summaryrefslogtreecommitdiff
path: root/lib/dbStmt.h
diff options
context:
space:
mode:
Diffstat (limited to 'lib/dbStmt.h')
-rw-r--r--lib/dbStmt.h66
1 files changed, 66 insertions, 0 deletions
diff --git a/lib/dbStmt.h b/lib/dbStmt.h
new file mode 100644
index 0000000..3e98b34
--- /dev/null
+++ b/lib/dbStmt.h
@@ -0,0 +1,66 @@
+#ifndef MYGRATE_DBSTMT_H
+#define MYGRATE_DBSTMT_H
+
+#include <compileTimeFormatter.h>
+#include <dbConn.h>
+#include <dbRecordSet.h>
+#include <memory>
+#include <string_view>
+#include <type_traits>
+
+namespace MyGrate {
+ class DbConn;
+ enum class ParamMode { None, DollarNum, QMark };
+
+ template<AdHoc::support::basic_fixed_string S, ParamMode pm = ParamMode::None> class DbStmt {
+ public:
+ // This don't account for common table expressions, hopefully won't need those :)
+ static constexpr auto isSelect {S.v().starts_with("SELECT") || S.v().starts_with("SHOW")
+ || S.v().find("RETURNING") != std::string_view::npos};
+
+ // These don't account for string literals, which we'd prefer to avoid anyway :)
+ static constexpr auto paramCount {[]() -> std::size_t {
+ switch (pm) {
+ case ParamMode::None:
+ return 0LU;
+ case ParamMode::DollarNum: {
+ const auto pn = [](const char * c, const char * const e) {
+ std::size_t n {0};
+ while (++c != e && *c >= '0' && *c <= '9') {
+ n = (n * 10) + (*c - '0');
+ }
+ return n;
+ };
+ return pn(std::max_element(S.v().begin(), S.v().end(),
+ [pn, e = S.v().end()](const char & a, const char & b) {
+ return (a == '$' ? pn(&a, e) : 0) < (b == '$' ? pn(&b, e) : 0);
+ }),
+ S.v().end());
+ }
+ case ParamMode::QMark:
+ return std::count_if(S.v().begin(), S.v().end(), [](char c) {
+ return c == '?';
+ });
+ }
+ }()};
+
+ using Return = std::conditional_t<isSelect, RecordSetPtr, std::size_t>;
+
+ template<typename... P>
+ static Return
+ execute(DbConn * c, P &&... p)
+ {
+ static_assert(sizeof...(P) == paramCount);
+ auto stmt {c->prepare(S, sizeof...(P))};
+ stmt->execute({std::forward<P...>(p)...});
+ if constexpr (isSelect) {
+ return stmt->recordSet();
+ }
+ else {
+ return stmt->rows();
+ }
+ }
+ };
+}
+
+#endif