changeset 19220:fbff1e801300

dict-sql: If value isn't a hexblob as expected, log an error instead of killing the whole process.
author Timo Sirainen <tss@iki.fi>
date Fri, 25 Sep 2015 18:16:56 +0300
parents 8183663ad7c0
children 5e18e1386600
files src/lib-dict/dict-sql.c
diffstat 1 files changed, 120 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/src/lib-dict/dict-sql.c	Thu Sep 24 20:58:04 2015 +0300
+++ b/src/lib-dict/dict-sql.c	Fri Sep 25 18:16:56 2015 +0300
@@ -211,44 +211,48 @@
 	return NULL;
 }
 
-static void
+static int
 sql_dict_value_escape(string_t *str, struct sql_dict *dict,
 		      bool value_is_hexblob, const char *field_name,
-		      const char *value, const char *value_suffix)
+		      const char *value, const char *value_suffix,
+		      const char **error_r)
 {
 	buffer_t *buf;
 
 	if (!value_is_hexblob) {
 		str_printfa(str, "'%s%s'", sql_escape_string(dict->db, value),
 			    value_suffix);
-		return;
+		return 0;
 	}
 
 	buf = buffer_create_dynamic(pool_datastack_create(), strlen(value)/2);
 	if (hex_to_binary(value, buf) < 0) {
 		/* we shouldn't get untrusted input here. it's also a bit
 		   annoying to handle this error. */
-		i_fatal("dict-sql: field %s value isn't hexblob: %s",
-			field_name, value);
+		*error_r = t_strdup_printf("field %s value isn't hexblob: %s",
+					   field_name, value);
+		return -1;
 	}
 	str_append(buf, value_suffix);
 	str_append(str, sql_escape_blob(dict->db, buf->data, buf->used));
+	return 0;
 }
 
-static void
+static int
 sql_dict_field_escape_value(string_t *str, struct sql_dict *dict,
 			    const struct dict_sql_field *field,
-			    const char *value, const char *value_suffix)
+			    const char *value, const char *value_suffix,
+			    const char **error_r)
 {
-	sql_dict_value_escape(str, dict, field->value_is_hexblob,
-			      field->name, value, value_suffix);
+	return sql_dict_value_escape(str, dict, field->value_is_hexblob,
+				     field->name, value, value_suffix, error_r);
 }
 
-static void
+static int
 sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map,
 		     const ARRAY_TYPE(const_string) *values_arr,
 		     char key1, enum sql_recurse_type recurse_type,
-		     string_t *query)
+		     string_t *query, const char **error_r)
 {
 	const struct dict_sql_field *sql_fields;
 	const char *const *values;
@@ -262,7 +266,7 @@
 
 	if (count2 == 0 && !priv) {
 		/* we want everything */
-		return;
+		return 0;
 	}
 
 	str_append(query, " WHERE");
@@ -272,7 +276,9 @@
 		if (i > 0)
 			str_append(query, " AND");
 		str_printfa(query, " %s = ", sql_fields[i].name);
-		sql_dict_field_escape_value(query, dict, &sql_fields[i], values[i], "");
+		if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+						values[i], "", error_r) < 0)
+			return -1;
 	}
 	switch (recurse_type) {
 	case SQL_DICT_RECURSE_NONE:
@@ -282,11 +288,13 @@
 			str_append(query, " AND");
 		if (i < count2) {
 			str_printfa(query, " %s LIKE ", sql_fields[i].name);
-			sql_dict_field_escape_value(query, dict, &sql_fields[i],
-						    values[i], "/%");
+			if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+							values[i], "/%", error_r) < 0)
+				return -1;
 			str_printfa(query, " AND %s NOT LIKE ", sql_fields[i].name);
-			sql_dict_field_escape_value(query, dict, &sql_fields[i],
-						    values[i], "/%/%");
+			if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+							values[i], "/%/%", error_r) < 0)
+				return -1;
 		} else {
 			str_printfa(query, " %s LIKE '%%' AND "
 				    "%s NOT LIKE '%%/%%'",
@@ -299,8 +307,9 @@
 				str_append(query, " AND");
 			str_printfa(query, " %s LIKE ",
 				    sql_fields[i].name);
-			sql_dict_field_escape_value(query, dict, &sql_fields[i],
-						    values[i], "/%");
+			if (sql_dict_field_escape_value(query, dict, &sql_fields[i],
+							values[i], "/%", error_r) < 0)
+				return -1;
 		}
 		break;
 	}
@@ -310,6 +319,7 @@
 		str_printfa(query, " %s = '%s'", map->username_field,
 			    sql_escape_string(dict->db, dict->username));
 	}
+	return 0;
 }
 
 static int
@@ -318,6 +328,7 @@
 {
 	const struct dict_sql_map *map;
 	ARRAY_TYPE(const_string) values;
+	const char *error;
 
 	map = *map_r = sql_dict_find_map(dict, key, &values);
 	if (map == NULL) {
@@ -326,8 +337,11 @@
 	}
 	str_printfa(query, "SELECT %s FROM %s",
 		    map->value_field, map->table);
-	sql_dict_where_build(dict, map, &values, key[0],
-			     SQL_DICT_RECURSE_NONE, query);
+	if (sql_dict_where_build(dict, map, &values, key[0],
+				 SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+		i_error("sql dict lookup: Failed to lookup key %s: %s", key, error);
+		return -1;
+	}
 	return 0;
 }
 
@@ -483,9 +497,9 @@
 	return NULL;
 }
 
-static bool
+static int
 sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx,
-				  string_t *query)
+				  string_t *query, const char **error_r)
 {
 	struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
 	const struct dict_sql_map *map;
@@ -495,8 +509,10 @@
 	unsigned int i, count;
 
 	map = sql_dict_iterate_find_next_map(ctx, &values);
-	if (map == NULL)
-		return FALSE;
+	if (map == NULL) {
+		*error_r = "Invalid/unmapped path";
+		return 0;
+	}
 
 	if (ctx->result != NULL) {
 		sql_result_unref(ctx->result);
@@ -529,9 +545,10 @@
 		recurse_type = SQL_DICT_RECURSE_NONE;
 	else
 		recurse_type = SQL_DICT_RECURSE_ONE;
-	sql_dict_where_build(dict, map, &values,
-			     ctx->paths[ctx->path_idx][0],
-			     recurse_type, query);
+	if (sql_dict_where_build(dict, map, &values,
+				 ctx->paths[ctx->path_idx][0],
+				 recurse_type, query, error_r) < 0)
+		return -1;
 
 	if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) {
 		str_append(query, " ORDER BY ");
@@ -544,7 +561,7 @@
 		str_printfa(query, " ORDER BY %s", map->value_field);
 
 	ctx->map = map;
-	return TRUE;
+	return 1;
 }
 
 static void sql_dict_iterate_callback(struct sql_result *result,
@@ -556,17 +573,20 @@
 		ctx->ctx.async_callback(ctx->ctx.async_context);
 }
 
-static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx)
+static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx,
+				       const char **error_r)
 {
 	struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict;
-	bool ret;
+	char *error = NULL;
+	int ret;
 
 	T_BEGIN {
 		string_t *query = t_str_new(256);
 
-		ret = sql_dict_iterate_build_next_query(ctx, query);
-		if (!ret) {
+		ret = sql_dict_iterate_build_next_query(ctx, query, error_r);
+		if (ret <= 0) {
 			/* failed */
+			error = i_strdup(*error_r);
 		} else if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) {
 			ctx->result = sql_query_s(dict->db, str_c(query));
 		} else {
@@ -575,6 +595,8 @@
 				  sql_dict_iterate_callback, ctx);
 		}
 	} T_END;
+	*error_r = t_strdup(error);
+	i_free(error);
 	return ret;
 }
 
@@ -584,6 +606,7 @@
 {
 	struct sql_dict_iterate_context *ctx;
 	unsigned int i, path_count;
+	const char *error;
 	pool_t pool;
 
 	pool = pool_alloconly_create("sql dict iterate", 512);
@@ -598,9 +621,9 @@
 		ctx->paths[i] = p_strdup(pool, paths[i]);
 
 	ctx->key = str_new(pool, 256);
-	if (!sql_dict_iterate_next_query(ctx)) {
-		i_error("sql dict iterate: Invalid/unmapped path: %s",
-			paths[0]);
+	if (sql_dict_iterate_next_query(ctx, &error) <= 0) {
+		i_error("sql dict iterate failed for %s: %s",
+			paths[0], error);
 		ctx->result = NULL;
 		ctx->failed = TRUE;
 		return &ctx->ctx;
@@ -613,7 +636,7 @@
 {
 	struct sql_dict_iterate_context *ctx =
 		(struct sql_dict_iterate_context *)_ctx;
-	const char *p, *value;
+	const char *p, *value, *error;
 	unsigned int i, sql_field_i, count;
 	int ret;
 
@@ -635,8 +658,9 @@
 		/* see if there are more results in the next map.
 		   don't do it if we're looking for an exact match, since we
 		   already should have handled it. */
-		if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0 ||
-		    !sql_dict_iterate_next_query(ctx))
+		if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0)
+			return FALSE;
+		if ((ret = sql_dict_iterate_next_query(ctx, &error)) == 0)
 			return FALSE;
 	}
 	if (ret < 0) {
@@ -806,7 +830,8 @@
 	bool inc;
 };
 
-static const char *sql_dict_set_query(const struct dict_sql_build_query *build)
+static int sql_dict_set_query(const struct dict_sql_build_query *build,
+			      const char **query_r, const char **error_r)
 {
 	struct sql_dict *dict = build->dict;
 	const struct dict_sql_build_query_field *fields;
@@ -831,9 +856,10 @@
 		if (build->inc)
 			str_append(suffix, fields[i].value);
 		else {
-			sql_dict_value_escape(suffix, dict,
+			if (sql_dict_value_escape(suffix, dict,
 				fields[i].map->value_hexblob,
-				"value", fields[i].value, "");
+				"value", fields[i].value, "", error_r) < 0)
+				return -1;
 		}
 	}
 	if (build->key1 == DICT_PATH_PRIVATE[0]) {
@@ -849,14 +875,17 @@
 	for (i = 0; i < count; i++) {
 		str_printfa(prefix, ",%s", sql_fields[i].name);
 		str_append_c(suffix, ',');
-		sql_dict_field_escape_value(suffix, dict, &sql_fields[i],
-					    extra_values[i], "");
+		if (sql_dict_field_escape_value(suffix, dict, &sql_fields[i],
+						extra_values[i], "", error_r) < 0)
+			return -1;
 	}
 
 	str_append_str(prefix, suffix);
 	str_append_c(prefix, ')');
-	if (!dict->has_on_duplicate_key)
-		return str_c(prefix);
+	if (!dict->has_on_duplicate_key) {
+		*query_r = str_c(prefix);
+		return 0;
+	}
 
 	str_append(prefix, " ON DUPLICATE KEY UPDATE ");
 	for (i = 0; i < field_count; i++) {
@@ -869,16 +898,19 @@
 				    fields[i].map->value_field,
 				    fields[i].value);
 		} else {
-			sql_dict_value_escape(prefix, dict,
+			if (sql_dict_value_escape(prefix, dict,
 				fields[i].map->value_hexblob,
-				"value", fields[i].value, "");
+				"value", fields[i].value, "", error_r) < 0)
+				return -1;
 		}
 	}
-	return str_c(prefix);
+	*query_r = str_c(prefix);
+	return 0;
 }
 
-static const char *
-sql_dict_update_query(const struct dict_sql_build_query *build)
+static int
+sql_dict_update_query(const struct dict_sql_build_query *build,
+		      const char **query_r, const char **error_r)
 {
 	struct sql_dict *dict = build->dict;
 	const struct dict_sql_build_query_field *fields;
@@ -902,9 +934,11 @@
 		str_append(query, fields[i].value);
 	}
 
-	sql_dict_where_build(dict, fields[0].map, build->extra_values,
-			     build->key1, SQL_DICT_RECURSE_NONE, query);
-	return str_c(query);
+	if (sql_dict_where_build(dict, fields[0].map, build->extra_values,
+				 build->key1, SQL_DICT_RECURSE_NONE, query, error_r) < 0)
+		return -1;
+	*query_r = str_c(query);
+	return 0;
 }
 
 static void sql_dict_set(struct dict_transaction_context *_ctx,
@@ -929,7 +963,7 @@
 	T_BEGIN {
 		struct dict_sql_build_query build;
 		struct dict_sql_build_query_field field;
-		const char *query;
+		const char *query, *error;
 
 		field.map = map;
 		field.value = value;
@@ -941,8 +975,13 @@
 		build.extra_values = &values;
 		build.key1 = key[0];
 
-		query = sql_dict_set_query(&build);
-		sql_update(ctx->sql_ctx, query);
+		if (sql_dict_set_query(&build, &query, &error) < 0) {
+			i_error("dict-sql: Failed to set %s=%s: %s",
+				key, value, error);
+			ctx->failed = TRUE;
+		} else {
+			sql_update(ctx->sql_ctx, query);
+		}
 	} T_END;
 }
 
@@ -967,11 +1006,16 @@
 
 	T_BEGIN {
 		string_t *query = t_str_new(256);
+		const char *error;
 
 		str_printfa(query, "DELETE FROM %s", map->table);
-		sql_dict_where_build(dict, map, &values, key[0],
-				     SQL_DICT_RECURSE_NONE, query);
-		sql_update(ctx->sql_ctx, str_c(query));
+		if (sql_dict_where_build(dict, map, &values, key[0],
+					 SQL_DICT_RECURSE_NONE, query, &error) < 0) {
+			i_error("dict-sql: Failed to delete %s: %s", key, error);
+			ctx->failed = TRUE;
+		} else {
+			sql_update(ctx->sql_ctx, str_c(query));
+		}
 	} T_END;
 }
 
@@ -1015,6 +1059,7 @@
 	T_BEGIN {
 		struct dict_sql_build_query build;
 		struct dict_sql_build_query_field field;
+		const char *query, *error;
 
 		field.map = map;
 		field.value = t_strdup_printf("%lld", diff);
@@ -1027,8 +1072,13 @@
 		build.key1 = key[0];
 		build.inc = TRUE;
 
-		sql_update_get_rows(ctx->sql_ctx, sql_dict_update_query(&build),
-				    sql_dict_next_inc_row(ctx));
+		if (sql_dict_update_query(&build, &query, &error) < 0) {
+			i_error("dict-sql: Failed to increase %s: %s", key, error);
+			ctx->failed = TRUE;
+		} else {
+			sql_update_get_rows(ctx->sql_ctx, query,
+					    sql_dict_next_inc_row(ctx));
+		}
 	} T_END;
 }
 
@@ -1107,6 +1157,7 @@
 	} else T_BEGIN {
 		struct dict_sql_build_query build;
 		struct dict_sql_build_query_field *field;
+		const char *query, *error;
 
 		memset(&build, 0, sizeof(build));
 		build.dict = dict;
@@ -1122,8 +1173,13 @@
 		field->map = map;
 		field->value = t_strdup_printf("%lld", diff);
 
-		sql_update_get_rows(ctx->sql_ctx, sql_dict_update_query(&build),
-				    sql_dict_next_inc_row(ctx));
+		if (sql_dict_update_query(&build, &query, &error) < 0) {
+			i_error("dict-sql: Failed to increase %s: %s", key, error);
+			ctx->failed = TRUE;
+		} else {
+			sql_update_get_rows(ctx->sql_ctx, query,
+					    sql_dict_next_inc_row(ctx));
+		}
 
 		i_free_and_null(ctx->prev_inc_key);
 		ctx->prev_inc_map = NULL;