From 936f41ec0ce1736e74b2fda2c65249956498c777 Mon Sep 17 00:00:00 2001 From: Dan Goodliffe Date: Sat, 12 Dec 2020 17:20:40 +0000 Subject: Smart pointer stmt to fix leak --- libmysqlpp/my-command.cpp | 13 ++++++------- libmysqlpp/my-command.h | 2 +- libmysqlpp/my-modifycommand.cpp | 6 +++--- libmysqlpp/my-selectcommand.cpp | 20 ++++++++++---------- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/libmysqlpp/my-command.cpp b/libmysqlpp/my-command.cpp index 539fbf6..7e7275a 100644 --- a/libmysqlpp/my-command.cpp +++ b/libmysqlpp/my-command.cpp @@ -5,15 +5,15 @@ #include MySQL::Command::Command(const Connection * conn, const std::string & sql) : - DB::Command(sql), c(conn), stmt(mysql_stmt_init(&conn->conn)), paramsNeedBinding(false) + DB::Command(sql), c(conn), stmt(mysql_stmt_init(&conn->conn), &mysql_stmt_close), paramsNeedBinding(false) { if (!stmt) { throw Error(&conn->conn); } - if (mysql_stmt_prepare(stmt, sql.c_str(), sql.length())) { - throw Error(stmt); + if (mysql_stmt_prepare(stmt.get(), sql.c_str(), sql.length())) { + throw Error(stmt.get()); } - binds.resize(mysql_stmt_param_count(stmt)); + binds.resize(mysql_stmt_param_count(stmt.get())); if (binds.size()) { paramsNeedBinding = true; for (auto & b : binds) { @@ -29,7 +29,6 @@ MySQL::Command::~Command() // NOLINTNEXTLINE(hicpp-no-malloc) free(b.buffer); } - mysql_stmt_close(stmt); } void * @@ -160,8 +159,8 @@ void MySQL::Command::bindParams() { if (paramsNeedBinding) { - if (mysql_stmt_bind_param(stmt, &binds.front())) { - throw Error(stmt); + if (mysql_stmt_bind_param(stmt.get(), &binds.front())) { + throw Error(stmt.get()); } } } diff --git a/libmysqlpp/my-command.h b/libmysqlpp/my-command.h index 9b6d72a..4a6426b 100644 --- a/libmysqlpp/my-command.h +++ b/libmysqlpp/my-command.h @@ -37,7 +37,7 @@ namespace MySQL { void * realloc(void * buffer, size_t size); const Connection * c; - MYSQL_STMT * stmt; + std::unique_ptr stmt; typedef std::vector Binds; Binds binds; bool paramsNeedBinding; diff --git a/libmysqlpp/my-modifycommand.cpp b/libmysqlpp/my-modifycommand.cpp index 7610a27..dd607b3 100644 --- a/libmysqlpp/my-modifycommand.cpp +++ b/libmysqlpp/my-modifycommand.cpp @@ -12,10 +12,10 @@ unsigned int MySQL::ModifyCommand::execute(bool anc) { bindParams(); - if (mysql_stmt_execute(stmt)) { - throw Error(stmt); + if (mysql_stmt_execute(stmt.get())) { + throw Error(stmt.get()); } - int rows = mysql_stmt_affected_rows(stmt); + int rows = mysql_stmt_affected_rows(stmt.get()); if (rows == 0 && !anc) { throw DB::NoRowsAffected(); } diff --git a/libmysqlpp/my-selectcommand.cpp b/libmysqlpp/my-selectcommand.cpp index 2bf336f..73ae7cc 100644 --- a/libmysqlpp/my-selectcommand.cpp +++ b/libmysqlpp/my-selectcommand.cpp @@ -14,11 +14,11 @@ MySQL::SelectCommand::execute() { if (!prepared) { bindParams(); - fields.resize(mysql_stmt_field_count(stmt)); + fields.resize(mysql_stmt_field_count(stmt.get())); for (auto & b : fields) { memset(&b, 0, sizeof(MYSQL_BIND)); } - MYSQL_RES * prepare_meta_result = mysql_stmt_result_metadata(stmt); + MYSQL_RES * prepare_meta_result = mysql_stmt_result_metadata(stmt.get()); MYSQL_FIELD * fieldDefs = mysql_fetch_fields(prepare_meta_result); for (std::size_t i = 0; i < fields.size(); i += 1) { switch (fieldDefs[i].type) { @@ -65,17 +65,17 @@ MySQL::SelectCommand::execute() } } mysql_free_result(prepare_meta_result); - if (mysql_stmt_bind_result(stmt, &fields.front())) { - throw Error(stmt); + if (mysql_stmt_bind_result(stmt.get(), &fields.front())) { + throw Error(stmt.get()); } prepared = true; } if (!executed) { - if (mysql_stmt_execute(stmt)) { - throw Error(stmt); + if (mysql_stmt_execute(stmt.get())) { + throw Error(stmt.get()); } - if (mysql_stmt_store_result(stmt)) { - throw Error(stmt); + if (mysql_stmt_store_result(stmt.get())) { + throw Error(stmt.get()); } executed = true; } @@ -85,13 +85,13 @@ bool MySQL::SelectCommand::fetch() { execute(); - switch (mysql_stmt_fetch(stmt)) { + switch (mysql_stmt_fetch(stmt.get())) { case 0: return true; case MYSQL_NO_DATA: executed = false; return false; default: - throw Error(stmt); + throw Error(stmt.get()); } } -- cgit v1.2.3