diff src/lib-sql/sql-api.c @ 22514:7d189977dfe3

lib-sql: Add support for prepared SQL statements. This initial implementation doesn't use prepared statements in drivers, but simply generates the query string internally.
author Timo Sirainen <timo.sirainen@dovecot.fi>
date Tue, 22 Aug 2017 11:37:17 +0300
parents 26e81cfafb1c
children cb108f786fb4
line wrap: on
line diff
--- a/src/lib-sql/sql-api.c	Tue Aug 22 11:31:51 2017 +0300
+++ b/src/lib-sql/sql-api.c	Tue Aug 22 11:37:17 2017 +0300
@@ -3,10 +3,16 @@
 #include "lib.h"
 #include "array.h"
 #include "ioloop.h"
+#include "str.h"
 #include "sql-api-private.h"
 
 #include <time.h>
 
+struct default_sql_prepared_statement {
+	struct sql_prepared_statement prep_stmt;
+	char *query_template;
+};
+
 struct sql_db_module_register sql_db_module_register = { 0 };
 ARRAY_TYPE(sql_drivers) sql_drivers;
 
@@ -148,6 +154,226 @@
 	return db->v.query_s(db, query);
 }
 
+static struct sql_prepared_statement *
+default_sql_prepared_statement_init(struct sql_db *db,
+				    const char *query_template)
+{
+	struct default_sql_prepared_statement *prep_stmt;
+
+	prep_stmt = i_new(struct default_sql_prepared_statement, 1);
+	prep_stmt->prep_stmt.db = db;
+	prep_stmt->query_template = i_strdup(query_template);
+	return &prep_stmt->prep_stmt;
+}
+
+static void
+default_sql_prepared_statement_deinit(struct sql_prepared_statement *_prep_stmt)
+{
+	struct default_sql_prepared_statement *prep_stmt =
+		(struct default_sql_prepared_statement *)_prep_stmt;
+
+	i_free(prep_stmt->query_template);
+	i_free(prep_stmt);
+}
+
+static struct sql_statement *
+default_sql_statement_init_prepared(struct sql_prepared_statement *_stmt)
+{
+	struct default_sql_prepared_statement *stmt =
+		(struct default_sql_prepared_statement *)_stmt;
+	return sql_statement_init(_stmt->db, stmt->query_template);
+}
+
+const char *sql_statement_get_query(struct sql_statement *stmt)
+{
+	string_t *query = t_str_new(128);
+	const char *const *args;
+	unsigned int i, args_count, arg_pos = 0;
+
+	args = array_get(&stmt->args, &args_count);
+
+	for (i = 0; stmt->query_template[i] != '\0'; i++) {
+		if (stmt->query_template[i] == '?') {
+			if (arg_pos >= args_count ||
+			    args[arg_pos] == NULL) {
+				i_panic("lib-sql: Missing bind for arg #%u in statement: %s",
+					arg_pos, stmt->query_template);
+			}
+			str_append(query, args[arg_pos++]);
+		} else {
+			str_append_c(query, stmt->query_template[i]);
+		}
+	}
+	if (arg_pos != args_count) {
+		i_panic("lib-sql: Too many bind args (%u) for statement: %s",
+			args_count, stmt->query_template);
+	}
+	return str_c(query);
+}
+
+static void
+default_sql_statement_query(struct sql_statement *stmt,
+			    sql_query_callback_t *callback, void *context)
+{
+	sql_query(stmt->db, sql_statement_get_query(stmt),
+		  callback, context);
+	pool_unref(&stmt->pool);
+}
+
+static struct sql_result *
+default_sql_statement_query_s(struct sql_statement *stmt)
+{
+	struct sql_result *result =
+		sql_query_s(stmt->db, sql_statement_get_query(stmt));
+	pool_unref(&stmt->pool);
+	return result;
+}
+
+static void default_sql_update_stmt(struct sql_transaction_context *ctx,
+				    struct sql_statement *stmt,
+				    unsigned int *affected_rows)
+{
+	ctx->db->v.update(ctx, sql_statement_get_query(stmt),
+			  affected_rows);
+	pool_unref(&stmt->pool);
+}
+
+struct sql_prepared_statement *
+sql_prepared_statement_init(struct sql_db *db, const char *query_template)
+{
+	if (db->v.prepared_statement_init != NULL)
+		return db->v.prepared_statement_init(db, query_template);
+	else
+		return default_sql_prepared_statement_init(db, query_template);
+}
+
+void sql_prepared_statement_deinit(struct sql_prepared_statement **_prep_stmt)
+{
+	struct sql_prepared_statement *prep_stmt = *_prep_stmt;
+
+	*_prep_stmt = NULL;
+	if (prep_stmt->db->v.prepared_statement_deinit != NULL)
+		prep_stmt->db->v.prepared_statement_deinit(prep_stmt);
+	else
+		default_sql_prepared_statement_deinit(prep_stmt);
+}
+
+static void
+sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db)
+{
+	stmt->db = db;
+	p_array_init(&stmt->args, stmt->pool, 8);
+}
+
+struct sql_statement *
+sql_statement_init(struct sql_db *db, const char *query_template)
+{
+	struct sql_statement *stmt;
+
+	if (db->v.statement_init != NULL)
+		stmt = db->v.statement_init(db, query_template);
+	else {
+		pool_t pool = pool_alloconly_create("sql statement", 1024);
+		stmt = p_new(pool, struct sql_statement, 1);
+		stmt->pool = pool;
+	}
+	stmt->query_template = p_strdup(stmt->pool, query_template);
+	sql_statement_init_fields(stmt, db);
+	return stmt;
+}
+
+struct sql_statement *
+sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt)
+{
+	struct sql_statement *stmt;
+
+	if (prep_stmt->db->v.statement_init_prepared == NULL)
+		return default_sql_statement_init_prepared(prep_stmt);
+
+	stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt);
+	sql_statement_init_fields(stmt, prep_stmt->db);
+	return stmt;
+}
+
+void sql_statement_abort(struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_abort != NULL)
+		stmt->db->v.statement_abort(stmt);
+	pool_unref(&stmt->pool);
+}
+
+void sql_statement_set_timestamp(struct sql_statement *stmt,
+				 const struct timespec *ts)
+{
+	if (stmt->db->v.statement_set_timestamp != NULL)
+		stmt->db->v.statement_set_timestamp(stmt, ts);
+}
+
+void sql_statement_bind_str(struct sql_statement *stmt,
+			    unsigned int column_idx, const char *value)
+{
+	const char *escaped_value =
+		p_strdup_printf(stmt->pool, "'%s'",
+				sql_escape_string(stmt->db, value));
+	array_idx_set(&stmt->args, column_idx, &escaped_value);
+
+	if (stmt->db->v.statement_bind_str != NULL)
+		stmt->db->v.statement_bind_str(stmt, column_idx, value);
+}
+
+void sql_statement_bind_binary(struct sql_statement *stmt,
+			       unsigned int column_idx, const void *value,
+			       size_t value_size)
+{
+	const char *value_str =
+		p_strdup_printf(stmt->pool, "%s",
+				sql_escape_blob(stmt->db, value, value_size));
+	array_idx_set(&stmt->args, column_idx, &value_str);
+
+	if (stmt->db->v.statement_bind_binary != NULL) {
+		stmt->db->v.statement_bind_binary(stmt, column_idx,
+						  value, value_size);
+	}
+}
+
+void sql_statement_bind_int64(struct sql_statement *stmt,
+			      unsigned int column_idx, int64_t value)
+{
+	const char *value_str = p_strdup_printf(stmt->pool, "%lld",
+						(long long)value);
+	array_idx_set(&stmt->args, column_idx, &value_str);
+
+	if (stmt->db->v.statement_bind_int64 != NULL)
+		stmt->db->v.statement_bind_int64(stmt, column_idx, value);
+}
+
+#undef sql_statement_query
+void sql_statement_query(struct sql_statement **_stmt,
+			 sql_query_callback_t *callback, void *context)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_query != NULL)
+		stmt->db->v.statement_query(stmt, callback, context);
+	else
+		default_sql_statement_query(stmt, callback, context);
+}
+
+struct sql_result *sql_statement_query_s(struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_query_s != NULL)
+		return stmt->db->v.statement_query_s(stmt);
+	else
+		return default_sql_statement_query_s(stmt);
+}
+
 void sql_result_ref(struct sql_result *result)
 {
 	result->refcount++;
@@ -464,12 +690,37 @@
 	ctx->db->v.update(ctx, query, NULL);
 }
 
+void sql_update_stmt(struct sql_transaction_context *ctx,
+		     struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (ctx->db->v.update_stmt != NULL)
+		ctx->db->v.update_stmt(ctx, stmt, NULL);
+	else
+		default_sql_update_stmt(ctx, stmt, NULL);
+}
+
 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
 			 unsigned int *affected_rows)
 {
 	ctx->db->v.update(ctx, query, affected_rows);
 }
 
+void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
+			      struct sql_statement **_stmt,
+			      unsigned int *affected_rows)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (ctx->db->v.update_stmt != NULL)
+		ctx->db->v.update_stmt(ctx, stmt, affected_rows);
+	else
+		default_sql_update_stmt(ctx, stmt, affected_rows);
+}
+
 void sql_db_set_state(struct sql_db *db, enum sql_db_state state)
 {
 	enum sql_db_state old_state = db->state;