summaryrefslogtreecommitdiff
path: root/lib/dbStmt.h
blob: a6d4a65cf3435449acdc3c7660fe28e3fb470a6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#ifndef MYGRATE_DBSTMT_H
#define MYGRATE_DBSTMT_H

#include "fixedString.h"
#include <dbConn.h>
#include <dbRecordSet.h>
#include <memory>
#include <string_view>
#include <type_traits>

namespace MyGrate {
	template<Support::basic_fixed_string S> class DbStmt {
	public:
		// This don't account for common table expressions, hopefully won't need those :)
		static constexpr auto isSelect {S.v().starts_with("SELECT") || S.v().starts_with("SHOW")
				|| S.v().find("RETURNING") != std::string_view::npos};

		// These don't account for string literals, which we'd prefer to avoid anyway :)
		static constexpr std::size_t
		paramCount(ParamMode pm)
		{
			switch (pm) {
				case ParamMode::None:
					return 0LU;
				case ParamMode::DollarNum: {
					const auto pn = [](const char * c, const char * const e) {
						std::size_t n {0};
						while (++c != e && *c >= '0' && *c <= '9') {
							n = (n * 10) + unsigned(*c - '0');
						}
						return n;
					};
					return pn(std::max_element(S.v().begin(), S.v().end(),
									  [pn, e = S.v().end()](const char & a, const char & b) {
										  return (a == '$' ? pn(&a, e) : 0) < (b == '$' ? pn(&b, e) : 0);
									  }),
							S.v().end());
				}
				case ParamMode::QMark:
					return std::count_if(S.v().begin(), S.v().end(), [](char c) {
						return c == '?';
					});
			}
		}

		using Return = std::conditional_t<isSelect, RecordSetPtr, std::size_t>;

		template<typename ConnType, typename... P>
		static Return
		execute(ConnType * c, P &&... p)
		{
			static_assert(sizeof...(P) == paramCount(ConnType::paramMode), "Wrong number of parameters for statement");
			auto stmt {c->prepare(S, sizeof...(P))};
			stmt->execute({std::forward<P>(p)...});
			if constexpr (isSelect) {
				return stmt->recordSet();
			}
			else {
				return stmt->rows();
			}
		}
	};
}

#endif