# HG changeset patch # User Timo Sirainen # Date 1503391037 -10800 # Node ID 7d189977dfe3b99189ed478c133ef667d831e57f # Parent 59b0c5ee8adf46b82d35f2d1d30b544056919186 lib-sql: Add support for prepared SQL statements. This initial implementation doesn't use prepared statements in drivers, but simply generates the query string internally. diff -r 59b0c5ee8adf -r 7d189977dfe3 src/lib-sql/sql-api-private.h --- a/src/lib-sql/sql-api-private.h Tue Aug 22 11:31:51 2017 +0300 +++ b/src/lib-sql/sql-api-private.h Tue Aug 22 11:37:17 2017 +0300 @@ -81,6 +81,33 @@ void (*transaction_commit2)(struct sql_transaction_context *ctx, sql_commit2_callback_t *callback, void *context); + + struct sql_prepared_statement * + (*prepared_statement_init)(struct sql_db *db, + const char *query_template); + void (*prepared_statement_deinit)(struct sql_prepared_statement *prep_stmt); + + + struct sql_statement * + (*statement_init)(struct sql_db *db, const char *query_template); + struct sql_statement * + (*statement_init_prepared)(struct sql_prepared_statement *prep_stmt); + void (*statement_abort)(struct sql_statement *stmt); + void (*statement_set_timestamp)(struct sql_statement *stmt, + const struct timespec *ts); + void (*statement_bind_str)(struct sql_statement *stmt, + unsigned int column_idx, const char *value); + void (*statement_bind_binary)(struct sql_statement *stmt, + unsigned int column_idx, const void *value, + size_t value_size); + void (*statement_bind_int64)(struct sql_statement *stmt, + unsigned int column_idx, int64_t value); + void (*statement_query)(struct sql_statement *stmt, + sql_query_callback_t *callback, void *context); + struct sql_result *(*statement_query_s)(struct sql_statement *stmt); + void (*update_stmt)(struct sql_transaction_context *ctx, + struct sql_statement *stmt, + unsigned int *affected_rows); }; struct sql_db { @@ -130,6 +157,18 @@ sql_query_callback_t *callback, void *context); }; +struct sql_prepared_statement { + struct sql_db *db; +}; + +struct sql_statement { + struct sql_db *db; + + pool_t pool; + const char *query_template; + ARRAY_TYPE(const_string) args; +}; + struct sql_field_map { enum sql_field_type type; size_t offset; @@ -172,5 +211,6 @@ void sql_transaction_add_query(struct sql_transaction_context *ctx, pool_t pool, const char *query, unsigned int *affected_rows); +const char *sql_statement_get_query(struct sql_statement *stmt); #endif diff -r 59b0c5ee8adf -r 7d189977dfe3 src/lib-sql/sql-api.c --- a/src/lib-sql/sql-api.c Tue Aug 22 11:31:51 2017 +0300 +++ b/src/lib-sql/sql-api.c Tue Aug 22 11:37:17 2017 +0300 @@ -3,10 +3,16 @@ #include "lib.h" #include "array.h" #include "ioloop.h" +#include "str.h" #include "sql-api-private.h" #include +struct default_sql_prepared_statement { + struct sql_prepared_statement prep_stmt; + char *query_template; +}; + struct sql_db_module_register sql_db_module_register = { 0 }; ARRAY_TYPE(sql_drivers) sql_drivers; @@ -148,6 +154,226 @@ return db->v.query_s(db, query); } +static struct sql_prepared_statement * +default_sql_prepared_statement_init(struct sql_db *db, + const char *query_template) +{ + struct default_sql_prepared_statement *prep_stmt; + + prep_stmt = i_new(struct default_sql_prepared_statement, 1); + prep_stmt->prep_stmt.db = db; + prep_stmt->query_template = i_strdup(query_template); + return &prep_stmt->prep_stmt; +} + +static void +default_sql_prepared_statement_deinit(struct sql_prepared_statement *_prep_stmt) +{ + struct default_sql_prepared_statement *prep_stmt = + (struct default_sql_prepared_statement *)_prep_stmt; + + i_free(prep_stmt->query_template); + i_free(prep_stmt); +} + +static struct sql_statement * +default_sql_statement_init_prepared(struct sql_prepared_statement *_stmt) +{ + struct default_sql_prepared_statement *stmt = + (struct default_sql_prepared_statement *)_stmt; + return sql_statement_init(_stmt->db, stmt->query_template); +} + +const char *sql_statement_get_query(struct sql_statement *stmt) +{ + string_t *query = t_str_new(128); + const char *const *args; + unsigned int i, args_count, arg_pos = 0; + + args = array_get(&stmt->args, &args_count); + + for (i = 0; stmt->query_template[i] != '\0'; i++) { + if (stmt->query_template[i] == '?') { + if (arg_pos >= args_count || + args[arg_pos] == NULL) { + i_panic("lib-sql: Missing bind for arg #%u in statement: %s", + arg_pos, stmt->query_template); + } + str_append(query, args[arg_pos++]); + } else { + str_append_c(query, stmt->query_template[i]); + } + } + if (arg_pos != args_count) { + i_panic("lib-sql: Too many bind args (%u) for statement: %s", + args_count, stmt->query_template); + } + return str_c(query); +} + +static void +default_sql_statement_query(struct sql_statement *stmt, + sql_query_callback_t *callback, void *context) +{ + sql_query(stmt->db, sql_statement_get_query(stmt), + callback, context); + pool_unref(&stmt->pool); +} + +static struct sql_result * +default_sql_statement_query_s(struct sql_statement *stmt) +{ + struct sql_result *result = + sql_query_s(stmt->db, sql_statement_get_query(stmt)); + pool_unref(&stmt->pool); + return result; +} + +static void default_sql_update_stmt(struct sql_transaction_context *ctx, + struct sql_statement *stmt, + unsigned int *affected_rows) +{ + ctx->db->v.update(ctx, sql_statement_get_query(stmt), + affected_rows); + pool_unref(&stmt->pool); +} + +struct sql_prepared_statement * +sql_prepared_statement_init(struct sql_db *db, const char *query_template) +{ + if (db->v.prepared_statement_init != NULL) + return db->v.prepared_statement_init(db, query_template); + else + return default_sql_prepared_statement_init(db, query_template); +} + +void sql_prepared_statement_deinit(struct sql_prepared_statement **_prep_stmt) +{ + struct sql_prepared_statement *prep_stmt = *_prep_stmt; + + *_prep_stmt = NULL; + if (prep_stmt->db->v.prepared_statement_deinit != NULL) + prep_stmt->db->v.prepared_statement_deinit(prep_stmt); + else + default_sql_prepared_statement_deinit(prep_stmt); +} + +static void +sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db) +{ + stmt->db = db; + p_array_init(&stmt->args, stmt->pool, 8); +} + +struct sql_statement * +sql_statement_init(struct sql_db *db, const char *query_template) +{ + struct sql_statement *stmt; + + if (db->v.statement_init != NULL) + stmt = db->v.statement_init(db, query_template); + else { + pool_t pool = pool_alloconly_create("sql statement", 1024); + stmt = p_new(pool, struct sql_statement, 1); + stmt->pool = pool; + } + stmt->query_template = p_strdup(stmt->pool, query_template); + sql_statement_init_fields(stmt, db); + return stmt; +} + +struct sql_statement * +sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt) +{ + struct sql_statement *stmt; + + if (prep_stmt->db->v.statement_init_prepared == NULL) + return default_sql_statement_init_prepared(prep_stmt); + + stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt); + sql_statement_init_fields(stmt, prep_stmt->db); + return stmt; +} + +void sql_statement_abort(struct sql_statement **_stmt) +{ + struct sql_statement *stmt = *_stmt; + + *_stmt = NULL; + if (stmt->db->v.statement_abort != NULL) + stmt->db->v.statement_abort(stmt); + pool_unref(&stmt->pool); +} + +void sql_statement_set_timestamp(struct sql_statement *stmt, + const struct timespec *ts) +{ + if (stmt->db->v.statement_set_timestamp != NULL) + stmt->db->v.statement_set_timestamp(stmt, ts); +} + +void sql_statement_bind_str(struct sql_statement *stmt, + unsigned int column_idx, const char *value) +{ + const char *escaped_value = + p_strdup_printf(stmt->pool, "'%s'", + sql_escape_string(stmt->db, value)); + array_idx_set(&stmt->args, column_idx, &escaped_value); + + if (stmt->db->v.statement_bind_str != NULL) + stmt->db->v.statement_bind_str(stmt, column_idx, value); +} + +void sql_statement_bind_binary(struct sql_statement *stmt, + unsigned int column_idx, const void *value, + size_t value_size) +{ + const char *value_str = + p_strdup_printf(stmt->pool, "%s", + sql_escape_blob(stmt->db, value, value_size)); + array_idx_set(&stmt->args, column_idx, &value_str); + + if (stmt->db->v.statement_bind_binary != NULL) { + stmt->db->v.statement_bind_binary(stmt, column_idx, + value, value_size); + } +} + +void sql_statement_bind_int64(struct sql_statement *stmt, + unsigned int column_idx, int64_t value) +{ + const char *value_str = p_strdup_printf(stmt->pool, "%lld", + (long long)value); + array_idx_set(&stmt->args, column_idx, &value_str); + + if (stmt->db->v.statement_bind_int64 != NULL) + stmt->db->v.statement_bind_int64(stmt, column_idx, value); +} + +#undef sql_statement_query +void sql_statement_query(struct sql_statement **_stmt, + sql_query_callback_t *callback, void *context) +{ + struct sql_statement *stmt = *_stmt; + + *_stmt = NULL; + if (stmt->db->v.statement_query != NULL) + stmt->db->v.statement_query(stmt, callback, context); + else + default_sql_statement_query(stmt, callback, context); +} + +struct sql_result *sql_statement_query_s(struct sql_statement **_stmt) +{ + struct sql_statement *stmt = *_stmt; + + *_stmt = NULL; + if (stmt->db->v.statement_query_s != NULL) + return stmt->db->v.statement_query_s(stmt); + else + return default_sql_statement_query_s(stmt); +} + void sql_result_ref(struct sql_result *result) { result->refcount++; @@ -464,12 +690,37 @@ ctx->db->v.update(ctx, query, NULL); } +void sql_update_stmt(struct sql_transaction_context *ctx, + struct sql_statement **_stmt) +{ + struct sql_statement *stmt = *_stmt; + + *_stmt = NULL; + if (ctx->db->v.update_stmt != NULL) + ctx->db->v.update_stmt(ctx, stmt, NULL); + else + default_sql_update_stmt(ctx, stmt, NULL); +} + void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query, unsigned int *affected_rows) { ctx->db->v.update(ctx, query, affected_rows); } +void sql_update_stmt_get_rows(struct sql_transaction_context *ctx, + struct sql_statement **_stmt, + unsigned int *affected_rows) +{ + struct sql_statement *stmt = *_stmt; + + *_stmt = NULL; + if (ctx->db->v.update_stmt != NULL) + ctx->db->v.update_stmt(ctx, stmt, affected_rows); + else + default_sql_update_stmt(ctx, stmt, affected_rows); +} + void sql_db_set_state(struct sql_db *db, enum sql_db_state state) { enum sql_db_state old_state = db->state; diff -r 59b0c5ee8adf -r 7d189977dfe3 src/lib-sql/sql-api.h --- a/src/lib-sql/sql-api.h Tue Aug 22 11:31:51 2017 +0300 +++ b/src/lib-sql/sql-api.h Tue Aug 22 11:37:17 2017 +0300 @@ -8,7 +8,11 @@ /* Set if queries are not executed asynchronously */ SQL_DB_FLAG_BLOCKING = 0x01, /* Set if database wants to use connection pooling */ - SQL_DB_FLAG_POOLED = 0x02 + SQL_DB_FLAG_POOLED = 0x02, + /* Prepared statements are supported by the database. If they aren't, + the functions can still be used, but they're just internally + convered into regular statements. */ + SQL_DB_FLAG_PREP_STATEMENTS = 0x04, }; enum sql_field_type { @@ -113,6 +117,33 @@ /* Execute blocking SQL query and return result. */ struct sql_result *sql_query_s(struct sql_db *db, const char *query); +struct sql_prepared_statement * +sql_prepared_statement_init(struct sql_db *db, const char *query_template); +void sql_prepared_statement_deinit(struct sql_prepared_statement **prep_stmt); + +struct sql_statement * +sql_statement_init(struct sql_db *db, const char *query_template); +struct sql_statement * +sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt); +void sql_statement_abort(struct sql_statement **stmt); +void sql_statement_set_timestamp(struct sql_statement *stmt, + const struct timespec *ts); +void sql_statement_bind_str(struct sql_statement *stmt, + unsigned int column_idx, const char *value); +void sql_statement_bind_binary(struct sql_statement *stmt, + unsigned int column_idx, const void *value, + size_t value_size); +void sql_statement_bind_int64(struct sql_statement *stmt, + unsigned int column_idx, int64_t value); +void sql_statement_query(struct sql_statement **stmt, + sql_query_callback_t *callback, void *context); +#define sql_statement_query(stmt, callback, context) \ + sql_statement_query(stmt, \ + (sql_query_callback_t *)callback, context + \ + CALLBACK_TYPECHECK(callback, void (*)( \ + struct sql_result *, typeof(context)))) +struct sql_result *sql_statement_query_s(struct sql_statement **stmt); + void sql_result_setup_fetch(struct sql_result *result, const struct sql_field_def *fields, void *dest, size_t dest_size); @@ -187,9 +218,14 @@ /* Execute query in given transaction. */ void sql_update(struct sql_transaction_context *ctx, const char *query); +void sql_update_stmt(struct sql_transaction_context *ctx, + struct sql_statement **stmt); /* Save the number of rows updated by this query. The value is set before commit callback is called. */ void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query, unsigned int *affected_rows); +void sql_update_stmt_get_rows(struct sql_transaction_context *ctx, + struct sql_statement **stmt, + unsigned int *affected_rows); #endif