summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/dbConn.h14
-rw-r--r--lib/dbRecordSet.h19
-rw-r--r--lib/dbStmt.h66
-rw-r--r--lib/dbTypes.h9
-rw-r--r--lib/input/mysqlBindings.h178
-rw-r--r--lib/input/mysqlConn.cpp73
-rw-r--r--lib/input/mysqlConn.h3
-rw-r--r--lib/input/mysqlRecordSet.cpp99
-rw-r--r--lib/input/mysqlRecordSet.h37
-rw-r--r--lib/input/mysqlStmt.cpp35
-rw-r--r--lib/input/mysqlStmt.h29
-rw-r--r--lib/output/pq/pqBindings.h57
-rw-r--r--lib/output/pq/pqConn.cpp73
-rw-r--r--lib/output/pq/pqConn.h14
-rw-r--r--lib/output/pq/pqRecordSet.cpp86
-rw-r--r--lib/output/pq/pqRecordSet.h25
-rw-r--r--lib/output/pq/pqStmt.cpp54
-rw-r--r--lib/output/pq/pqStmt.h37
18 files changed, 784 insertions, 124 deletions
diff --git a/lib/dbConn.h b/lib/dbConn.h
index 26b9a64..e4b056c 100644
--- a/lib/dbConn.h
+++ b/lib/dbConn.h
@@ -1,13 +1,27 @@
#ifndef MYGRATE_DBCONN_H
#define MYGRATE_DBCONN_H
+#include <dbRecordSet.h>
#include <dbTypes.h>
#include <initializer_list>
namespace MyGrate {
+ class DbPrepStmt {
+ public:
+ virtual ~DbPrepStmt() = default;
+ virtual void execute(const std::initializer_list<DbValue> &) = 0;
+ virtual std::size_t rows() const = 0;
+ virtual RecordSetPtr recordSet() = 0;
+ };
+ using DbPrepStmtPtr = std::unique_ptr<DbPrepStmt>;
+
class DbConn {
+ public:
+ virtual ~DbConn() = default;
virtual void query(const char * const) = 0;
virtual void query(const char * const, const std::initializer_list<DbValue> &) = 0;
+
+ virtual DbPrepStmtPtr prepare(const char * const, std::size_t nParams) = 0;
};
}
diff --git a/lib/dbRecordSet.h b/lib/dbRecordSet.h
new file mode 100644
index 0000000..9bddc01
--- /dev/null
+++ b/lib/dbRecordSet.h
@@ -0,0 +1,19 @@
+#ifndef MYGRATE_DBRECORDSET_H
+#define MYGRATE_DBRECORDSET_H
+
+#include <dbTypes.h>
+#include <memory>
+
+namespace MyGrate {
+ class RecordSet {
+ public:
+ virtual ~RecordSet() = default;
+
+ virtual std::size_t rows() const = 0;
+ virtual std::size_t columns() const = 0;
+ virtual DbValue at(std::size_t, std::size_t) const = 0;
+ };
+ using RecordSetPtr = std::unique_ptr<RecordSet>;
+}
+
+#endif
diff --git a/lib/dbStmt.h b/lib/dbStmt.h
new file mode 100644
index 0000000..3e98b34
--- /dev/null
+++ b/lib/dbStmt.h
@@ -0,0 +1,66 @@
+#ifndef MYGRATE_DBSTMT_H
+#define MYGRATE_DBSTMT_H
+
+#include <compileTimeFormatter.h>
+#include <dbConn.h>
+#include <dbRecordSet.h>
+#include <memory>
+#include <string_view>
+#include <type_traits>
+
+namespace MyGrate {
+ class DbConn;
+ enum class ParamMode { None, DollarNum, QMark };
+
+ template<AdHoc::support::basic_fixed_string S, ParamMode pm = ParamMode::None> class DbStmt {
+ public:
+ // This don't account for common table expressions, hopefully won't need those :)
+ static constexpr auto isSelect {S.v().starts_with("SELECT") || S.v().starts_with("SHOW")
+ || S.v().find("RETURNING") != std::string_view::npos};
+
+ // These don't account for string literals, which we'd prefer to avoid anyway :)
+ static constexpr auto paramCount {[]() -> std::size_t {
+ switch (pm) {
+ case ParamMode::None:
+ return 0LU;
+ case ParamMode::DollarNum: {
+ const auto pn = [](const char * c, const char * const e) {
+ std::size_t n {0};
+ while (++c != e && *c >= '0' && *c <= '9') {
+ n = (n * 10) + (*c - '0');
+ }
+ return n;
+ };
+ return pn(std::max_element(S.v().begin(), S.v().end(),
+ [pn, e = S.v().end()](const char & a, const char & b) {
+ return (a == '$' ? pn(&a, e) : 0) < (b == '$' ? pn(&b, e) : 0);
+ }),
+ S.v().end());
+ }
+ case ParamMode::QMark:
+ return std::count_if(S.v().begin(), S.v().end(), [](char c) {
+ return c == '?';
+ });
+ }
+ }()};
+
+ using Return = std::conditional_t<isSelect, RecordSetPtr, std::size_t>;
+
+ template<typename... P>
+ static Return
+ execute(DbConn * c, P &&... p)
+ {
+ static_assert(sizeof...(P) == paramCount);
+ auto stmt {c->prepare(S, sizeof...(P))};
+ stmt->execute({std::forward<P...>(p)...});
+ if constexpr (isSelect) {
+ return stmt->recordSet();
+ }
+ else {
+ return stmt->rows();
+ }
+ }
+ };
+}
+
+#endif
diff --git a/lib/dbTypes.h b/lib/dbTypes.h
index ba0cd70..b35b036 100644
--- a/lib/dbTypes.h
+++ b/lib/dbTypes.h
@@ -11,16 +11,25 @@ struct timespec;
namespace MyGrate {
struct Date {
+ inline Date() { }
+ inline Date(uint16_t y, uint8_t m, uint8_t d) : year {y}, month {m}, day {d} { }
+ explicit inline Date(const tm & tm) : Date(tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday) { }
uint16_t year;
uint8_t month;
uint8_t day;
};
struct Time {
+ inline Time() { }
+ inline Time(uint8_t h, uint8_t m, uint8_t s) : hour {h}, minute {m}, second {s} { }
+ explicit inline Time(const tm & tm) : Time(tm.tm_hour, tm.tm_min, tm.tm_sec) { }
uint8_t hour;
uint8_t minute;
uint8_t second;
};
struct DateTime : public Date, public Time {
+ inline DateTime() { }
+ inline DateTime(const Date & d, const Time & t) : Date {d}, Time {t} { }
+ explicit inline DateTime(const tm & tm) : Date {tm}, Time {tm} { }
};
using Blob = std::span<const std::byte>;
diff --git a/lib/input/mysqlBindings.h b/lib/input/mysqlBindings.h
new file mode 100644
index 0000000..dcb3ebf
--- /dev/null
+++ b/lib/input/mysqlBindings.h
@@ -0,0 +1,178 @@
+#ifndef MYGRATE_INPUT_MYSQLBINDINGS_H
+#define MYGRATE_INPUT_MYSQLBINDINGS_H
+
+#include <dbTypes.h>
+#include <helpers.h>
+#include <initializer_list>
+#include <mysql.h>
+#include <mysql_types.h>
+#include <variant>
+#include <vector>
+
+namespace MyGrate::Input {
+ struct BingingData {
+ explicit BingingData(unsigned long l, my_bool n = 0) : len {l}, null {n} { }
+ unsigned long len;
+ my_bool null;
+ };
+
+ struct Bindings {
+ // NOLINTNEXTLINE(hicpp-explicit-conversions)
+ explicit Bindings(const std::initializer_list<DbValue> & vs)
+ {
+ binds.reserve(vs.size());
+ data.reserve(vs.size());
+ for (const auto & v : vs) {
+ std::visit(*this, v);
+ }
+ }
+ template<std::integral T>
+ void
+ operator()(const T & v)
+ {
+ auto & b = binds.emplace_back();
+ b.buffer_type = MySQL::CType<T>::type;
+ b.buffer = const_cast<T *>(&v);
+ b.is_unsigned = std::unsigned_integral<T>;
+ }
+ template<std::floating_point T>
+ void
+ operator()(const T & v)
+ {
+ auto & b = binds.emplace_back();
+ b.buffer_type = MySQL::CType<T>::type;
+ b.buffer = const_cast<T *>(&v);
+ }
+ template<Viewable T>
+ void
+ operator()(const T & v)
+ {
+ auto & b = binds.emplace_back();
+ b.buffer_type = MySQL::CType<T>::type;
+ b.buffer = const_cast<typename T::value_type *>(v.data());
+ b.length = &data.emplace_back(v.size(), 0).len;
+ }
+ void
+ operator()(const std::nullptr_t &)
+ {
+ auto & b = binds.emplace_back();
+ b.buffer = nullptr;
+ b.is_null = &data.emplace_back(0, 1).null;
+ }
+ template<typename T>
+ void
+ operator()(const T &)
+ {
+ throw std::runtime_error("Not implemented");
+ }
+ std::vector<MYSQL_BIND> binds;
+ std::vector<BingingData> data;
+ };
+
+ class ResultData : public BingingData {
+ public:
+ ResultData() : BingingData {0} { }
+ virtual ~ResultData() = default;
+
+ [[nodiscard]] virtual DbValue getValue() const = 0;
+ };
+
+ template<typename T> class ResultDataT : public ResultData {
+ public:
+ ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f)
+ {
+ b.buffer = &buf;
+ b.buffer_length = sizeof(T);
+ b.is_null = &this->null;
+ b.length = &this->len;
+ b.is_unsigned = std::is_unsigned_v<T>;
+ b.buffer_type = f.type;
+ }
+
+ [[nodiscard]] DbValue
+ getValue() const override
+ {
+ return buf;
+ }
+
+ private:
+ T buf {};
+ };
+
+ template<> class ResultDataT<std::string_view> : public ResultData {
+ public:
+ ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f) : buf(f.length)
+ {
+ b.buffer_length = buf.size();
+ b.buffer = buf.data();
+ b.is_null = &this->null;
+ b.length = &this->len;
+ b.buffer_type = f.type;
+ }
+
+ [[nodiscard]] DbValue
+ getValue() const override
+ {
+ return std::string_view {buf.data(), this->len};
+ }
+
+ private:
+ std::vector<char> buf;
+ };
+
+ template<> class ResultDataT<Blob> : public ResultData {
+ public:
+ ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f) : buf(f.length)
+ {
+ b.buffer_length = buf.size();
+ b.buffer = buf.data();
+ b.is_null = &this->null;
+ b.length = &this->len;
+ b.buffer_type = f.type;
+ }
+
+ [[nodiscard]] DbValue
+ getValue() const override
+ {
+ return Blob {buf.data(), this->len};
+ }
+
+ private:
+ std::vector<std::byte> buf;
+ };
+
+ template<typename Out> class ResultDataTime : public ResultData {
+ public:
+ ResultDataTime(MYSQL_BIND & b, const MYSQL_FIELD & f)
+ {
+ b.buffer_length = sizeof(MYSQL_TIME);
+ b.buffer = &buf;
+ b.is_null = &this->null;
+ b.length = &this->len;
+ b.buffer_type = f.type;
+ }
+
+ [[nodiscard]] DbValue
+ getValue() const override
+ {
+ return Out {*this};
+ }
+
+ private:
+ operator Date() const
+ {
+ return Date(buf.year, buf.month, buf.day);
+ }
+ operator Time() const
+ {
+ return Time(buf.hour, buf.minute, buf.second);
+ }
+ operator DateTime() const
+ {
+ return DateTime(*this, *this);
+ }
+ MYSQL_TIME buf;
+ };
+}
+
+#endif
diff --git a/lib/input/mysqlConn.cpp b/lib/input/mysqlConn.cpp
index 179f9d5..46ed7c6 100644
--- a/lib/input/mysqlConn.cpp
+++ b/lib/input/mysqlConn.cpp
@@ -1,18 +1,17 @@
#include "mysqlConn.h"
-#include "helpers.h"
+#include "mysqlBindings.h"
+#include "mysqlStmt.h"
#include <cstddef>
#include <cstring>
+#include <dbConn.h>
#include <dbTypes.h>
+#include <helpers.h>
#include <memory>
#include <mysql.h>
-#include <mysql_types.h>
#include <stdexcept>
-#include <variant>
#include <vector>
namespace MyGrate::Input {
- using StmtPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_close)>;
-
MySQLConn::MySQLConn(
const char * const host, const char * const user, const char * const pass, unsigned short port) :
st_mysql {}
@@ -36,64 +35,6 @@ namespace MyGrate::Input {
verify<std::runtime_error>(!mysql_query(this, q), q);
}
- struct Bindings {
- // NOLINTNEXTLINE(hicpp-explicit-conversions)
- explicit Bindings(const std::initializer_list<DbValue> & vs)
- {
- binds.reserve(vs.size());
- extras.reserve(vs.size());
- for (const auto & v : vs) {
- std::visit(*this, v);
- }
- }
- template<std::integral T>
- void
- operator()(const T & v)
- {
- auto & b = binds.emplace_back();
- b.buffer_type = MySQL::CType<T>::type;
- b.buffer = const_cast<T *>(&v);
- b.is_unsigned = std::unsigned_integral<T>;
- }
- template<std::floating_point T>
- void
- operator()(const T & v)
- {
- auto & b = binds.emplace_back();
- b.buffer_type = MySQL::CType<T>::type;
- b.buffer = const_cast<T *>(&v);
- }
- template<Viewable T>
- void
- operator()(const T & v)
- {
- auto & b = binds.emplace_back();
- b.buffer_type = MySQL::CType<T>::type;
- b.buffer = const_cast<typename T::value_type *>(v.data());
- b.length = &extras.emplace_back(v.size(), 0).len;
- }
- void
- operator()(const std::nullptr_t &)
- {
- auto & b = binds.emplace_back();
- b.buffer = nullptr;
- b.is_null = &extras.emplace_back(0, 1).null;
- }
- template<typename T>
- void
- operator()(const T &)
- {
- throw std::runtime_error("Not implemented");
- }
- struct extra {
- explicit extra(unsigned long l, my_bool n = 0) : len {l}, null {n} { }
- unsigned long len;
- my_bool null;
- };
- std::vector<MYSQL_BIND> binds;
- std::vector<extra> extras;
- };
-
void
MySQLConn::query(const char * const q, const std::initializer_list<DbValue> & vs)
{
@@ -104,4 +45,10 @@ namespace MyGrate::Input {
verify<std::runtime_error>(!mysql_stmt_bind_param(stmt.get(), b.binds.data()), "Param count mismatch");
verify<std::runtime_error>(!mysql_stmt_execute(stmt.get()), q);
}
+
+ DbPrepStmtPtr
+ MySQLConn::prepare(const char * const q, std::size_t)
+ {
+ return std::make_unique<MySQLPrepStmt>(q, this);
+ }
}
diff --git a/lib/input/mysqlConn.h b/lib/input/mysqlConn.h
index 9e4ec25..2f71262 100644
--- a/lib/input/mysqlConn.h
+++ b/lib/input/mysqlConn.h
@@ -1,6 +1,7 @@
#ifndef MYGRATE_INPUT_MYSQLCONN_H
#define MYGRATE_INPUT_MYSQLCONN_H
+#include <cstddef>
#include <dbConn.h>
#include <dbTypes.h>
#include <initializer_list>
@@ -14,6 +15,8 @@ namespace MyGrate::Input {
void query(const char * const) override;
void query(const char * const q, const std::initializer_list<DbValue> &) override;
+
+ DbPrepStmtPtr prepare(const char * const, std::size_t) override;
};
}
diff --git a/lib/input/mysqlRecordSet.cpp b/lib/input/mysqlRecordSet.cpp
new file mode 100644
index 0000000..51be15f
--- /dev/null
+++ b/lib/input/mysqlRecordSet.cpp
@@ -0,0 +1,99 @@
+#include "mysqlRecordSet.h"
+#include "mysqlBindings.h"
+#include "mysqlStmt.h"
+#include <cstdint>
+#include <dbTypes.h>
+#include <helpers.h>
+#include <stdexcept>
+#include <string_view>
+#include <utility>
+// IWYU pragma: no_include <ext/alloc_traits.h>
+
+namespace MyGrate::Input {
+ MySQLRecordSet::MySQLRecordSet(StmtPtr s) :
+ stmt {std::move(s)}, stmtres {nullptr, nullptr}, fields(mysql_stmt_field_count(stmt.get())),
+ extras(fields.size())
+ {
+ auto getBind = [](const MYSQL_FIELD & f, MYSQL_BIND & b) -> std::unique_ptr<ResultData> {
+ switch (f.type) {
+ case MYSQL_TYPE_DECIMAL:
+ case MYSQL_TYPE_NEWDECIMAL:
+ case MYSQL_TYPE_DOUBLE:
+ return std::make_unique<ResultDataT<double>>(b, f);
+ case MYSQL_TYPE_FLOAT:
+ return std::make_unique<ResultDataT<float>>(b, f);
+ case MYSQL_TYPE_TINY:
+ return std::make_unique<ResultDataT<int8_t>>(b, f);
+ case MYSQL_TYPE_SHORT:
+ case MYSQL_TYPE_YEAR:
+ return std::make_unique<ResultDataT<int16_t>>(b, f);
+ case MYSQL_TYPE_LONG:
+ case MYSQL_TYPE_INT24:
+ return std::make_unique<ResultDataT<int32_t>>(b, f);
+ case MYSQL_TYPE_LONGLONG:
+ return std::make_unique<ResultDataT<int64_t>>(b, f);
+ case MYSQL_TYPE_NULL:
+ return std::make_unique<ResultDataT<std::nullptr_t>>(b, f);
+ case MYSQL_TYPE_TIMESTAMP:
+ case MYSQL_TYPE_TIMESTAMP2:
+ case MYSQL_TYPE_DATETIME:
+ case MYSQL_TYPE_DATETIME2:
+ return std::make_unique<ResultDataTime<DateTime>>(b, f);
+ case MYSQL_TYPE_TIME:
+ case MYSQL_TYPE_TIME2:
+ return std::make_unique<ResultDataTime<Time>>(b, f);
+ case MYSQL_TYPE_DATE:
+ case MYSQL_TYPE_NEWDATE:
+ return std::make_unique<ResultDataTime<Date>>(b, f);
+ case MYSQL_TYPE_VARCHAR:
+ case MYSQL_TYPE_VAR_STRING:
+ case MYSQL_TYPE_STRING:
+ case MYSQL_TYPE_JSON:
+ case MYSQL_TYPE_ENUM:
+ return std::make_unique<ResultDataT<std::string_view>>(b, f);
+ case MYSQL_TYPE_TINY_BLOB:
+ case MYSQL_TYPE_MEDIUM_BLOB:
+ case MYSQL_TYPE_LONG_BLOB:
+ case MYSQL_TYPE_BLOB:
+ return std::make_unique<ResultDataT<Blob>>(b, f);
+ case MAX_NO_FIELD_TYPES:
+ case MYSQL_TYPE_BIT:
+ case MYSQL_TYPE_SET:
+ case MYSQL_TYPE_GEOMETRY:;
+ }
+ throw std::runtime_error("Unsupported column type");
+ };
+ ResPtr meta {mysql_stmt_result_metadata(stmt.get()), mysql_free_result};
+ const auto fieldDefs = mysql_fetch_fields(meta.get());
+ verify<std::runtime_error>(fieldDefs, "Fetch fields");
+ for (std::size_t i = 0; i < fields.size(); i += 1) {
+ extras[i] = getBind(fieldDefs[i], fields[i]);
+ }
+ verify<std::runtime_error>(!mysql_stmt_bind_result(stmt.get(), fields.data()), "Store result error");
+ verify<std::runtime_error>(!mysql_stmt_store_result(stmt.get()), "Store result error");
+ stmtres = {stmt.get(), mysql_stmt_free_result};
+ verify<std::runtime_error>(!mysql_stmt_fetch(stmt.get()), "Fetch");
+ }
+
+ std::size_t
+ MySQLRecordSet::rows() const
+ {
+ return mysql_stmt_num_rows(stmt.get());
+ }
+
+ std::size_t
+ MySQLRecordSet::columns() const
+ {
+ return fields.size();
+ }
+
+ DbValue
+ MySQLRecordSet::at(std::size_t row, std::size_t col) const
+ {
+ mysql_stmt_data_seek(stmt.get(), row);
+ if (extras[col]->null) {
+ return nullptr;
+ }
+ return extras[col]->getValue();
+ }
+}
diff --git a/lib/input/mysqlRecordSet.h b/lib/input/mysqlRecordSet.h
new file mode 100644
index 0000000..849a653
--- /dev/null
+++ b/lib/input/mysqlRecordSet.h
@@ -0,0 +1,37 @@
+#ifndef MYGRATE_INPUT_MYSQLRECORDSET_H
+#define MYGRATE_INPUT_MYSQLRECORDSET_H
+
+#include "mysqlStmt.h"
+#include <cstddef>
+#include <dbRecordSet.h>
+#include <dbTypes.h>
+#include <memory>
+#include <mysql.h>
+#include <vector>
+
+namespace MyGrate::Input {
+ class ResultData;
+
+ class MySQLRecordSet : public RecordSet {
+ public:
+ using ResPtr = std::unique_ptr<MYSQL_RES, decltype(&mysql_free_result)>;
+ using StmtResPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_free_result)>;
+ using ResultDataPtr = std::unique_ptr<ResultData>;
+
+ explicit MySQLRecordSet(StmtPtr s);
+
+ std::size_t rows() const override;
+
+ std::size_t columns() const override;
+
+ DbValue at(std::size_t row, std::size_t col) const override;
+
+ private:
+ StmtPtr stmt;
+ StmtResPtr stmtres;
+ std::vector<MYSQL_BIND> fields;
+ std::vector<ResultDataPtr> extras;
+ };
+}
+
+#endif
diff --git a/lib/input/mysqlStmt.cpp b/lib/input/mysqlStmt.cpp
new file mode 100644
index 0000000..08d1303
--- /dev/null
+++ b/lib/input/mysqlStmt.cpp
@@ -0,0 +1,35 @@
+#include "mysqlStmt.h"
+#include "mysqlBindings.h"
+#include "mysqlRecordSet.h"
+#include <cstring>
+#include <helpers.h>
+#include <stdexcept>
+#include <utility>
+#include <vector>
+
+namespace MyGrate::Input {
+ MySQLPrepStmt::MySQLPrepStmt(const char * const q, MYSQL * c) : stmt {mysql_stmt_init(c), &mysql_stmt_close}
+ {
+ verify<std::runtime_error>(!mysql_stmt_prepare(stmt.get(), q, strlen(q)), q);
+ }
+
+ void
+ MySQLPrepStmt::execute(const std::initializer_list<DbValue> & vs)
+ {
+ Bindings b {vs};
+ verify<std::runtime_error>(!mysql_stmt_bind_param(stmt.get(), b.binds.data()), "Param count mismatch");
+ verify<std::runtime_error>(!mysql_stmt_execute(stmt.get()), "Prepared statement execute");
+ }
+
+ std::size_t
+ MySQLPrepStmt::rows() const
+ {
+ return mysql_stmt_affected_rows(stmt.get());
+ }
+
+ RecordSetPtr
+ MySQLPrepStmt::recordSet()
+ {
+ return std::make_unique<MySQLRecordSet>(std::move(stmt));
+ }
+}
diff --git a/lib/input/mysqlStmt.h b/lib/input/mysqlStmt.h
new file mode 100644
index 0000000..a42e0db
--- /dev/null
+++ b/lib/input/mysqlStmt.h
@@ -0,0 +1,29 @@
+#ifndef MYGRATE_INPUT_MYSQLSTMT_H
+#define MYGRATE_INPUT_MYSQLSTMT_H
+
+#include "dbConn.h"
+#include "dbRecordSet.h"
+#include "dbTypes.h"
+#include <cstddef>
+#include <initializer_list>
+#include <memory>
+#include <mysql.h>
+
+namespace MyGrate::Input {
+ using StmtPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_close)>;
+
+ class MySQLPrepStmt : public DbPrepStmt {
+ public:
+ MySQLPrepStmt(const char * const q, MYSQL * c);
+ void execute(const std::initializer_list<DbValue> & vs) override;
+
+ std::size_t rows() const override;
+
+ RecordSetPtr recordSet() override;
+
+ private:
+ StmtPtr stmt;
+ };
+}
+
+#endif
diff --git a/lib/output/pq/pqBindings.h b/lib/output/pq/pqBindings.h
new file mode 100644
index 0000000..ef0df84
--- /dev/null
+++ b/lib/output/pq/pqBindings.h
@@ -0,0 +1,57 @@
+#ifndef MYGRATE_OUTPUT_PQ_PQBINDINGS
+#define MYGRATE_OUTPUT_PQ_PQBINDINGS
+
+#include <dbTypes.h>
+#include <helpers.h>
+#include <initializer_list>
+#include <variant>
+#include <vector>
+
+namespace MyGrate::Output::Pq {
+ struct Bindings {
+ // NOLINTNEXTLINE(hicpp-explicit-conversions)
+ explicit Bindings(const std::initializer_list<DbValue> & vs)
+ {
+ bufs.reserve(vs.size());
+ values.reserve(vs.size());
+ lengths.reserve(vs.size());
+ for (const auto & v : vs) {
+ std::visit(*this, v);
+ }
+ }
+ template<Stringable T>
+ void
+ operator()(const T & v)
+ {
+ bufs.emplace_back(std::to_string(v));
+ const auto & vw {bufs.back()};
+ values.emplace_back(vw.data());
+ lengths.emplace_back(vw.length());
+ }
+ template<Viewable T>
+ void
+ operator()(const T & v)
+ {
+ values.emplace_back(v.data());
+ lengths.emplace_back(v.size());
+ }
+ template<typename T>
+ void
+ operator()(const T &)
+ {
+ throw std::runtime_error("Not implemented");
+ }
+ void
+ operator()(const std::nullptr_t &)
+ {
+ values.emplace_back(nullptr);
+ lengths.emplace_back(0);
+ }
+
+ std::vector<std::string> bufs;
+ std::vector<const char *> values;
+ std::vector<int> lengths;
+ };
+}
+
+#endif
diff --git a/lib/output/pq/pqConn.cpp b/lib/output/pq/pqConn.cpp
index 4f55ba8..81d9610 100644
--- a/lib/output/pq/pqConn.cpp
+++ b/lib/output/pq/pqConn.cpp
@@ -1,89 +1,44 @@
#include "pqConn.h"
+#include "pqBindings.h"
+#include "pqStmt.h"
+#include <dbConn.h>
#include <dbTypes.h>
#include <helpers.h>
#include <libpq-fe.h>
#include <memory>
-#include <sstream>
#include <stdexcept>
#include <string>
-#include <variant>
#include <vector>
namespace MyGrate::Output::Pq {
- using ResPtr = std::unique_ptr<PGresult, decltype(&PQclear)>;
-
- PqConn::PqConn(const char * const str) : conn {PQconnectdb(str)}
+ PqConn::PqConn(const char * const str) : conn {PQconnectdb(str), PQfinish}
{
- verify<std::runtime_error>(PQstatus(conn) == CONNECTION_OK, "Connection failure");
- PQsetNoticeProcessor(conn, notice_processor, this);
- }
-
- PqConn::~PqConn()
- {
- PQfinish(conn);
+ verify<std::runtime_error>(PQstatus(conn.get()) == CONNECTION_OK, "Connection failure");
+ PQsetNoticeProcessor(conn.get(), notice_processor, this);
}
void
PqConn::query(const char * const q)
{
- ResPtr res {PQexec(conn, q), &PQclear};
+ ResPtr res {PQexec(conn.get(), q), &PQclear};
verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q);
}
- struct Bindings {
- // NOLINTNEXTLINE(hicpp-explicit-conversions)
- explicit Bindings(const std::initializer_list<DbValue> & vs)
- {
- bufs.reserve(vs.size());
- values.reserve(vs.size());
- lengths.reserve(vs.size());
- for (const auto & v : vs) {
- std::visit(*this, v);
- }
- }
- template<Stringable T>
- void
- operator()(const T & v)
- {
- bufs.emplace_back(std::to_string(v));
- const auto & vw {bufs.back()};
- values.emplace_back(vw.data());
- lengths.emplace_back(vw.length());
- }
- template<Viewable T>
- void
- operator()(const T & v)
- {
- values.emplace_back(v.data());
- lengths.emplace_back(v.size());
- }
- template<typename T>
- void
- operator()(const T &)
- {
- throw std::runtime_error("Not implemented");
- }
- void
- operator()(const std::nullptr_t &)
- {
- values.emplace_back(nullptr);
- lengths.emplace_back(0);
- }
-
- std::vector<std::string> bufs;
- std::vector<const char *> values;
- std::vector<int> lengths;
- };
-
void
PqConn::query(const char * const q, const std::initializer_list<DbValue> & vs)
{
Bindings b {vs};
- ResPtr res {PQexecParams(conn, q, (int)vs.size(), nullptr, b.values.data(), b.lengths.data(), nullptr, 0),
+ ResPtr res {PQexecParams(conn.get(), q, (int)vs.size(), nullptr, b.values.data(), b.lengths.data(), nullptr, 0),
&PQclear};
verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q);
}
+ DbPrepStmtPtr
+ PqConn::prepare(const char * const q, std::size_t n)
+ {
+ return std::make_unique<PqPrepStmt>(q, n, this);
+ }
+
void
PqConn::notice_processor(void * p, const char * n)
{
diff --git a/lib/output/pq/pqConn.h b/lib/output/pq/pqConn.h
index 613af6f..856683d 100644
--- a/lib/output/pq/pqConn.h
+++ b/lib/output/pq/pqConn.h
@@ -1,25 +1,35 @@
#ifndef MYGRATE_OUTPUT_PQ_PQCONN_H
#define MYGRATE_OUTPUT_PQ_PQCONN_H
+#include <cstddef>
#include <dbConn.h>
#include <dbTypes.h>
+#include <functional>
#include <initializer_list>
#include <libpq-fe.h>
+#include <map>
+#include <memory>
+#include <string>
namespace MyGrate::Output::Pq {
class PqConn : public DbConn {
public:
explicit PqConn(const char * const str);
- virtual ~PqConn();
+ virtual ~PqConn() = default;
void query(const char * const) override;
void query(const char * const, const std::initializer_list<DbValue> &) override;
+ DbPrepStmtPtr prepare(const char * const, std::size_t nParams) override;
+
private:
static void notice_processor(void *, const char *);
virtual void notice_processor(const char *) const;
- PGconn * const conn;
+ std::unique_ptr<PGconn, decltype(&PQfinish)> const conn;
+
+ friend class PqPrepStmt;
+ std::map<std::string, std::string, std::less<>> stmts;
};
}
diff --git a/lib/output/pq/pqRecordSet.cpp b/lib/output/pq/pqRecordSet.cpp
new file mode 100644
index 0000000..71ddee4
--- /dev/null
+++ b/lib/output/pq/pqRecordSet.cpp
@@ -0,0 +1,86 @@
+#include "pqRecordSet.h"
+#include "dbTypes.h"
+#include "pqStmt.h"
+#include <cstddef>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <helpers.h>
+#include <libpq-fe.h>
+#include <server/catalog/pg_type_d.h>
+#include <stdexcept>
+#include <string_view>
+#include <utility>
+
+namespace MyGrate::Output::Pq {
+ PqRecordSet::PqRecordSet(ResPtr r) : res {std::move(r)} { }
+
+ std::size_t
+ PqRecordSet::rows() const
+ {
+ return PQntuples(res.get());
+ }
+
+ std::size_t
+ PqRecordSet::columns() const
+ {
+ return PQnfields(res.get());
+ }
+
+ DbValue
+ PqRecordSet::at(std::size_t row, std::size_t col) const
+ {
+ if (PQgetisnull(res.get(), (int)row, (int)col)) {
+ return nullptr;
+ }
+ const auto value {PQgetvalue(res.get(), (int)row, (int)col)};
+ const auto size {static_cast<size_t>(PQgetlength(res.get(), (int)row, (int)col))};
+ const auto type {PQftype(res.get(), (int)col)};
+ switch (type) {
+ // case BITOID: TODO bool
+ // case BOOLOID: TODO bool
+ // case BOOLARRAYOID:
+ case VARBITOID:
+ case BYTEAOID:
+ // This is wrong :)
+ return Blob {reinterpret_cast<const std::byte *>(value), size};
+ case INT2OID:
+ return static_cast<int16_t>(std::strtol(value, nullptr, 10));
+ case INT4OID:
+ return static_cast<int32_t>(std::strtol(value, nullptr, 10));
+ case INT8OID:
+ return std::strtol(value, nullptr, 10);
+ case FLOAT4OID:
+ return std::strtof(value, nullptr);
+ case FLOAT8OID:
+ case CASHOID:
+ case NUMERICOID:
+ return std::strtod(value, nullptr);
+ case DATEOID: {
+ tm tm {};
+ const auto end = strptime(value, "%F", &tm);
+ verify<std::runtime_error>(end && !*end, "Invalid date string");
+ return Date {tm};
+ }
+ case TIMEOID: {
+ tm tm {};
+ const auto end = strptime(value, "%T", &tm);
+ verify<std::runtime_error>(end && !*end, "Invalid time string");
+ return Time {tm};
+ }
+ case TIMESTAMPOID: {
+ tm tm {};
+ const auto end = strptime(value, "%FT%T", &tm);
+ verify<std::runtime_error>(end && !*end, "Invalid timestamp string");
+ return DateTime {tm};
+ }
+ // case TIMESTAMPTZOID: Maybe add TZ support?
+ // case INTERVALOID: Maybe add interval support?
+ // case TIMETZOID: Maybe add TZ support?
+ case VOIDOID:
+ return nullptr;
+ default:
+ return std::string_view {value, size};
+ }
+ }
+}
diff --git a/lib/output/pq/pqRecordSet.h b/lib/output/pq/pqRecordSet.h
new file mode 100644
index 0000000..2934d84
--- /dev/null
+++ b/lib/output/pq/pqRecordSet.h
@@ -0,0 +1,25 @@
+#ifndef MYGRATE_OUTPUT_PQ_PQRECORDSET_H
+#define MYGRATE_OUTPUT_PQ_PQRECORDSET_H
+
+#include "dbRecordSet.h"
+#include "dbTypes.h"
+#include "pqStmt.h"
+#include <cstddef>
+
+namespace MyGrate::Output::Pq {
+ class PqRecordSet : public RecordSet {
+ public:
+ explicit PqRecordSet(ResPtr r);
+
+ std::size_t rows() const override;
+
+ std::size_t columns() const override;
+
+ DbValue at(std::size_t row, std::size_t col) const override;
+
+ private:
+ ResPtr res;
+ };
+}
+
+#endif
diff --git a/lib/output/pq/pqStmt.cpp b/lib/output/pq/pqStmt.cpp
new file mode 100644
index 0000000..04b48c6
--- /dev/null
+++ b/lib/output/pq/pqStmt.cpp
@@ -0,0 +1,54 @@
+#include "pqStmt.h"
+#include "libpq-fe.h"
+#include "pqBindings.h"
+#include "pqConn.h"
+#include "pqRecordSet.h"
+#include <compileTimeFormatter.h>
+#include <cstdlib>
+#include <functional>
+#include <helpers.h>
+#include <map>
+#include <stdexcept>
+#include <utility>
+#include <vector>
+
+namespace MyGrate::Output::Pq {
+ PqPrepStmt::PqPrepStmt(const char * const q, std::size_t n, PqConn * c) :
+ conn {c->conn.get()}, name {prepareAsNeeded(q, n, c)}, res {nullptr, nullptr}
+ {
+ }
+
+ void
+ PqPrepStmt::execute(const std::initializer_list<DbValue> & vs)
+ {
+ Bindings b {vs};
+ res = {PQexecPrepared(conn, name.c_str(), (int)vs.size(), b.values.data(), b.lengths.data(), nullptr, 0),
+ &PQclear};
+ verify<std::runtime_error>(
+ PQresultStatus(res.get()) == PGRES_COMMAND_OK || PQresultStatus(res.get()) == PGRES_TUPLES_OK, name);
+ }
+
+ std::size_t
+ PqPrepStmt::rows() const
+ {
+ return std::strtoul(PQcmdTuples(res.get()), nullptr, 10);
+ }
+
+ RecordSetPtr
+ PqPrepStmt::recordSet()
+ {
+ return std::make_unique<PqRecordSet>(std::move(res));
+ }
+
+ std::string
+ PqPrepStmt::prepareAsNeeded(const char * const q, std::size_t n, PqConn * c)
+ {
+ if (const auto i = c->stmts.find(q); i != c->stmts.end()) {
+ return i->second;
+ }
+ auto nam {AdHoc::scprintf<"pst%0x">(c->stmts.size())};
+ ResPtr res {PQprepare(c->conn.get(), nam.c_str(), q, (int)n, nullptr), PQclear};
+ verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q);
+ return c->stmts.emplace(q, std::move(nam)).first->second;
+ }
+}
diff --git a/lib/output/pq/pqStmt.h b/lib/output/pq/pqStmt.h
new file mode 100644
index 0000000..b806617
--- /dev/null
+++ b/lib/output/pq/pqStmt.h
@@ -0,0 +1,37 @@
+#ifndef MYGRATE_OUTPUT_PQ_PQSTMT_H
+#define MYGRATE_OUTPUT_PQ_PQSTMT_H
+
+#include "dbConn.h"
+#include "dbRecordSet.h"
+#include "dbTypes.h"
+#include <cstddef>
+#include <initializer_list>
+#include <libpq-fe.h>
+#include <memory>
+#include <string>
+
+namespace MyGrate::Output::Pq {
+ class PqConn;
+
+ using ResPtr = std::unique_ptr<PGresult, decltype(&PQclear)>;
+
+ class PqPrepStmt : public DbPrepStmt {
+ public:
+ PqPrepStmt(const char * const q, std::size_t n, PqConn * c);
+
+ void execute(const std::initializer_list<DbValue> & vs) override;
+
+ std::size_t rows() const override;
+
+ RecordSetPtr recordSet() override;
+
+ private:
+ static std::string prepareAsNeeded(const char * const q, std::size_t n, PqConn * c);
+
+ PGconn * conn;
+ std::string name;
+ ResPtr res;
+ };
+}
+
+#endif