view src/lib-sql/sql-api.c @ 22514:7d189977dfe3

lib-sql: Add support for prepared SQL statements. This initial implementation doesn't use prepared statements in drivers, but simply generates the query string internally.
author Timo Sirainen <timo.sirainen@dovecot.fi>
date Tue, 22 Aug 2017 11:37:17 +0300
parents 26e81cfafb1c
children cb108f786fb4
line wrap: on
line source

/* Copyright (c) 2004-2017 Dovecot authors, see the included COPYING file */

#include "lib.h"
#include "array.h"
#include "ioloop.h"
#include "str.h"
#include "sql-api-private.h"

#include <time.h>

struct default_sql_prepared_statement {
	struct sql_prepared_statement prep_stmt;
	char *query_template;
};

struct sql_db_module_register sql_db_module_register = { 0 };
ARRAY_TYPE(sql_drivers) sql_drivers;

void sql_drivers_init(void)
{
	i_array_init(&sql_drivers, 8);
}

void sql_drivers_deinit(void)
{
	array_free(&sql_drivers);
}

static const struct sql_db *sql_driver_lookup(const char *name)
{
	const struct sql_db *const *drivers;
	unsigned int i, count;

	drivers = array_get(&sql_drivers, &count);
	for (i = 0; i < count; i++) {
		if (strcmp(drivers[i]->name, name) == 0)
			return drivers[i];
	}
	return NULL;
}

void sql_driver_register(const struct sql_db *driver)
{
	if (sql_driver_lookup(driver->name) != NULL) {
		i_fatal("sql_driver_register(%s): Already registered",
			driver->name);
	}
	array_append(&sql_drivers, &driver, 1);
}

void sql_driver_unregister(const struct sql_db *driver)
{
	const struct sql_db *const *drivers;
	unsigned int i, count;

	drivers = array_get(&sql_drivers, &count);
	for (i = 0; i < count; i++) {
		if (drivers[i] == driver) {
			array_delete(&sql_drivers, i, 1);
			break;
		}
	}
}

struct sql_db *sql_init(const char *db_driver, const char *connect_string)
{
	const struct sql_db *driver;
	struct sql_db *db;

	i_assert(connect_string != NULL);

	driver = sql_driver_lookup(db_driver);
	if (driver == NULL)
		i_fatal("Unknown database driver '%s'", db_driver);

	if ((driver->flags & SQL_DB_FLAG_POOLED) == 0)
		db = driver->v.init(connect_string);
	else
		db = driver_sqlpool_init(connect_string, driver);
	i_array_init(&db->module_contexts, 5);
	return db;
}

void sql_deinit(struct sql_db **_db)
{
	struct sql_db *db = *_db;

	*_db = NULL;

	if (db->to_reconnect != NULL)
		timeout_remove(&db->to_reconnect);
	db->v.deinit(db);
}

enum sql_db_flags sql_get_flags(struct sql_db *db)
{
	return db->flags;
}

int sql_connect(struct sql_db *db)
{
	time_t now;

	switch (db->state) {
	case SQL_DB_STATE_DISCONNECTED:
		break;
	case SQL_DB_STATE_CONNECTING:
		return 0;
	default:
		return 1;
	}

	/* don't try reconnecting more than once a second */
	now = time(NULL);
	if (db->last_connect_try + (time_t)db->connect_delay > now)
		return -1;
	db->last_connect_try = now;

	return db->v.connect(db);
}

void sql_disconnect(struct sql_db *db)
{
	if (db->to_reconnect != NULL)
		timeout_remove(&db->to_reconnect);
	db->v.disconnect(db);
}

const char *sql_escape_string(struct sql_db *db, const char *string)
{
	return db->v.escape_string(db, string);
}

const char *sql_escape_blob(struct sql_db *db,
			    const unsigned char *data, size_t size)
{
	return db->v.escape_blob(db, data, size);
}

void sql_exec(struct sql_db *db, const char *query)
{
	db->v.exec(db, query);
}

#undef sql_query
void sql_query(struct sql_db *db, const char *query,
	       sql_query_callback_t *callback, void *context)
{
	db->v.query(db, query, callback, context);
}

struct sql_result *sql_query_s(struct sql_db *db, const char *query)
{
	return db->v.query_s(db, query);
}

static struct sql_prepared_statement *
default_sql_prepared_statement_init(struct sql_db *db,
				    const char *query_template)
{
	struct default_sql_prepared_statement *prep_stmt;

	prep_stmt = i_new(struct default_sql_prepared_statement, 1);
	prep_stmt->prep_stmt.db = db;
	prep_stmt->query_template = i_strdup(query_template);
	return &prep_stmt->prep_stmt;
}

static void
default_sql_prepared_statement_deinit(struct sql_prepared_statement *_prep_stmt)
{
	struct default_sql_prepared_statement *prep_stmt =
		(struct default_sql_prepared_statement *)_prep_stmt;

	i_free(prep_stmt->query_template);
	i_free(prep_stmt);
}

static struct sql_statement *
default_sql_statement_init_prepared(struct sql_prepared_statement *_stmt)
{
	struct default_sql_prepared_statement *stmt =
		(struct default_sql_prepared_statement *)_stmt;
	return sql_statement_init(_stmt->db, stmt->query_template);
}

const char *sql_statement_get_query(struct sql_statement *stmt)
{
	string_t *query = t_str_new(128);
	const char *const *args;
	unsigned int i, args_count, arg_pos = 0;

	args = array_get(&stmt->args, &args_count);

	for (i = 0; stmt->query_template[i] != '\0'; i++) {
		if (stmt->query_template[i] == '?') {
			if (arg_pos >= args_count ||
			    args[arg_pos] == NULL) {
				i_panic("lib-sql: Missing bind for arg #%u in statement: %s",
					arg_pos, stmt->query_template);
			}
			str_append(query, args[arg_pos++]);
		} else {
			str_append_c(query, stmt->query_template[i]);
		}
	}
	if (arg_pos != args_count) {
		i_panic("lib-sql: Too many bind args (%u) for statement: %s",
			args_count, stmt->query_template);
	}
	return str_c(query);
}

static void
default_sql_statement_query(struct sql_statement *stmt,
			    sql_query_callback_t *callback, void *context)
{
	sql_query(stmt->db, sql_statement_get_query(stmt),
		  callback, context);
	pool_unref(&stmt->pool);
}

static struct sql_result *
default_sql_statement_query_s(struct sql_statement *stmt)
{
	struct sql_result *result =
		sql_query_s(stmt->db, sql_statement_get_query(stmt));
	pool_unref(&stmt->pool);
	return result;
}

static void default_sql_update_stmt(struct sql_transaction_context *ctx,
				    struct sql_statement *stmt,
				    unsigned int *affected_rows)
{
	ctx->db->v.update(ctx, sql_statement_get_query(stmt),
			  affected_rows);
	pool_unref(&stmt->pool);
}

struct sql_prepared_statement *
sql_prepared_statement_init(struct sql_db *db, const char *query_template)
{
	if (db->v.prepared_statement_init != NULL)
		return db->v.prepared_statement_init(db, query_template);
	else
		return default_sql_prepared_statement_init(db, query_template);
}

void sql_prepared_statement_deinit(struct sql_prepared_statement **_prep_stmt)
{
	struct sql_prepared_statement *prep_stmt = *_prep_stmt;

	*_prep_stmt = NULL;
	if (prep_stmt->db->v.prepared_statement_deinit != NULL)
		prep_stmt->db->v.prepared_statement_deinit(prep_stmt);
	else
		default_sql_prepared_statement_deinit(prep_stmt);
}

static void
sql_statement_init_fields(struct sql_statement *stmt, struct sql_db *db)
{
	stmt->db = db;
	p_array_init(&stmt->args, stmt->pool, 8);
}

struct sql_statement *
sql_statement_init(struct sql_db *db, const char *query_template)
{
	struct sql_statement *stmt;

	if (db->v.statement_init != NULL)
		stmt = db->v.statement_init(db, query_template);
	else {
		pool_t pool = pool_alloconly_create("sql statement", 1024);
		stmt = p_new(pool, struct sql_statement, 1);
		stmt->pool = pool;
	}
	stmt->query_template = p_strdup(stmt->pool, query_template);
	sql_statement_init_fields(stmt, db);
	return stmt;
}

struct sql_statement *
sql_statement_init_prepared(struct sql_prepared_statement *prep_stmt)
{
	struct sql_statement *stmt;

	if (prep_stmt->db->v.statement_init_prepared == NULL)
		return default_sql_statement_init_prepared(prep_stmt);

	stmt = prep_stmt->db->v.statement_init_prepared(prep_stmt);
	sql_statement_init_fields(stmt, prep_stmt->db);
	return stmt;
}

void sql_statement_abort(struct sql_statement **_stmt)
{
	struct sql_statement *stmt = *_stmt;

	*_stmt = NULL;
	if (stmt->db->v.statement_abort != NULL)
		stmt->db->v.statement_abort(stmt);
	pool_unref(&stmt->pool);
}

void sql_statement_set_timestamp(struct sql_statement *stmt,
				 const struct timespec *ts)
{
	if (stmt->db->v.statement_set_timestamp != NULL)
		stmt->db->v.statement_set_timestamp(stmt, ts);
}

void sql_statement_bind_str(struct sql_statement *stmt,
			    unsigned int column_idx, const char *value)
{
	const char *escaped_value =
		p_strdup_printf(stmt->pool, "'%s'",
				sql_escape_string(stmt->db, value));
	array_idx_set(&stmt->args, column_idx, &escaped_value);

	if (stmt->db->v.statement_bind_str != NULL)
		stmt->db->v.statement_bind_str(stmt, column_idx, value);
}

void sql_statement_bind_binary(struct sql_statement *stmt,
			       unsigned int column_idx, const void *value,
			       size_t value_size)
{
	const char *value_str =
		p_strdup_printf(stmt->pool, "%s",
				sql_escape_blob(stmt->db, value, value_size));
	array_idx_set(&stmt->args, column_idx, &value_str);

	if (stmt->db->v.statement_bind_binary != NULL) {
		stmt->db->v.statement_bind_binary(stmt, column_idx,
						  value, value_size);
	}
}

void sql_statement_bind_int64(struct sql_statement *stmt,
			      unsigned int column_idx, int64_t value)
{
	const char *value_str = p_strdup_printf(stmt->pool, "%lld",
						(long long)value);
	array_idx_set(&stmt->args, column_idx, &value_str);

	if (stmt->db->v.statement_bind_int64 != NULL)
		stmt->db->v.statement_bind_int64(stmt, column_idx, value);
}

#undef sql_statement_query
void sql_statement_query(struct sql_statement **_stmt,
			 sql_query_callback_t *callback, void *context)
{
	struct sql_statement *stmt = *_stmt;

	*_stmt = NULL;
	if (stmt->db->v.statement_query != NULL)
		stmt->db->v.statement_query(stmt, callback, context);
	else
		default_sql_statement_query(stmt, callback, context);
}

struct sql_result *sql_statement_query_s(struct sql_statement **_stmt)
{
	struct sql_statement *stmt = *_stmt;

	*_stmt = NULL;
	if (stmt->db->v.statement_query_s != NULL)
		return stmt->db->v.statement_query_s(stmt);
	else
		return default_sql_statement_query_s(stmt);
}

void sql_result_ref(struct sql_result *result)
{
	result->refcount++;
}

void sql_result_unref(struct sql_result *result)
{
	i_assert(result->refcount > 0);
	if (--result->refcount > 0)
		return;

	i_free(result->map);
	result->v.free(result);
}

static const struct sql_field_def *
sql_field_def_find(const struct sql_field_def *fields, const char *name)
{
	unsigned int i;

	for (i = 0; fields[i].name != NULL; i++) {
		if (strcasecmp(fields[i].name, name) == 0)
			return &fields[i];
	}
	return NULL;
}

static void
sql_result_build_map(struct sql_result *result,
		     const struct sql_field_def *fields, size_t dest_size)
{
	const struct sql_field_def *def;
	const char *name;
	unsigned int i, count, field_size = 0;

	count = sql_result_get_fields_count(result);

	result->map_size = count;
	result->map = i_new(struct sql_field_map, result->map_size);
	for (i = 0; i < count; i++) {
		name = sql_result_get_field_name(result, i);
		def = sql_field_def_find(fields, name);
		if (def != NULL) {
			result->map[i].type = def->type;
			result->map[i].offset = def->offset;
			switch (def->type) {
			case SQL_TYPE_STR:
				field_size = sizeof(const char *);
				break;
			case SQL_TYPE_UINT:
				field_size = sizeof(unsigned int);
				break;
			case SQL_TYPE_ULLONG:
				field_size = sizeof(unsigned long long);
				break;
			case SQL_TYPE_BOOL:
				field_size = sizeof(bool);
				break;
			}
			i_assert(def->offset + field_size <= dest_size);
		} else {
			result->map[i].offset = (size_t)-1;
		}
	}
}

void sql_result_setup_fetch(struct sql_result *result,
			    const struct sql_field_def *fields,
			    void *dest, size_t dest_size)
{
	if (result->map == NULL)
		sql_result_build_map(result, fields, dest_size);
	result->fetch_dest = dest;
	result->fetch_dest_size = dest_size;
}

static void sql_result_fetch(struct sql_result *result)
{
	unsigned int i, count;
	const char *value;
	void *ptr;

	memset(result->fetch_dest, 0, result->fetch_dest_size);
	count = result->map_size;
	for (i = 0; i < count; i++) {
		if (result->map[i].offset == (size_t)-1)
			continue;

		value = sql_result_get_field_value(result, i);
		ptr = STRUCT_MEMBER_P(result->fetch_dest,
				      result->map[i].offset);

		switch (result->map[i].type) {
		case SQL_TYPE_STR: {
			*((const char **)ptr) = value;
			break;
		}
		case SQL_TYPE_UINT: {
			if (value != NULL &&
			    str_to_uint(value, (unsigned int *)ptr) < 0)
				i_error("sql: Value not uint: %s", value);
			break;
		}
		case SQL_TYPE_ULLONG: {
			if (value != NULL &&
			    str_to_ullong(value, (unsigned long long *)ptr) < 0)
				i_error("sql: Value not ullong: %s", value);
			break;
		}
		case SQL_TYPE_BOOL: {
			if (value != NULL && (*value == 't' || *value == '1'))
				*((bool *)ptr) = TRUE;
			break;
		}
		}
	}
}

int sql_result_next_row(struct sql_result *result)
{
	int ret;

	if ((ret = result->v.next_row(result)) <= 0)
		return ret;

	if (result->fetch_dest != NULL)
		sql_result_fetch(result);
	return 1;
}

#undef sql_result_more
void sql_result_more(struct sql_result **result,
		     sql_query_callback_t *callback, void *context)
{
	i_assert((*result)->v.more != NULL);

	(*result)->v.more(result, TRUE, callback, context);
}

static void
sql_result_more_sync_callback(struct sql_result *result, void *context)
{
	struct sql_result **dest_result = context;

	*dest_result = result;
}

void sql_result_more_s(struct sql_result **result)
{
	i_assert((*result)->v.more != NULL);

	(*result)->v.more(result, FALSE, sql_result_more_sync_callback, result);
	/* the callback must have been called */
	i_assert(*result != NULL);
}

unsigned int sql_result_get_fields_count(struct sql_result *result)
{
	return result->v.get_fields_count(result);
}

const char *sql_result_get_field_name(struct sql_result *result,
				      unsigned int idx)
{
	return result->v.get_field_name(result, idx);
}

int sql_result_find_field(struct sql_result *result, const char *field_name)
{
	return result->v.find_field(result, field_name);
}

const char *sql_result_get_field_value(struct sql_result *result,
				       unsigned int idx)
{
	return result->v.get_field_value(result, idx);
}

const unsigned char *
sql_result_get_field_value_binary(struct sql_result *result,
				  unsigned int idx, size_t *size_r)
{
	return result->v.get_field_value_binary(result, idx, size_r);
}

const char *sql_result_find_field_value(struct sql_result *result,
					const char *field_name)
{
	return result->v.find_field_value(result, field_name);
}

const char *const *sql_result_get_values(struct sql_result *result)
{
	return result->v.get_values(result);
}

const char *sql_result_get_error(struct sql_result *result)
{
	return result->v.get_error(result);
}

enum sql_result_error_type sql_result_get_error_type(struct sql_result *result)
{
	return result->error_type;
}

static void
sql_result_not_connected_free(struct sql_result *result ATTR_UNUSED)
{
}

static int
sql_result_not_connected_next_row(struct sql_result *result ATTR_UNUSED)
{
	return -1;
}

static const char *
sql_result_not_connected_get_error(struct sql_result *result ATTR_UNUSED)
{
	return SQL_ERRSTR_NOT_CONNECTED;
}

struct sql_transaction_context *sql_transaction_begin(struct sql_db *db)
{
	return db->v.transaction_begin(db);
}

struct sql_commit1_wrap_ctx {
	sql_commit_callback_t *callback;
	void *context;
};

static void sql_commit1_wrap(const struct sql_commit_result *result,
			     void *context)
{
	struct sql_commit1_wrap_ctx *ctx = context;

	ctx->callback(result->error, ctx->context);
	i_free(ctx);
}

#undef sql_transaction_commit
void sql_transaction_commit(struct sql_transaction_context **_ctx,
			    sql_commit_callback_t *callback, void *context)
{
	struct sql_transaction_context *ctx = *_ctx;

	*_ctx = NULL;
	if (ctx->db->v.transaction_commit != NULL)
		ctx->db->v.transaction_commit(ctx, callback, context);
	else {
		struct sql_commit1_wrap_ctx *wrap;

		wrap = i_new(struct sql_commit1_wrap_ctx, 1);
		wrap->callback = callback;
		wrap->context = context;
		ctx->db->v.transaction_commit2(ctx, sql_commit1_wrap, wrap);
	}
}

struct sql_commit2_wrap_ctx {
	sql_commit2_callback_t *callback;
	void *context;
};

static void sql_commit2_wrap(const char *error, void *context)
{
	struct sql_commit2_wrap_ctx *ctx = context;
	struct sql_commit_result result = { .error = error };

	ctx->callback(&result, ctx->context);
	i_free(ctx);
}

#undef sql_transaction_commit2
void sql_transaction_commit2(struct sql_transaction_context **_ctx,
			     sql_commit2_callback_t *callback, void *context)
{
	struct sql_transaction_context *ctx = *_ctx;

	*_ctx = NULL;
	if (ctx->db->v.transaction_commit2 != NULL)
		ctx->db->v.transaction_commit2(ctx, callback, context);
	else {
		struct sql_commit2_wrap_ctx *wrap;

		wrap = i_new(struct sql_commit2_wrap_ctx, 1);
		wrap->callback = callback;
		wrap->context = context;
		ctx->db->v.transaction_commit(ctx, sql_commit2_wrap, wrap);
	}
}

int sql_transaction_commit_s(struct sql_transaction_context **_ctx,
			     const char **error_r)
{
	struct sql_transaction_context *ctx = *_ctx;

	*_ctx = NULL;
	return ctx->db->v.transaction_commit_s(ctx, error_r);
}

void sql_transaction_rollback(struct sql_transaction_context **_ctx)
{
	struct sql_transaction_context *ctx = *_ctx;

	*_ctx = NULL;
	ctx->db->v.transaction_rollback(ctx);
}

void sql_update(struct sql_transaction_context *ctx, const char *query)
{
	ctx->db->v.update(ctx, query, NULL);
}

void sql_update_stmt(struct sql_transaction_context *ctx,
		     struct sql_statement **_stmt)
{
	struct sql_statement *stmt = *_stmt;

	*_stmt = NULL;
	if (ctx->db->v.update_stmt != NULL)
		ctx->db->v.update_stmt(ctx, stmt, NULL);
	else
		default_sql_update_stmt(ctx, stmt, NULL);
}

void sql_update_get_rows(struct sql_transaction_context *ctx, const char *query,
			 unsigned int *affected_rows)
{
	ctx->db->v.update(ctx, query, affected_rows);
}

void sql_update_stmt_get_rows(struct sql_transaction_context *ctx,
			      struct sql_statement **_stmt,
			      unsigned int *affected_rows)
{
	struct sql_statement *stmt = *_stmt;

	*_stmt = NULL;
	if (ctx->db->v.update_stmt != NULL)
		ctx->db->v.update_stmt(ctx, stmt, affected_rows);
	else
		default_sql_update_stmt(ctx, stmt, affected_rows);
}

void sql_db_set_state(struct sql_db *db, enum sql_db_state state)
{
	enum sql_db_state old_state = db->state;

	if (db->state == state)
		return;

	db->state = state;
	if (db->state_change_callback != NULL) {
		db->state_change_callback(db, old_state,
					  db->state_change_context);
	}
}

void sql_transaction_add_query(struct sql_transaction_context *ctx, pool_t pool,
			       const char *query, unsigned int *affected_rows)
{
	struct sql_transaction_query *tquery;

	tquery = p_new(pool, struct sql_transaction_query, 1);
	tquery->trans = ctx;
	tquery->query = p_strdup(pool, query);
	tquery->affected_rows = affected_rows;

	if (ctx->head == NULL)
		ctx->head = tquery;
	else
		ctx->tail->next = tquery;
	ctx->tail = tquery;
}

struct sql_result sql_not_connected_result = {
	.v = {
		sql_result_not_connected_free,
		sql_result_not_connected_next_row,
		NULL, NULL, NULL, NULL, NULL, NULL, NULL,
		sql_result_not_connected_get_error,
		NULL,
	},
	.failed_try_retry = TRUE
};