changeset 22514:7d189977dfe3

lib-sql: Add support for prepared SQL statements. This initial implementation doesn't use prepared statements in drivers, but simply generates the query string internally.
author Timo Sirainen <timo.sirainen@dovecot.fi>
date Tue, 22 Aug 2017 11:37:17 +0300
parents 59b0c5ee8adf
children b6c3fc6d0cc0
files src/lib-sql/sql-api-private.h src/lib-sql/sql-api.c src/lib-sql/sql-api.h
diffstat 3 files changed, 328 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/src/lib-sql/sql-api-private.h	Tue Aug 22 11:31:51 2017 +0300
+++ b/src/lib-sql/sql-api-private.h	Tue Aug 22 11:37:17 2017 +0300
@@ -81,6 +81,33 @@
 	void (*transaction_commit2)(struct sql_transaction_context *ctx,
 				    sql_commit2_callback_t *callback,
 				    void *context);
+
+	struct sql_prepared_statement *
+		(*prepared_statement_init)(struct sql_db *db,
+					   const char *query_template);
+	void (*prepared_statement_deinit)(struct sql_prepared_statement *prep_stmt);
+
+
+	struct sql_statement *
+		(*statement_init)(struct sql_db *db, const char *query_template);
+	struct sql_statement *
+		(*statement_init_prepared)(struct sql_prepared_statement *prep_stmt);
+	void (*statement_abort)(struct sql_statement *stmt);
+	void (*statement_set_timestamp)(struct sql_statement *stmt,
+					const struct timespec *ts);
+	void (*statement_bind_str)(struct sql_statement *stmt,
+				   unsigned int column_idx, const char *value);
+	void (*statement_bind_binary)(struct sql_statement *stmt,
+				      unsigned int column_idx, const void *value,
+				      size_t value_size);
+	void (*statement_bind_int64)(struct sql_statement *stmt,
+				     unsigned int column_idx, int64_t value);
+	void (*statement_query)(struct sql_statement *stmt,
+				sql_query_callback_t *callback, void *context);
+	struct sql_result *(*statement_query_s)(struct sql_statement *stmt);
+	void (*update_stmt)(struct sql_transaction_context *ctx,
+			    struct sql_statement *stmt,
+			    unsigned int *affected_rows);
 };
 
 struct sql_db {
@@ -130,6 +157,18 @@
 		     sql_query_callback_t *callback, void *context);
 };
 
+struct sql_prepared_statement {
+	struct sql_db *db;
+};
+
+struct sql_statement {
+	struct sql_db *db;
+
+	pool_t pool;
+	const char *query_template;
+	ARRAY_TYPE(const_string) args;
+};
+
 struct sql_field_map {
 	enum sql_field_type type;
 	size_t offset;
@@ -172,5 +211,6 @@
 
 void sql_transaction_add_query(struct sql_transaction_context *ctx, pool_t pool,
 			       const char *query, unsigned int *affected_rows);
+const char *sql_statement_get_query(struct sql_statement *stmt);
 
 #endif
--- a/src/lib-sql/sql-api.c	Tue Aug 22 11:31:51 2017 +0300
+++ b/src/lib-sql/sql-api.c	Tue Aug 22 11:37:17 2017 +0300
@@ -3,10 +3,16 @@
 #include "lib.h"
 #include "array.h"
 #include "ioloop.h"
+#include "str.h"
 #include "sql-api-private.h"
 
 #include <time.h>
 
+struct default_sql_prepared_statement {
+	struct sql_prepared_statement prep_stmt;
+	char *query_template;
+};
+
 struct sql_db_module_register sql_db_module_register = { 0 };
 ARRAY_TYPE(sql_drivers) sql_drivers;
 
@@ -148,6 +154,226 @@
 	return db->v.query_s(db, query);
 }
 
+static struct sql_prepared_statement *
+default_sql_prepared_statement_init(struct sql_db *db,
+				    const char *query_template)
+{
+	struct default_sql_prepared_statement *prep_stmt;
+
+	prep_stmt = i_new(struct default_sql_prepared_statement, 1);
+	prep_stmt->prep_stmt.db = db;
+	prep_stmt->query_template = i_strdup(query_template);
+	return &prep_stmt->prep_stmt;
+}
+
+static void
+default_sql_prepared_statement_deinit(struct sql_prepared_statement *_prep_stmt)
+{
+	struct default_sql_prepared_statement *prep_stmt =
+		(struct default_sql_prepared_statement *)_prep_stmt;
+
+	i_free(prep_stmt->query_template);
+	i_free(prep_stmt);
+}
+
+static struct sql_statement *
+default_sql_statement_init_prepared(struct sql_prepared_statement *_stmt)
+{
+	struct default_sql_prepared_statement *stmt =
+		(struct default_sql_prepared_statement *)_stmt;
+	return sql_statement_init(_stmt->db, stmt->query_template);
+}
+
+const char *sql_statement_get_query(struct sql_statement *stmt)
+{
+	string_t *query = t_str_new(128);
+	const char *const *args;
+	unsigned int i, args_count, arg_pos = 0;
+
+	args = array_get(&stmt->args, &args_count);
+
+	for (i = 0; stmt->query_template[i] != '\0'; i++) {
+		if (stmt->query_template[i] == '?') {
+			if (arg_pos >= args_count ||
+			    args[arg_pos] == NULL) {
+				i_panic("lib-sql: Missing bind for arg #%u in statement: %s",
+					arg_pos, stmt->query_template);
+			}
+			str_append(query, args[arg_pos++]);
+		} else {
+			str_append_c(query, stmt->query_template[i]);
+		}
+	}
+	if (arg_pos != args_count) {
+		i_panic("lib-sql: Too many bind args (%u) for statement: %s",
+			args_count, stmt->query_template);
+	}
+	return str_c(query);
+}
+
+static void
+default_sql_statement_query(struct sql_statement *stmt,
+			    sql_query_callback_t *callback, void *context)
+{
+	sql_query(stmt->db, sql_statement_get_query(stmt),
+		  callback, context);
+	pool_unref(&stmt->pool);
+}
+
+static struct sql_result *
+default_sql_statement_query_s(struct sql_statement *stmt)
+{
+	struct sql_result *result =
+		sql_query_s(stmt->db, sql_statement_get_query(stmt));
+	pool_unref(&stmt->pool);
+	return result;
+}
+
+static void default_sql_update_stmt(struct sql_transaction_context *ctx,
+				    struct sql_statement *stmt,
+				    unsigned int *affected_rows)
+{
+	ctx->db->v.update(ctx, sql_statement_get_query(stmt),
+			  affected_rows);
+	pool_unref(&stmt->pool);
+}
+
+struct sql_prepared_statement *
+sql_prepared_statement_init(struct sql_db *db, const char *query_template)
+{
+	if (db->v.prepared_statement_init != NULL)
+		return db->v.prepared_statement_init(db, query_template);
+	else
+		return default_sql_prepared_statement_init(db, query_template);
+}
+
+void sql_prepared_statement_deinit(struct sql_prepared_statement **_prep_stmt)
+{
+	struct sql_prepared_statement *prep_stmt = *_prep_stmt;
+
+	*_prep_stmt = NULL;
+	if (prep_stmt->db->v.prepared_statement_deinit != NULL)
+		prep_stmt->db->v.prepared_statement_deinit(prep_stmt);
+	else
+		default_sql_prepared_statement_deinit(prep_stmt);
+}
+
+static void
+sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db)
+{
+	stmt->db = db;
+	p_array_init(&stmt->args, stmt->pool, 8);
+}
+
+struct sql_statement *
+sql_statement_init(struct sql_db *db, const char *query_template)
+{
+	struct sql_statement *stmt;
+
+	if (db->v.statement_init != NULL)
+		stmt = db->v.statement_init(db, query_template);
+	else {
+		pool_t pool = pool_alloconly_create("sql statement", 1024);
+		stmt = p_new(pool, struct sql_statement, 1);
+		stmt->pool = pool;
+	}
+	stmt->query_template = p_strdup(stmt->pool, query_template);
+	sql_statement_init_fields(stmt, db);
+	return stmt;
+}
+
+struct sql_statement *
+sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt)
+{
+	struct sql_statement *stmt;
+
+	if (prep_stmt->db->v.statement_init_prepared == NULL)
+		return default_sql_statement_init_prepared(prep_stmt);
+
+	stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt);
+	sql_statement_init_fields(stmt, prep_stmt->db);
+	return stmt;
+}
+
+void sql_statement_abort(struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_abort != NULL)
+		stmt->db->v.statement_abort(stmt);
+	pool_unref(&stmt->pool);
+}
+
+void sql_statement_set_timestamp(struct sql_statement *stmt,
+				 const struct timespec *ts)
+{
+	if (stmt->db->v.statement_set_timestamp != NULL)
+		stmt->db->v.statement_set_timestamp(stmt, ts);
+}
+
+void sql_statement_bind_str(struct sql_statement *stmt,
+			    unsigned int column_idx, const char *value)
+{
+	const char *escaped_value =
+		p_strdup_printf(stmt->pool, "'%s'",
+				sql_escape_string(stmt->db, value));
+	array_idx_set(&stmt->args, column_idx, &escaped_value);
+
+	if (stmt->db->v.statement_bind_str != NULL)
+		stmt->db->v.statement_bind_str(stmt, column_idx, value);
+}
+
+void sql_statement_bind_binary(struct sql_statement *stmt,
+			       unsigned int column_idx, const void *value,
+			       size_t value_size)
+{
+	const char *value_str =
+		p_strdup_printf(stmt->pool, "%s",
+				sql_escape_blob(stmt->db, value, value_size));
+	array_idx_set(&stmt->args, column_idx, &value_str);
+
+	if (stmt->db->v.statement_bind_binary != NULL) {
+		stmt->db->v.statement_bind_binary(stmt, column_idx,
+						  value, value_size);
+	}
+}
+
+void sql_statement_bind_int64(struct sql_statement *stmt,
+			      unsigned int column_idx, int64_t value)
+{
+	const char *value_str = p_strdup_printf(stmt->pool, "%lld",
+						(long long)value);
+	array_idx_set(&stmt->args, column_idx, &value_str);
+
+	if (stmt->db->v.statement_bind_int64 != NULL)
+		stmt->db->v.statement_bind_int64(stmt, column_idx, value);
+}
+
+#undef sql_statement_query
+void sql_statement_query(struct sql_statement **_stmt,
+			 sql_query_callback_t *callback, void *context)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_query != NULL)
+		stmt->db->v.statement_query(stmt, callback, context);
+	else
+		default_sql_statement_query(stmt, callback, context);
+}
+
+struct sql_result *sql_statement_query_s(struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (stmt->db->v.statement_query_s != NULL)
+		return stmt->db->v.statement_query_s(stmt);
+	else
+		return default_sql_statement_query_s(stmt);
+}
+
 void sql_result_ref(struct sql_result *result)
 {
 	result->refcount++;
@@ -464,12 +690,37 @@
 	ctx->db->v.update(ctx, query, NULL);
 }
 
+void sql_update_stmt(struct sql_transaction_context *ctx,
+		     struct sql_statement **_stmt)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (ctx->db->v.update_stmt != NULL)
+		ctx->db->v.update_stmt(ctx, stmt, NULL);
+	else
+		default_sql_update_stmt(ctx, stmt, NULL);
+}
+
 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
 			 unsigned int *affected_rows)
 {
 	ctx->db->v.update(ctx, query, affected_rows);
 }
 
+void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
+			      struct sql_statement **_stmt,
+			      unsigned int *affected_rows)
+{
+	struct sql_statement *stmt = *_stmt;
+
+	*_stmt = NULL;
+	if (ctx->db->v.update_stmt != NULL)
+		ctx->db->v.update_stmt(ctx, stmt, affected_rows);
+	else
+		default_sql_update_stmt(ctx, stmt, affected_rows);
+}
+
 void sql_db_set_state(struct sql_db *db, enum sql_db_state state)
 {
 	enum sql_db_state old_state = db->state;
--- a/src/lib-sql/sql-api.h	Tue Aug 22 11:31:51 2017 +0300
+++ b/src/lib-sql/sql-api.h	Tue Aug 22 11:37:17 2017 +0300
@@ -8,7 +8,11 @@
 	/* Set if queries are not executed asynchronously */
 	SQL_DB_FLAG_BLOCKING		= 0x01,
 	/* Set if database wants to use connection pooling */
-	SQL_DB_FLAG_POOLED		= 0x02
+	SQL_DB_FLAG_POOLED		= 0x02,
+	/* Prepared statements are supported by the database. If they aren't,
+	   the functions can still be used, but they're just internally
+	   convered into regular statements. */
+	SQL_DB_FLAG_PREP_STATEMENTS	= 0x04,
 };
 
 enum sql_field_type {
@@ -113,6 +117,33 @@
 /* Execute blocking SQL query and return result. */
 struct sql_result *sql_query_s(struct sql_db *db, const char *query);
 
+struct sql_prepared_statement *
+sql_prepared_statement_init(struct sql_db *db, const char *query_template);
+void sql_prepared_statement_deinit(struct sql_prepared_statement **prep_stmt);
+
+struct sql_statement *
+sql_statement_init(struct sql_db *db, const char *query_template);
+struct sql_statement *
+sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt);
+void sql_statement_abort(struct sql_statement **stmt);
+void sql_statement_set_timestamp(struct sql_statement *stmt,
+				 const struct timespec *ts);
+void sql_statement_bind_str(struct sql_statement *stmt,
+			    unsigned int column_idx, const char *value);
+void sql_statement_bind_binary(struct sql_statement *stmt,
+			       unsigned int column_idx, const void *value,
+			       size_t value_size);
+void sql_statement_bind_int64(struct sql_statement *stmt,
+			      unsigned int column_idx, int64_t value);
+void sql_statement_query(struct sql_statement **stmt,
+			 sql_query_callback_t *callback, void *context);
+#define sql_statement_query(stmt, callback, context) \
+	sql_statement_query(stmt, \
+		(sql_query_callback_t *)callback, context + \
+		CALLBACK_TYPECHECK(callback, void (*)( \
+			struct sql_result *, typeof(context))))
+struct sql_result *sql_statement_query_s(struct sql_statement **stmt);
+
 void sql_result_setup_fetch(struct sql_result *result,
 			    const struct sql_field_def *fields,
 			    void *dest, size_t dest_size);
@@ -187,9 +218,14 @@
 
 /* Execute query in given transaction. */
 void sql_update(struct sql_transaction_context *ctx, const char *query);
+void sql_update_stmt(struct sql_transaction_context *ctx,
+		     struct sql_statement **stmt);
 /* Save the number of rows updated by this query. The value is set before
    commit callback is called. */
 void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
 			 unsigned int *affected_rows);
+void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
+			      struct sql_statement **stmt,
+			      unsigned int *affected_rows);
 
 #endif