diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/dbConn.h | 14 | ||||
-rw-r--r-- | lib/dbRecordSet.h | 19 | ||||
-rw-r--r-- | lib/dbStmt.h | 66 | ||||
-rw-r--r-- | lib/dbTypes.h | 9 | ||||
-rw-r--r-- | lib/input/mysqlBindings.h | 178 | ||||
-rw-r--r-- | lib/input/mysqlConn.cpp | 73 | ||||
-rw-r--r-- | lib/input/mysqlConn.h | 3 | ||||
-rw-r--r-- | lib/input/mysqlRecordSet.cpp | 99 | ||||
-rw-r--r-- | lib/input/mysqlRecordSet.h | 37 | ||||
-rw-r--r-- | lib/input/mysqlStmt.cpp | 35 | ||||
-rw-r--r-- | lib/input/mysqlStmt.h | 29 | ||||
-rw-r--r-- | lib/output/pq/pqBindings.h | 57 | ||||
-rw-r--r-- | lib/output/pq/pqConn.cpp | 73 | ||||
-rw-r--r-- | lib/output/pq/pqConn.h | 14 | ||||
-rw-r--r-- | lib/output/pq/pqRecordSet.cpp | 86 | ||||
-rw-r--r-- | lib/output/pq/pqRecordSet.h | 25 | ||||
-rw-r--r-- | lib/output/pq/pqStmt.cpp | 54 | ||||
-rw-r--r-- | lib/output/pq/pqStmt.h | 37 |
18 files changed, 784 insertions, 124 deletions
diff --git a/lib/dbConn.h b/lib/dbConn.h index 26b9a64..e4b056c 100644 --- a/lib/dbConn.h +++ b/lib/dbConn.h @@ -1,13 +1,27 @@ #ifndef MYGRATE_DBCONN_H #define MYGRATE_DBCONN_H +#include <dbRecordSet.h> #include <dbTypes.h> #include <initializer_list> namespace MyGrate { + class DbPrepStmt { + public: + virtual ~DbPrepStmt() = default; + virtual void execute(const std::initializer_list<DbValue> &) = 0; + virtual std::size_t rows() const = 0; + virtual RecordSetPtr recordSet() = 0; + }; + using DbPrepStmtPtr = std::unique_ptr<DbPrepStmt>; + class DbConn { + public: + virtual ~DbConn() = default; virtual void query(const char * const) = 0; virtual void query(const char * const, const std::initializer_list<DbValue> &) = 0; + + virtual DbPrepStmtPtr prepare(const char * const, std::size_t nParams) = 0; }; } diff --git a/lib/dbRecordSet.h b/lib/dbRecordSet.h new file mode 100644 index 0000000..9bddc01 --- /dev/null +++ b/lib/dbRecordSet.h @@ -0,0 +1,19 @@ +#ifndef MYGRATE_DBRECORDSET_H +#define MYGRATE_DBRECORDSET_H + +#include <dbTypes.h> +#include <memory> + +namespace MyGrate { + class RecordSet { + public: + virtual ~RecordSet() = default; + + virtual std::size_t rows() const = 0; + virtual std::size_t columns() const = 0; + virtual DbValue at(std::size_t, std::size_t) const = 0; + }; + using RecordSetPtr = std::unique_ptr<RecordSet>; +} + +#endif diff --git a/lib/dbStmt.h b/lib/dbStmt.h new file mode 100644 index 0000000..3e98b34 --- /dev/null +++ b/lib/dbStmt.h @@ -0,0 +1,66 @@ +#ifndef MYGRATE_DBSTMT_H +#define MYGRATE_DBSTMT_H + +#include <compileTimeFormatter.h> +#include <dbConn.h> +#include <dbRecordSet.h> +#include <memory> +#include <string_view> +#include <type_traits> + +namespace MyGrate { + class DbConn; + enum class ParamMode { None, DollarNum, QMark }; + + template<AdHoc::support::basic_fixed_string S, ParamMode pm = ParamMode::None> class DbStmt { + public: + // This don't account for common table expressions, hopefully won't need those :) + static constexpr auto isSelect {S.v().starts_with("SELECT") || S.v().starts_with("SHOW") + || S.v().find("RETURNING") != std::string_view::npos}; + + // These don't account for string literals, which we'd prefer to avoid anyway :) + static constexpr auto paramCount {[]() -> std::size_t { + switch (pm) { + case ParamMode::None: + return 0LU; + case ParamMode::DollarNum: { + const auto pn = [](const char * c, const char * const e) { + std::size_t n {0}; + while (++c != e && *c >= '0' && *c <= '9') { + n = (n * 10) + (*c - '0'); + } + return n; + }; + return pn(std::max_element(S.v().begin(), S.v().end(), + [pn, e = S.v().end()](const char & a, const char & b) { + return (a == '$' ? pn(&a, e) : 0) < (b == '$' ? pn(&b, e) : 0); + }), + S.v().end()); + } + case ParamMode::QMark: + return std::count_if(S.v().begin(), S.v().end(), [](char c) { + return c == '?'; + }); + } + }()}; + + using Return = std::conditional_t<isSelect, RecordSetPtr, std::size_t>; + + template<typename... P> + static Return + execute(DbConn * c, P &&... p) + { + static_assert(sizeof...(P) == paramCount); + auto stmt {c->prepare(S, sizeof...(P))}; + stmt->execute({std::forward<P...>(p)...}); + if constexpr (isSelect) { + return stmt->recordSet(); + } + else { + return stmt->rows(); + } + } + }; +} + +#endif diff --git a/lib/dbTypes.h b/lib/dbTypes.h index ba0cd70..b35b036 100644 --- a/lib/dbTypes.h +++ b/lib/dbTypes.h @@ -11,16 +11,25 @@ struct timespec; namespace MyGrate { struct Date { + inline Date() { } + inline Date(uint16_t y, uint8_t m, uint8_t d) : year {y}, month {m}, day {d} { } + explicit inline Date(const tm & tm) : Date(tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday) { } uint16_t year; uint8_t month; uint8_t day; }; struct Time { + inline Time() { } + inline Time(uint8_t h, uint8_t m, uint8_t s) : hour {h}, minute {m}, second {s} { } + explicit inline Time(const tm & tm) : Time(tm.tm_hour, tm.tm_min, tm.tm_sec) { } uint8_t hour; uint8_t minute; uint8_t second; }; struct DateTime : public Date, public Time { + inline DateTime() { } + inline DateTime(const Date & d, const Time & t) : Date {d}, Time {t} { } + explicit inline DateTime(const tm & tm) : Date {tm}, Time {tm} { } }; using Blob = std::span<const std::byte>; diff --git a/lib/input/mysqlBindings.h b/lib/input/mysqlBindings.h new file mode 100644 index 0000000..dcb3ebf --- /dev/null +++ b/lib/input/mysqlBindings.h @@ -0,0 +1,178 @@ +#ifndef MYGRATE_INPUT_MYSQLBINDINGS_H +#define MYGRATE_INPUT_MYSQLBINDINGS_H + +#include <dbTypes.h> +#include <helpers.h> +#include <initializer_list> +#include <mysql.h> +#include <mysql_types.h> +#include <variant> +#include <vector> + +namespace MyGrate::Input { + struct BingingData { + explicit BingingData(unsigned long l, my_bool n = 0) : len {l}, null {n} { } + unsigned long len; + my_bool null; + }; + + struct Bindings { + // NOLINTNEXTLINE(hicpp-explicit-conversions) + explicit Bindings(const std::initializer_list<DbValue> & vs) + { + binds.reserve(vs.size()); + data.reserve(vs.size()); + for (const auto & v : vs) { + std::visit(*this, v); + } + } + template<std::integral T> + void + operator()(const T & v) + { + auto & b = binds.emplace_back(); + b.buffer_type = MySQL::CType<T>::type; + b.buffer = const_cast<T *>(&v); + b.is_unsigned = std::unsigned_integral<T>; + } + template<std::floating_point T> + void + operator()(const T & v) + { + auto & b = binds.emplace_back(); + b.buffer_type = MySQL::CType<T>::type; + b.buffer = const_cast<T *>(&v); + } + template<Viewable T> + void + operator()(const T & v) + { + auto & b = binds.emplace_back(); + b.buffer_type = MySQL::CType<T>::type; + b.buffer = const_cast<typename T::value_type *>(v.data()); + b.length = &data.emplace_back(v.size(), 0).len; + } + void + operator()(const std::nullptr_t &) + { + auto & b = binds.emplace_back(); + b.buffer = nullptr; + b.is_null = &data.emplace_back(0, 1).null; + } + template<typename T> + void + operator()(const T &) + { + throw std::runtime_error("Not implemented"); + } + std::vector<MYSQL_BIND> binds; + std::vector<BingingData> data; + }; + + class ResultData : public BingingData { + public: + ResultData() : BingingData {0} { } + virtual ~ResultData() = default; + + [[nodiscard]] virtual DbValue getValue() const = 0; + }; + + template<typename T> class ResultDataT : public ResultData { + public: + ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f) + { + b.buffer = &buf; + b.buffer_length = sizeof(T); + b.is_null = &this->null; + b.length = &this->len; + b.is_unsigned = std::is_unsigned_v<T>; + b.buffer_type = f.type; + } + + [[nodiscard]] DbValue + getValue() const override + { + return buf; + } + + private: + T buf {}; + }; + + template<> class ResultDataT<std::string_view> : public ResultData { + public: + ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f) : buf(f.length) + { + b.buffer_length = buf.size(); + b.buffer = buf.data(); + b.is_null = &this->null; + b.length = &this->len; + b.buffer_type = f.type; + } + + [[nodiscard]] DbValue + getValue() const override + { + return std::string_view {buf.data(), this->len}; + } + + private: + std::vector<char> buf; + }; + + template<> class ResultDataT<Blob> : public ResultData { + public: + ResultDataT(MYSQL_BIND & b, const MYSQL_FIELD & f) : buf(f.length) + { + b.buffer_length = buf.size(); + b.buffer = buf.data(); + b.is_null = &this->null; + b.length = &this->len; + b.buffer_type = f.type; + } + + [[nodiscard]] DbValue + getValue() const override + { + return Blob {buf.data(), this->len}; + } + + private: + std::vector<std::byte> buf; + }; + + template<typename Out> class ResultDataTime : public ResultData { + public: + ResultDataTime(MYSQL_BIND & b, const MYSQL_FIELD & f) + { + b.buffer_length = sizeof(MYSQL_TIME); + b.buffer = &buf; + b.is_null = &this->null; + b.length = &this->len; + b.buffer_type = f.type; + } + + [[nodiscard]] DbValue + getValue() const override + { + return Out {*this}; + } + + private: + operator Date() const + { + return Date(buf.year, buf.month, buf.day); + } + operator Time() const + { + return Time(buf.hour, buf.minute, buf.second); + } + operator DateTime() const + { + return DateTime(*this, *this); + } + MYSQL_TIME buf; + }; +} + +#endif diff --git a/lib/input/mysqlConn.cpp b/lib/input/mysqlConn.cpp index 179f9d5..46ed7c6 100644 --- a/lib/input/mysqlConn.cpp +++ b/lib/input/mysqlConn.cpp @@ -1,18 +1,17 @@ #include "mysqlConn.h" -#include "helpers.h" +#include "mysqlBindings.h" +#include "mysqlStmt.h" #include <cstddef> #include <cstring> +#include <dbConn.h> #include <dbTypes.h> +#include <helpers.h> #include <memory> #include <mysql.h> -#include <mysql_types.h> #include <stdexcept> -#include <variant> #include <vector> namespace MyGrate::Input { - using StmtPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_close)>; - MySQLConn::MySQLConn( const char * const host, const char * const user, const char * const pass, unsigned short port) : st_mysql {} @@ -36,64 +35,6 @@ namespace MyGrate::Input { verify<std::runtime_error>(!mysql_query(this, q), q); } - struct Bindings { - // NOLINTNEXTLINE(hicpp-explicit-conversions) - explicit Bindings(const std::initializer_list<DbValue> & vs) - { - binds.reserve(vs.size()); - extras.reserve(vs.size()); - for (const auto & v : vs) { - std::visit(*this, v); - } - } - template<std::integral T> - void - operator()(const T & v) - { - auto & b = binds.emplace_back(); - b.buffer_type = MySQL::CType<T>::type; - b.buffer = const_cast<T *>(&v); - b.is_unsigned = std::unsigned_integral<T>; - } - template<std::floating_point T> - void - operator()(const T & v) - { - auto & b = binds.emplace_back(); - b.buffer_type = MySQL::CType<T>::type; - b.buffer = const_cast<T *>(&v); - } - template<Viewable T> - void - operator()(const T & v) - { - auto & b = binds.emplace_back(); - b.buffer_type = MySQL::CType<T>::type; - b.buffer = const_cast<typename T::value_type *>(v.data()); - b.length = &extras.emplace_back(v.size(), 0).len; - } - void - operator()(const std::nullptr_t &) - { - auto & b = binds.emplace_back(); - b.buffer = nullptr; - b.is_null = &extras.emplace_back(0, 1).null; - } - template<typename T> - void - operator()(const T &) - { - throw std::runtime_error("Not implemented"); - } - struct extra { - explicit extra(unsigned long l, my_bool n = 0) : len {l}, null {n} { } - unsigned long len; - my_bool null; - }; - std::vector<MYSQL_BIND> binds; - std::vector<extra> extras; - }; - void MySQLConn::query(const char * const q, const std::initializer_list<DbValue> & vs) { @@ -104,4 +45,10 @@ namespace MyGrate::Input { verify<std::runtime_error>(!mysql_stmt_bind_param(stmt.get(), b.binds.data()), "Param count mismatch"); verify<std::runtime_error>(!mysql_stmt_execute(stmt.get()), q); } + + DbPrepStmtPtr + MySQLConn::prepare(const char * const q, std::size_t) + { + return std::make_unique<MySQLPrepStmt>(q, this); + } } diff --git a/lib/input/mysqlConn.h b/lib/input/mysqlConn.h index 9e4ec25..2f71262 100644 --- a/lib/input/mysqlConn.h +++ b/lib/input/mysqlConn.h @@ -1,6 +1,7 @@ #ifndef MYGRATE_INPUT_MYSQLCONN_H #define MYGRATE_INPUT_MYSQLCONN_H +#include <cstddef> #include <dbConn.h> #include <dbTypes.h> #include <initializer_list> @@ -14,6 +15,8 @@ namespace MyGrate::Input { void query(const char * const) override; void query(const char * const q, const std::initializer_list<DbValue> &) override; + + DbPrepStmtPtr prepare(const char * const, std::size_t) override; }; } diff --git a/lib/input/mysqlRecordSet.cpp b/lib/input/mysqlRecordSet.cpp new file mode 100644 index 0000000..51be15f --- /dev/null +++ b/lib/input/mysqlRecordSet.cpp @@ -0,0 +1,99 @@ +#include "mysqlRecordSet.h" +#include "mysqlBindings.h" +#include "mysqlStmt.h" +#include <cstdint> +#include <dbTypes.h> +#include <helpers.h> +#include <stdexcept> +#include <string_view> +#include <utility> +// IWYU pragma: no_include <ext/alloc_traits.h> + +namespace MyGrate::Input { + MySQLRecordSet::MySQLRecordSet(StmtPtr s) : + stmt {std::move(s)}, stmtres {nullptr, nullptr}, fields(mysql_stmt_field_count(stmt.get())), + extras(fields.size()) + { + auto getBind = [](const MYSQL_FIELD & f, MYSQL_BIND & b) -> std::unique_ptr<ResultData> { + switch (f.type) { + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_DOUBLE: + return std::make_unique<ResultDataT<double>>(b, f); + case MYSQL_TYPE_FLOAT: + return std::make_unique<ResultDataT<float>>(b, f); + case MYSQL_TYPE_TINY: + return std::make_unique<ResultDataT<int8_t>>(b, f); + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_YEAR: + return std::make_unique<ResultDataT<int16_t>>(b, f); + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + return std::make_unique<ResultDataT<int32_t>>(b, f); + case MYSQL_TYPE_LONGLONG: + return std::make_unique<ResultDataT<int64_t>>(b, f); + case MYSQL_TYPE_NULL: + return std::make_unique<ResultDataT<std::nullptr_t>>(b, f); + case MYSQL_TYPE_TIMESTAMP: + case MYSQL_TYPE_TIMESTAMP2: + case MYSQL_TYPE_DATETIME: + case MYSQL_TYPE_DATETIME2: + return std::make_unique<ResultDataTime<DateTime>>(b, f); + case MYSQL_TYPE_TIME: + case MYSQL_TYPE_TIME2: + return std::make_unique<ResultDataTime<Time>>(b, f); + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_NEWDATE: + return std::make_unique<ResultDataTime<Date>>(b, f); + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_VAR_STRING: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_JSON: + case MYSQL_TYPE_ENUM: + return std::make_unique<ResultDataT<std::string_view>>(b, f); + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_BLOB: + return std::make_unique<ResultDataT<Blob>>(b, f); + case MAX_NO_FIELD_TYPES: + case MYSQL_TYPE_BIT: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_GEOMETRY:; + } + throw std::runtime_error("Unsupported column type"); + }; + ResPtr meta {mysql_stmt_result_metadata(stmt.get()), mysql_free_result}; + const auto fieldDefs = mysql_fetch_fields(meta.get()); + verify<std::runtime_error>(fieldDefs, "Fetch fields"); + for (std::size_t i = 0; i < fields.size(); i += 1) { + extras[i] = getBind(fieldDefs[i], fields[i]); + } + verify<std::runtime_error>(!mysql_stmt_bind_result(stmt.get(), fields.data()), "Store result error"); + verify<std::runtime_error>(!mysql_stmt_store_result(stmt.get()), "Store result error"); + stmtres = {stmt.get(), mysql_stmt_free_result}; + verify<std::runtime_error>(!mysql_stmt_fetch(stmt.get()), "Fetch"); + } + + std::size_t + MySQLRecordSet::rows() const + { + return mysql_stmt_num_rows(stmt.get()); + } + + std::size_t + MySQLRecordSet::columns() const + { + return fields.size(); + } + + DbValue + MySQLRecordSet::at(std::size_t row, std::size_t col) const + { + mysql_stmt_data_seek(stmt.get(), row); + if (extras[col]->null) { + return nullptr; + } + return extras[col]->getValue(); + } +} diff --git a/lib/input/mysqlRecordSet.h b/lib/input/mysqlRecordSet.h new file mode 100644 index 0000000..849a653 --- /dev/null +++ b/lib/input/mysqlRecordSet.h @@ -0,0 +1,37 @@ +#ifndef MYGRATE_INPUT_MYSQLRECORDSET_H +#define MYGRATE_INPUT_MYSQLRECORDSET_H + +#include "mysqlStmt.h" +#include <cstddef> +#include <dbRecordSet.h> +#include <dbTypes.h> +#include <memory> +#include <mysql.h> +#include <vector> + +namespace MyGrate::Input { + class ResultData; + + class MySQLRecordSet : public RecordSet { + public: + using ResPtr = std::unique_ptr<MYSQL_RES, decltype(&mysql_free_result)>; + using StmtResPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_free_result)>; + using ResultDataPtr = std::unique_ptr<ResultData>; + + explicit MySQLRecordSet(StmtPtr s); + + std::size_t rows() const override; + + std::size_t columns() const override; + + DbValue at(std::size_t row, std::size_t col) const override; + + private: + StmtPtr stmt; + StmtResPtr stmtres; + std::vector<MYSQL_BIND> fields; + std::vector<ResultDataPtr> extras; + }; +} + +#endif diff --git a/lib/input/mysqlStmt.cpp b/lib/input/mysqlStmt.cpp new file mode 100644 index 0000000..08d1303 --- /dev/null +++ b/lib/input/mysqlStmt.cpp @@ -0,0 +1,35 @@ +#include "mysqlStmt.h" +#include "mysqlBindings.h" +#include "mysqlRecordSet.h" +#include <cstring> +#include <helpers.h> +#include <stdexcept> +#include <utility> +#include <vector> + +namespace MyGrate::Input { + MySQLPrepStmt::MySQLPrepStmt(const char * const q, MYSQL * c) : stmt {mysql_stmt_init(c), &mysql_stmt_close} + { + verify<std::runtime_error>(!mysql_stmt_prepare(stmt.get(), q, strlen(q)), q); + } + + void + MySQLPrepStmt::execute(const std::initializer_list<DbValue> & vs) + { + Bindings b {vs}; + verify<std::runtime_error>(!mysql_stmt_bind_param(stmt.get(), b.binds.data()), "Param count mismatch"); + verify<std::runtime_error>(!mysql_stmt_execute(stmt.get()), "Prepared statement execute"); + } + + std::size_t + MySQLPrepStmt::rows() const + { + return mysql_stmt_affected_rows(stmt.get()); + } + + RecordSetPtr + MySQLPrepStmt::recordSet() + { + return std::make_unique<MySQLRecordSet>(std::move(stmt)); + } +} diff --git a/lib/input/mysqlStmt.h b/lib/input/mysqlStmt.h new file mode 100644 index 0000000..a42e0db --- /dev/null +++ b/lib/input/mysqlStmt.h @@ -0,0 +1,29 @@ +#ifndef MYGRATE_INPUT_MYSQLSTMT_H +#define MYGRATE_INPUT_MYSQLSTMT_H + +#include "dbConn.h" +#include "dbRecordSet.h" +#include "dbTypes.h" +#include <cstddef> +#include <initializer_list> +#include <memory> +#include <mysql.h> + +namespace MyGrate::Input { + using StmtPtr = std::unique_ptr<MYSQL_STMT, decltype(&mysql_stmt_close)>; + + class MySQLPrepStmt : public DbPrepStmt { + public: + MySQLPrepStmt(const char * const q, MYSQL * c); + void execute(const std::initializer_list<DbValue> & vs) override; + + std::size_t rows() const override; + + RecordSetPtr recordSet() override; + + private: + StmtPtr stmt; + }; +} + +#endif diff --git a/lib/output/pq/pqBindings.h b/lib/output/pq/pqBindings.h new file mode 100644 index 0000000..ef0df84 --- /dev/null +++ b/lib/output/pq/pqBindings.h @@ -0,0 +1,57 @@ +#ifndef MYGRATE_OUTPUT_PQ_PQBINDINGS +#define MYGRATE_OUTPUT_PQ_PQBINDINGS + +#include <dbTypes.h> +#include <helpers.h> +#include <initializer_list> +#include <variant> +#include <vector> + +namespace MyGrate::Output::Pq { + struct Bindings { + // NOLINTNEXTLINE(hicpp-explicit-conversions) + explicit Bindings(const std::initializer_list<DbValue> & vs) + { + bufs.reserve(vs.size()); + values.reserve(vs.size()); + lengths.reserve(vs.size()); + for (const auto & v : vs) { + std::visit(*this, v); + } + } + template<Stringable T> + void + operator()(const T & v) + { + bufs.emplace_back(std::to_string(v)); + const auto & vw {bufs.back()}; + values.emplace_back(vw.data()); + lengths.emplace_back(vw.length()); + } + template<Viewable T> + void + operator()(const T & v) + { + values.emplace_back(v.data()); + lengths.emplace_back(v.size()); + } + template<typename T> + void + operator()(const T &) + { + throw std::runtime_error("Not implemented"); + } + void + operator()(const std::nullptr_t &) + { + values.emplace_back(nullptr); + lengths.emplace_back(0); + } + + std::vector<std::string> bufs; + std::vector<const char *> values; + std::vector<int> lengths; + }; +} + +#endif diff --git a/lib/output/pq/pqConn.cpp b/lib/output/pq/pqConn.cpp index 4f55ba8..81d9610 100644 --- a/lib/output/pq/pqConn.cpp +++ b/lib/output/pq/pqConn.cpp @@ -1,89 +1,44 @@ #include "pqConn.h" +#include "pqBindings.h" +#include "pqStmt.h" +#include <dbConn.h> #include <dbTypes.h> #include <helpers.h> #include <libpq-fe.h> #include <memory> -#include <sstream> #include <stdexcept> #include <string> -#include <variant> #include <vector> namespace MyGrate::Output::Pq { - using ResPtr = std::unique_ptr<PGresult, decltype(&PQclear)>; - - PqConn::PqConn(const char * const str) : conn {PQconnectdb(str)} + PqConn::PqConn(const char * const str) : conn {PQconnectdb(str), PQfinish} { - verify<std::runtime_error>(PQstatus(conn) == CONNECTION_OK, "Connection failure"); - PQsetNoticeProcessor(conn, notice_processor, this); - } - - PqConn::~PqConn() - { - PQfinish(conn); + verify<std::runtime_error>(PQstatus(conn.get()) == CONNECTION_OK, "Connection failure"); + PQsetNoticeProcessor(conn.get(), notice_processor, this); } void PqConn::query(const char * const q) { - ResPtr res {PQexec(conn, q), &PQclear}; + ResPtr res {PQexec(conn.get(), q), &PQclear}; verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q); } - struct Bindings { - // NOLINTNEXTLINE(hicpp-explicit-conversions) - explicit Bindings(const std::initializer_list<DbValue> & vs) - { - bufs.reserve(vs.size()); - values.reserve(vs.size()); - lengths.reserve(vs.size()); - for (const auto & v : vs) { - std::visit(*this, v); - } - } - template<Stringable T> - void - operator()(const T & v) - { - bufs.emplace_back(std::to_string(v)); - const auto & vw {bufs.back()}; - values.emplace_back(vw.data()); - lengths.emplace_back(vw.length()); - } - template<Viewable T> - void - operator()(const T & v) - { - values.emplace_back(v.data()); - lengths.emplace_back(v.size()); - } - template<typename T> - void - operator()(const T &) - { - throw std::runtime_error("Not implemented"); - } - void - operator()(const std::nullptr_t &) - { - values.emplace_back(nullptr); - lengths.emplace_back(0); - } - - std::vector<std::string> bufs; - std::vector<const char *> values; - std::vector<int> lengths; - }; - void PqConn::query(const char * const q, const std::initializer_list<DbValue> & vs) { Bindings b {vs}; - ResPtr res {PQexecParams(conn, q, (int)vs.size(), nullptr, b.values.data(), b.lengths.data(), nullptr, 0), + ResPtr res {PQexecParams(conn.get(), q, (int)vs.size(), nullptr, b.values.data(), b.lengths.data(), nullptr, 0), &PQclear}; verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q); } + DbPrepStmtPtr + PqConn::prepare(const char * const q, std::size_t n) + { + return std::make_unique<PqPrepStmt>(q, n, this); + } + void PqConn::notice_processor(void * p, const char * n) { diff --git a/lib/output/pq/pqConn.h b/lib/output/pq/pqConn.h index 613af6f..856683d 100644 --- a/lib/output/pq/pqConn.h +++ b/lib/output/pq/pqConn.h @@ -1,25 +1,35 @@ #ifndef MYGRATE_OUTPUT_PQ_PQCONN_H #define MYGRATE_OUTPUT_PQ_PQCONN_H +#include <cstddef> #include <dbConn.h> #include <dbTypes.h> +#include <functional> #include <initializer_list> #include <libpq-fe.h> +#include <map> +#include <memory> +#include <string> namespace MyGrate::Output::Pq { class PqConn : public DbConn { public: explicit PqConn(const char * const str); - virtual ~PqConn(); + virtual ~PqConn() = default; void query(const char * const) override; void query(const char * const, const std::initializer_list<DbValue> &) override; + DbPrepStmtPtr prepare(const char * const, std::size_t nParams) override; + private: static void notice_processor(void *, const char *); virtual void notice_processor(const char *) const; - PGconn * const conn; + std::unique_ptr<PGconn, decltype(&PQfinish)> const conn; + + friend class PqPrepStmt; + std::map<std::string, std::string, std::less<>> stmts; }; } diff --git a/lib/output/pq/pqRecordSet.cpp b/lib/output/pq/pqRecordSet.cpp new file mode 100644 index 0000000..71ddee4 --- /dev/null +++ b/lib/output/pq/pqRecordSet.cpp @@ -0,0 +1,86 @@ +#include "pqRecordSet.h" +#include "dbTypes.h" +#include "pqStmt.h" +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <ctime> +#include <helpers.h> +#include <libpq-fe.h> +#include <server/catalog/pg_type_d.h> +#include <stdexcept> +#include <string_view> +#include <utility> + +namespace MyGrate::Output::Pq { + PqRecordSet::PqRecordSet(ResPtr r) : res {std::move(r)} { } + + std::size_t + PqRecordSet::rows() const + { + return PQntuples(res.get()); + } + + std::size_t + PqRecordSet::columns() const + { + return PQnfields(res.get()); + } + + DbValue + PqRecordSet::at(std::size_t row, std::size_t col) const + { + if (PQgetisnull(res.get(), (int)row, (int)col)) { + return nullptr; + } + const auto value {PQgetvalue(res.get(), (int)row, (int)col)}; + const auto size {static_cast<size_t>(PQgetlength(res.get(), (int)row, (int)col))}; + const auto type {PQftype(res.get(), (int)col)}; + switch (type) { + // case BITOID: TODO bool + // case BOOLOID: TODO bool + // case BOOLARRAYOID: + case VARBITOID: + case BYTEAOID: + // This is wrong :) + return Blob {reinterpret_cast<const std::byte *>(value), size}; + case INT2OID: + return static_cast<int16_t>(std::strtol(value, nullptr, 10)); + case INT4OID: + return static_cast<int32_t>(std::strtol(value, nullptr, 10)); + case INT8OID: + return std::strtol(value, nullptr, 10); + case FLOAT4OID: + return std::strtof(value, nullptr); + case FLOAT8OID: + case CASHOID: + case NUMERICOID: + return std::strtod(value, nullptr); + case DATEOID: { + tm tm {}; + const auto end = strptime(value, "%F", &tm); + verify<std::runtime_error>(end && !*end, "Invalid date string"); + return Date {tm}; + } + case TIMEOID: { + tm tm {}; + const auto end = strptime(value, "%T", &tm); + verify<std::runtime_error>(end && !*end, "Invalid time string"); + return Time {tm}; + } + case TIMESTAMPOID: { + tm tm {}; + const auto end = strptime(value, "%FT%T", &tm); + verify<std::runtime_error>(end && !*end, "Invalid timestamp string"); + return DateTime {tm}; + } + // case TIMESTAMPTZOID: Maybe add TZ support? + // case INTERVALOID: Maybe add interval support? + // case TIMETZOID: Maybe add TZ support? + case VOIDOID: + return nullptr; + default: + return std::string_view {value, size}; + } + } +} diff --git a/lib/output/pq/pqRecordSet.h b/lib/output/pq/pqRecordSet.h new file mode 100644 index 0000000..2934d84 --- /dev/null +++ b/lib/output/pq/pqRecordSet.h @@ -0,0 +1,25 @@ +#ifndef MYGRATE_OUTPUT_PQ_PQRECORDSET_H +#define MYGRATE_OUTPUT_PQ_PQRECORDSET_H + +#include "dbRecordSet.h" +#include "dbTypes.h" +#include "pqStmt.h" +#include <cstddef> + +namespace MyGrate::Output::Pq { + class PqRecordSet : public RecordSet { + public: + explicit PqRecordSet(ResPtr r); + + std::size_t rows() const override; + + std::size_t columns() const override; + + DbValue at(std::size_t row, std::size_t col) const override; + + private: + ResPtr res; + }; +} + +#endif diff --git a/lib/output/pq/pqStmt.cpp b/lib/output/pq/pqStmt.cpp new file mode 100644 index 0000000..04b48c6 --- /dev/null +++ b/lib/output/pq/pqStmt.cpp @@ -0,0 +1,54 @@ +#include "pqStmt.h" +#include "libpq-fe.h" +#include "pqBindings.h" +#include "pqConn.h" +#include "pqRecordSet.h" +#include <compileTimeFormatter.h> +#include <cstdlib> +#include <functional> +#include <helpers.h> +#include <map> +#include <stdexcept> +#include <utility> +#include <vector> + +namespace MyGrate::Output::Pq { + PqPrepStmt::PqPrepStmt(const char * const q, std::size_t n, PqConn * c) : + conn {c->conn.get()}, name {prepareAsNeeded(q, n, c)}, res {nullptr, nullptr} + { + } + + void + PqPrepStmt::execute(const std::initializer_list<DbValue> & vs) + { + Bindings b {vs}; + res = {PQexecPrepared(conn, name.c_str(), (int)vs.size(), b.values.data(), b.lengths.data(), nullptr, 0), + &PQclear}; + verify<std::runtime_error>( + PQresultStatus(res.get()) == PGRES_COMMAND_OK || PQresultStatus(res.get()) == PGRES_TUPLES_OK, name); + } + + std::size_t + PqPrepStmt::rows() const + { + return std::strtoul(PQcmdTuples(res.get()), nullptr, 10); + } + + RecordSetPtr + PqPrepStmt::recordSet() + { + return std::make_unique<PqRecordSet>(std::move(res)); + } + + std::string + PqPrepStmt::prepareAsNeeded(const char * const q, std::size_t n, PqConn * c) + { + if (const auto i = c->stmts.find(q); i != c->stmts.end()) { + return i->second; + } + auto nam {AdHoc::scprintf<"pst%0x">(c->stmts.size())}; + ResPtr res {PQprepare(c->conn.get(), nam.c_str(), q, (int)n, nullptr), PQclear}; + verify<std::runtime_error>(PQresultStatus(res.get()) == PGRES_COMMAND_OK, q); + return c->stmts.emplace(q, std::move(nam)).first->second; + } +} diff --git a/lib/output/pq/pqStmt.h b/lib/output/pq/pqStmt.h new file mode 100644 index 0000000..b806617 --- /dev/null +++ b/lib/output/pq/pqStmt.h @@ -0,0 +1,37 @@ +#ifndef MYGRATE_OUTPUT_PQ_PQSTMT_H +#define MYGRATE_OUTPUT_PQ_PQSTMT_H + +#include "dbConn.h" +#include "dbRecordSet.h" +#include "dbTypes.h" +#include <cstddef> +#include <initializer_list> +#include <libpq-fe.h> +#include <memory> +#include <string> + +namespace MyGrate::Output::Pq { + class PqConn; + + using ResPtr = std::unique_ptr<PGresult, decltype(&PQclear)>; + + class PqPrepStmt : public DbPrepStmt { + public: + PqPrepStmt(const char * const q, std::size_t n, PqConn * c); + + void execute(const std::initializer_list<DbValue> & vs) override; + + std::size_t rows() const override; + + RecordSetPtr recordSet() override; + + private: + static std::string prepareAsNeeded(const char * const q, std::size_t n, PqConn * c); + + PGconn * conn; + std::string name; + ResPtr res; + }; +} + +#endif |