changeset 22521:6be650aee0b5

dict-sql: Use sql_statement_bind_*()
author Timo Sirainen <timo.sirainen@dovecot.fi>
date Tue, 15 Aug 2017 16:38:33 +0300
parents 124286dedf19
children 486ddf6a61eb
files src/lib-dict/dict-sql.c
diffstat 1 files changed, 162 insertions(+), 73 deletions(-) [+]
line wrap: on
line diff
--- a/src/lib-dict/dict-sql.c	Tue Aug 15 15:37:01 2017 +0300
+++ b/src/lib-dict/dict-sql.c	Tue Aug 15 16:38:33 2017 +0300
@@ -33,6 +33,16 @@
 	unsigned int has_on_duplicate_key:1;
 };
 
+struct sql_dict_param {
+	enum dict_sql_type value_type;
+
+	const char *value_str;
+	int64_t value_int64;
+	const void *value_binary;
+	size_t value_binary_size;
+};
+ARRAY_DEFINE_TYPE(sql_dict_param, struct sql_dict_param);
+
 struct sql_dict_iterate_context {
 	struct dict_iterate_context ctx;
 	pool_t pool;
@@ -220,39 +230,74 @@
 	return NULL;
 }
 
+static void
+sql_dict_statement_bind(struct sql_statement *stmt, unsigned int column_idx,
+			const struct sql_dict_param *param)
+{
+	switch (param->value_type) {
+	case DICT_SQL_TYPE_STRING:
+		sql_statement_bind_str(stmt, column_idx, param->value_str);
+		break;
+	case DICT_SQL_TYPE_INT:
+	case DICT_SQL_TYPE_UINT:
+		sql_statement_bind_int64(stmt, column_idx, param->value_int64);
+		break;
+	case DICT_SQL_TYPE_HEXBLOB:
+		sql_statement_bind_binary(stmt, column_idx, param->value_binary,
+					  param->value_binary_size);
+		break;
+	}
+}
+
+static struct sql_statement *
+sql_dict_statement_init(struct sql_db *db, const char *query,
+			const ARRAY_TYPE(sql_dict_param) *params)
+{
+	struct sql_statement *stmt = sql_statement_init(db, query);
+	const struct sql_dict_param *param;
+
+	array_foreach(params, param) {
+		sql_dict_statement_bind(stmt, array_foreach_idx(params, param),
+					param);
+	}
+	return stmt;
+}
+
 static int
-sql_dict_value_escape(string_t *str, struct sql_dict *dict,
-		      const struct dict_sql_map *map,
-		      enum dict_sql_type value_type, const char *field_name,
-		      const char *value, const char *value_suffix,
-		      const char **error_r)
+sql_dict_value_get(const struct dict_sql_map *map,
+		   enum dict_sql_type value_type, const char *field_name,
+		   const char *value, const char *value_suffix,
+		   ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
+	struct sql_dict_param *param;
 	buffer_t *buf;
-	int64_t snum;
-	uint64_t num;
+
+	param = array_append_space(params);
+	param->value_type = value_type;
 
 	switch (value_type) {
 	case DICT_SQL_TYPE_STRING:
-		str_printfa(str, "'%s%s'", sql_escape_string(dict->db, value),
-			    value_suffix);
+		if (value_suffix[0] != '\0')
+			value = t_strconcat(value, value_suffix, NULL);
+		param->value_str = value;
 		return 0;
 	case DICT_SQL_TYPE_INT:
-		if (value_suffix[0] != '\0' || str_to_int64(value, &snum) < 0) {
+		if (value_suffix[0] != '\0' ||
+		    str_to_int64(value, &param->value_int64) < 0) {
 			*error_r = t_strdup_printf(
 				"%s field's value isn't 64bit signed integer: %s%s (in pattern: %s)",
 				field_name, value, value_suffix, map->pattern);
 			return -1;
 		}
-		str_printfa(str, "%"PRId64, snum);
 		return 0;
 	case DICT_SQL_TYPE_UINT:
-		if (value_suffix[0] != '\0' || str_to_uint64(value, &num) < 0) {
+		if (value_suffix[0] != '\0' || value[0] == '-' ||
+		    str_to_int64(value, &param->value_int64) < 0) {
 			*error_r = t_strdup_printf(
 				"%s field's value isn't 64bit unsigned integer: %s%s (in pattern: %s)",
 				field_name, value, value_suffix, map->pattern);
 			return -1;
 		}
-		str_printfa(str, "%llu", (unsigned long long)num);
 		return 0;
 	case DICT_SQL_TYPE_HEXBLOB:
 		break;
@@ -267,26 +312,28 @@
 		return -1;
 	}
 	str_append(buf, value_suffix);
-	str_append(str, sql_escape_blob(dict->db, buf->data, buf->used));
+	param->value_binary = buf->data;
+	param->value_binary_size = buf->used;
 	return 0;
 }
 
 static int
-sql_dict_field_escape_value(string_t *str, struct sql_dict *dict,
-			    const struct dict_sql_map *map,
-			    const struct dict_sql_field *field,
-			    const char *value, const char *value_suffix,
-			    const char **error_r)
+sql_dict_field_get_value(const struct dict_sql_map *map,
+			 const struct dict_sql_field *field,
+			 const char *value, const char *value_suffix,
+			 ARRAY_TYPE(sql_dict_param) *params,
+			 const char **error_r)
 {
-	return sql_dict_value_escape(str, dict, map, field->value_type,
-				     field->name, value, value_suffix, error_r);
+	return sql_dict_value_get(map, field->value_type, field->name,
+				  value, value_suffix, params, error_r);
 }
 
 static int
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
 		     const ARRAY_TYPE(const_string) *values_arr,
 		     char key1, enum sql_recurse_type recurse_type,
-		     string_t *query, const char **error_r)
+		     string_t *query, ARRAY_TYPE(sql_dict_param) *params,
+		     const char **error_r)
 {
 	const struct dict_sql_field *sql_fields;
 	const char *const *values;
@@ -309,9 +356,9 @@
 	for (i = 0; i < exact_count; i++) {
 		if (i > 0)
 			str_append(query, " AND");
-		str_printfa(query, " %s = ", sql_fields[i].name);
-		if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-						values[i], "", error_r) < 0)
+		str_printfa(query, " %s = ?", sql_fields[i].name);
+		if (sql_dict_field_get_value(map, &sql_fields[i], values[i], "",
+					     params, error_r) < 0)
 			return -1;
 	}
 	switch (recurse_type) {
@@ -321,13 +368,15 @@
 		if (i > 0)
 			str_append(query, " AND");
 		if (i < count2) {
-			str_printfa(query, " %s LIKE ", sql_fields[i].name);
-			if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-							values[i], "/%", error_r) < 0)
+			str_printfa(query, " %s LIKE ?", sql_fields[i].name);
+			if (sql_dict_field_get_value(map, &sql_fields[i],
+						     values[i], "/%",
+						     params, error_r) < 0)
 				return -1;
-			str_printfa(query, " AND %s NOT LIKE ", sql_fields[i].name);
-			if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-							values[i], "/%/%", error_r) < 0)
+			str_printfa(query, " AND %s NOT LIKE ?", sql_fields[i].name);
+			if (sql_dict_field_get_value(map, &sql_fields[i],
+						     values[i], "/%/%",
+						     params, error_r) < 0)
 				return -1;
 		} else {
 			str_printfa(query, " %s LIKE '%%' AND "
@@ -341,8 +390,9 @@
 				str_append(query, " AND");
 			str_printfa(query, " %s LIKE ",
 				    sql_fields[i].name);
-			if (sql_dict_field_escape_value(query, dict, map, &sql_fields[i],
-							values[i], "/%", error_r) < 0)
+			if (sql_dict_field_get_value(map, &sql_fields[i],
+						     values[i], "/%",
+						     params, error_r) < 0)
 				return -1;
 		}
 		break;
@@ -359,7 +409,7 @@
 static int
 sql_lookup_get_query(struct sql_dict *dict, const char *key,
 		     string_t *query, const struct dict_sql_map **map_r,
-		     const char **error_r)
+		     ARRAY_TYPE(sql_dict_param) *params, const char **error_r)
 {
 	const struct dict_sql_map *map;
 	ARRAY_TYPE(const_string) values;
@@ -374,7 +424,8 @@
 	str_printfa(query, "SELECT %s FROM %s",
 		    map->value_field, map->table);
 	if (sql_dict_where_build(dict, map, &values, key[0],
-				 SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+				 SQL_DICT_RECURSE_NONE, query,
+				 params, &error) < 0) {
 		*error_r = t_strdup_printf(
 			"sql dict lookup: Failed to lookup key %s: %s", key, error);
 		return -1;
@@ -446,18 +497,20 @@
 	const struct dict_sql_map *map;
 	struct sql_result *result = NULL;
 	string_t *query = t_str_new(256);
+	ARRAY_TYPE(sql_dict_param) params;
 	const char *error;
 	int ret;
 
-	ret = sql_lookup_get_query(dict, key, query, &map, &error);
-	if (ret < 0) {
+	*value_r = NULL;
+
+	t_array_init(&params, 4);
+	if (sql_lookup_get_query(dict, key, query, &map, &params, &error) < 0) {
 		i_error("%s", error);
-		*value_r = NULL;
 		return -1;
 	}
 
 	struct sql_statement *stmt =
-		sql_statement_init(dict->db, str_c(query));
+		sql_dict_statement_init(dict->db, str_c(query), &params);
 	result = sql_statement_query_s(&stmt);
 	ret = sql_result_next_row(result);
 	if (ret <= 0) {
@@ -514,9 +567,11 @@
 	const struct dict_sql_map *map;
 	struct sql_dict_lookup_context *ctx;
 	string_t *query = t_str_new(256);
+	ARRAY_TYPE(sql_dict_param) params;
 	const char *error;
 
-	if (sql_lookup_get_query(dict, key, query, &map, &error) < 0) {
+	t_array_init(&params, 4);
+	if (sql_lookup_get_query(dict, key, query, &map, &params, &error) < 0) {
 		struct dict_lookup_result result;
 
 		i_zero(&result);
@@ -529,7 +584,7 @@
 		ctx->context = context;
 		ctx->map = map;
 		struct sql_statement *stmt =
-			sql_statement_init(dict->db, str_c(query));
+			sql_dict_statement_init(dict->db, str_c(query), &params);
 		sql_statement_query(&stmt, sql_dict_lookup_async_callback, ctx);
 	}
 }
@@ -570,7 +625,9 @@
 
 static int
 sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
-				  string_t *query, const char **error_r)
+				  string_t *query,
+				  ARRAY_TYPE(sql_dict_param) *params,
+				  const char **error_r)
 {
 	struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
 	const struct dict_sql_map *map;
@@ -622,7 +679,7 @@
 		recurse_type = SQL_DICT_RECURSE_ONE;
 	if (sql_dict_where_build(dict, map, &values,
 				 ctx->paths[ctx->path_idx][0],
-				 recurse_type, query, error_r) < 0)
+				 recurse_type, query, params, error_r) < 0)
 		return -1;
 
 	if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
@@ -658,11 +715,13 @@
 {
 	struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
 	string_t *query = t_str_new(256);
+	ARRAY_TYPE(sql_dict_param) params;
 	const char *error;
 	unsigned int path_idx = ctx->path_idx;
 	int ret;
 
-	ret = sql_dict_iterate_build_next_query(ctx, query, &error);
+	t_array_init(&params, 4);
+	ret = sql_dict_iterate_build_next_query(ctx, query, &params, &error);
 	if (ret < 0) {
 		/* failed */
 		i_error("sql dict iterate failed for %s: %s",
@@ -674,7 +733,7 @@
 	}
 
 	struct sql_statement *stmt =
-		sql_statement_init(dict->db, str_c(query));
+		sql_dict_statement_init(dict->db, str_c(query), &params);
 	if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) {
 		ctx->result = sql_statement_query_s(&stmt);
 	} else {
@@ -928,10 +987,12 @@
 
 static struct sql_statement *
 sql_dict_transaction_stmt_init(struct sql_dict_transaction_context *ctx,
-			       const char *query)
+			       const char *query,
+			       const ARRAY_TYPE(sql_dict_param) *params)
 {
 	struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
-	struct sql_statement *stmt = sql_statement_init(dict->db, query);
+	struct sql_statement *stmt =
+		sql_dict_statement_init(dict->db, query, params);
 
 	if (ctx->ctx.timestamp.tv_sec != 0)
 		sql_statement_set_timestamp(stmt, &ctx->ctx.timestamp);
@@ -952,7 +1013,9 @@
 };
 
 static int sql_dict_set_query(const struct dict_sql_build_query *build,
-			      const char **query_r, const char **error_r)
+			      const char **query_r,
+			      ARRAY_TYPE(sql_dict_param) *params,
+			      const char **error_r)
 {
 	struct sql_dict *dict = build->dict;
 	const struct dict_sql_build_query_field *fields;
@@ -978,9 +1041,10 @@
 
 		enum dict_sql_type value_type =
 			fields[i].map->value_types[0];
-		if (sql_dict_value_escape(suffix, dict, fields[i].map,
-					  value_type, "value", fields[i].value,
-					  "", error_r) < 0)
+		str_append_c(suffix, '?');
+		if (sql_dict_value_get(fields[i].map,
+				       value_type, "value", fields[i].value,
+				       "", params, error_r) < 0)
 			return -1;
 	}
 	if (build->key1 == DICT_PATH_PRIVATE[0]) {
@@ -995,9 +1059,10 @@
 	i_assert(count == count2);
 	for (i = 0; i < count; i++) {
 		str_printfa(prefix, ",%s", sql_fields[i].name);
-		str_append_c(suffix, ',');
-		if (sql_dict_field_escape_value(suffix, dict, fields[0].map, &sql_fields[i],
-						extra_values[i], "", error_r) < 0)
+		str_append(suffix, ",?");
+		if (sql_dict_field_get_value(fields[0].map, &sql_fields[i],
+					     extra_values[i], "",
+					     params, error_r) < 0)
 			return -1;
 	}
 
@@ -1019,9 +1084,10 @@
 
 		enum dict_sql_type value_type =
 			fields[i].map->value_types[0];
-		if (sql_dict_value_escape(prefix, dict, fields[i].map,
-					  value_type, "value", fields[i].value,
-					  "", error_r) < 0)
+		str_append_c(prefix, '?');
+		if (sql_dict_value_get(fields[i].map,
+				       value_type, "value", fields[i].value,
+				       "", params, error_r) < 0)
 			return -1;
 	}
 	*query_r = str_c(prefix);
@@ -1030,7 +1096,8 @@
 
 static int
 sql_dict_update_query(const struct dict_sql_build_query *build,
-		      const char **query_r, const char **error_r)
+		      const char **query_r, ARRAY_TYPE(sql_dict_param) *params,
+		      const char **error_r)
 {
 	struct sql_dict *dict = build->dict;
 	const struct dict_sql_build_query_field *fields;
@@ -1047,15 +1114,13 @@
 			t_strcut(fields[i].map->value_field, ',');
 		if (i > 0)
 			str_append_c(query, ',');
-		str_printfa(query, "%s=%s", first_value_field,
+		str_printfa(query, "%s=%s+?", first_value_field,
 			    first_value_field);
-		if (fields[i].value[0] != '-')
-			str_append_c(query, '+');
-		str_append(query, fields[i].value);
 	}
 
 	if (sql_dict_where_build(dict, fields[0].map, build->extra_values,
-				 build->key1, SQL_DICT_RECURSE_NONE, query, error_r) < 0)
+				 build->key1, SQL_DICT_RECURSE_NONE, query,
+				 params, error_r) < 0)
 		return -1;
 	*query_r = str_c(query);
 	return 0;
@@ -1071,6 +1136,7 @@
 	ARRAY_TYPE(const_string) values;
 	struct dict_sql_build_query build;
 	struct dict_sql_build_query_field field;
+	ARRAY_TYPE(sql_dict_param) params;
 	const char *query, *error;
 
 	map = sql_dict_find_map(dict, key, &values);
@@ -1090,13 +1156,13 @@
 	build.extra_values = &values;
 	build.key1 = key[0];
 
-	if (sql_dict_set_query(&build, &query, &error) < 0) {
+	if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
 		i_error("dict-sql: Failed to set %s=%s: %s",
 			key, value, error);
 		ctx->failed = TRUE;
 	} else {
 		struct sql_statement *stmt =
-			sql_dict_transaction_stmt_init(ctx, query);
+			sql_dict_transaction_stmt_init(ctx, query, &params);
 		sql_update_stmt(ctx->sql_ctx, &stmt);
 	}
 }
@@ -1110,6 +1176,7 @@
 	const struct dict_sql_map *map;
 	ARRAY_TYPE(const_string) values;
 	string_t *query = t_str_new(256);
+	ARRAY_TYPE(sql_dict_param) params;
 	const char *error;
 
 	if (ctx->prev_inc_map != NULL)
@@ -1125,13 +1192,15 @@
 	}
 
 	str_printfa(query, "DELETE FROM %s", map->table);
+	t_array_init(&params, 4);
 	if (sql_dict_where_build(dict, map, &values, key[0],
-				 SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+				 SQL_DICT_RECURSE_NONE, query,
+				 &params, &error) < 0) {
 		i_error("dict-sql: Failed to delete %s: %s", key, error);
 		ctx->failed = TRUE;
 	} else {
 		struct sql_statement *stmt =
-			sql_dict_transaction_stmt_init(ctx, str_c(query));
+			sql_dict_transaction_stmt_init(ctx, str_c(query), &params);
 		sql_update_stmt(ctx->sql_ctx, &stmt);
 	}
 }
@@ -1171,6 +1240,8 @@
 	ARRAY_TYPE(const_string) values;
 	struct dict_sql_build_query build;
 	struct dict_sql_build_query_field field;
+	ARRAY_TYPE(sql_dict_param) params;
+	struct sql_dict_param *param;
 	const char *query, *error;
 
 	map = sql_dict_find_map(dict, key, &values);
@@ -1186,12 +1257,17 @@
 	build.extra_values = &values;
 	build.key1 = key[0];
 
-	if (sql_dict_update_query(&build, &query, &error) < 0) {
+	t_array_init(&params, 4);
+	param = array_append_space(&params);
+	param->value_type = DICT_SQL_TYPE_INT;
+	param->value_int64 = diff;
+
+	if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
 		i_error("dict-sql: Failed to increase %s: %s", key, error);
 		ctx->failed = TRUE;
 	} else {
 		struct sql_statement *stmt =
-			sql_dict_transaction_stmt_init(ctx, query);
+			sql_dict_transaction_stmt_init(ctx, query, &params);
 		sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
 					 sql_dict_next_inc_row(ctx));
 	}
@@ -1283,6 +1359,7 @@
 	} else {
 		struct dict_sql_build_query build;
 		struct dict_sql_build_query_field *field;
+		ARRAY_TYPE(sql_dict_param) params;
 		const char *query, *error;
 
 		i_zero(&build);
@@ -1298,12 +1375,13 @@
 		field->map = map;
 		field->value = value;
 
-		if (sql_dict_set_query(&build, &query, &error) < 0) {
+		t_array_init(&params, 4);
+		if (sql_dict_set_query(&build, &query, &params, &error) < 0) {
 			i_error("dict-sql: Failed to set %s: %s", key, error);
 			ctx->failed = TRUE;
 		} else {
 			struct sql_statement *stmt =
-				sql_dict_transaction_stmt_init(ctx, query);
+				sql_dict_transaction_stmt_init(ctx, query, &params);
 			sql_update_stmt(ctx->sql_ctx, &stmt);
 		}
 		i_free_and_null(ctx->prev_set_value);
@@ -1344,6 +1422,8 @@
 	} else {
 		struct dict_sql_build_query build;
 		struct dict_sql_build_query_field *field;
+		ARRAY_TYPE(sql_dict_param) params;
+		struct sql_dict_param *param;
 		const char *query, *error;
 
 		i_zero(&build);
@@ -1359,12 +1439,21 @@
 		field->map = map;
 		field->value = t_strdup_printf("%lld", diff);
 
-		if (sql_dict_update_query(&build, &query, &error) < 0) {
+		t_array_init(&params, 4);
+		param = array_append_space(&params);
+		param->value_type = DICT_SQL_TYPE_INT;
+		param->value_int64 = ctx->prev_inc_diff;
+
+		param = array_append_space(&params);
+		param->value_type = DICT_SQL_TYPE_INT;
+		param->value_int64 = diff;
+
+		if (sql_dict_update_query(&build, &query, &params, &error) < 0) {
 			i_error("dict-sql: Failed to increase %s: %s", key, error);
 			ctx->failed = TRUE;
 		} else {
 			struct sql_statement *stmt =
-				sql_dict_transaction_stmt_init(ctx, query);
+				sql_dict_transaction_stmt_init(ctx, query, &params);
 			sql_update_stmt_get_rows(ctx->sql_ctx, &stmt,
 						 sql_dict_next_inc_row(ctx));
 		}