Mercurial > dovecot > original-hg > dovecot-1.2
changeset 8113:2d3e942a4cde HEAD
dict sql: Added configuration for mapping dict paths to SQL fields.
author | Timo Sirainen <tss@iki.fi> |
---|---|
date | Wed, 27 Aug 2008 12:10:21 +0300 |
parents | d630b13ccd60 |
children | 74153b8e63a6 |
files | doc/dovecot-dict-sql-example.conf src/dict/Makefile.am src/lib-dict/Makefile.am src/lib-dict/dict-sql-settings.c src/lib-dict/dict-sql-settings.h src/lib-dict/dict-sql.c src/lib-dict/dict.c |
diffstat | 7 files changed, 576 insertions(+), 247 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/dovecot-dict-sql-example.conf Wed Aug 27 12:10:21 2008 +0300 @@ -0,0 +1,39 @@ +#connect = host=localhost dbname=mails user=testuser password=pass + +# CREATE TABLE quota ( +# username varchar(100) not null, +# bytes bigint not null default 0, +# messages integer not null default 0, +# primary key (username) +# ); + +map { + pattern = priv/quota/storage + table = quota + username_field = username + value_field = bytes +} +map { + pattern = priv/quota/messages + table = quota + username_field = username + value_field = messages +} + +# CREATE TABLE expires ( +# username varchar(100) not null, +# mailbox varchar(255) not null, +# expire_stamp integer not null, +# primary key (username, mailbox) +# ); + +map { + pattern = shared/expire/$user/$mailbox + table = expires + value_field = expire_stamp + + fields { + username = $user + mailbox = $mailbox + } +}
--- a/src/dict/Makefile.am Wed Aug 27 12:08:14 2008 +0300 +++ b/src/dict/Makefile.am Wed Aug 27 12:10:21 2008 +0300 @@ -15,6 +15,7 @@ ../lib-dict/libdict_backend.a \ ../lib-dict/libdict.a \ ../lib-sql/libsql.a \ + ../lib-settings/libsettings.a \ ../lib/liblib.a dict_LDADD = \
--- a/src/lib-dict/Makefile.am Wed Aug 27 12:08:14 2008 +0300 +++ b/src/lib-dict/Makefile.am Wed Aug 27 12:10:21 2008 +0300 @@ -5,6 +5,7 @@ AM_CPPFLAGS = \ -I$(top_srcdir)/src/lib \ -I$(top_srcdir)/src/lib-sql \ + -I$(top_srcdir)/src/lib-settings \ -DPKG_RUNDIR=\""$(rundir)"\" \ $(SQL_CFLAGS) @@ -14,7 +15,8 @@ backend_sources = \ dict-db.c \ - dict-sql.c + dict-sql.c \ + dict-sql-settings.c libdict_a_SOURCES = \ $(base_sources) @@ -27,7 +29,8 @@ dict.h \ dict-client.h \ dict-private.h \ - dict-sql.h + dict-sql.h \ + dict-sql-settings.h if INSTALL_HEADERS pkginc_libdir=$(pkgincludedir)/src/lib-dict
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/lib-dict/dict-sql-settings.c Wed Aug 27 12:10:21 2008 +0300 @@ -0,0 +1,215 @@ +/* Copyright (c) 2008 Dovecot authors, see the included COPYING file */ + +#include "lib.h" +#include "array.h" +#include "str.h" +#include "settings.h" +#include "dict-sql-settings.h" + +#include <ctype.h> + +enum section_type { + SECTION_ROOT = 0, + SECTION_MAP, + SECTION_FIELDS +}; + +struct dict_sql_map_field { + const char *sql_field, *variable; +}; + +struct setting_parser_ctx { + pool_t pool; + struct dict_sql_settings *set; + enum section_type type; + + struct dict_sql_map cur_map; + ARRAY_DEFINE(cur_fields, struct dict_sql_map_field); +}; + +#define DEF_STR(name) DEF_STRUCT_STR(name, dict_sql_map) + +static struct setting_def dict_sql_map_setting_defs[] = { + DEF_STR(pattern), + DEF_STR(table), + DEF_STR(username_field), + DEF_STR(value_field), + + { 0, NULL, 0 } +}; + +static const char *pattern_read_name(const char **pattern) +{ + const char *p = *pattern, *name; + + if (*p == '{') { + /* ${name} */ + name = ++p; + p = strchr(p, '}'); + if (p == NULL) { + /* error, but allow anyway */ + *pattern += strlen(*pattern); + return ""; + } + *pattern = p + 1; + } else { + /* $name - ends at the first non-alnum_ character */ + name = p; + for (; *p != '\0'; p++) { + if (!i_isalnum(*p) && *p != '_') + break; + } + *pattern = p; + } + name = t_strdup_until(name, p); + return name; +} + +static const char *dict_sql_fields_map(struct setting_parser_ctx *ctx) +{ + struct dict_sql_map_field *fields; + string_t *pattern; + const char *p, *name; + unsigned int i, count; + + p_array_init(&ctx->cur_map.sql_fields, ctx->pool, count); + + /* go through the variables in the pattern, replace them with plain + '$' character and add its sql field */ + pattern = t_str_new(strlen(ctx->cur_map.pattern) + 1); + fields = array_get_modifiable(&ctx->cur_fields, &count); + for (p = ctx->cur_map.pattern; *p != '\0';) { + if (*p != '$') { + str_append_c(pattern, *p); + p++; + continue; + } + p++; + str_append_c(pattern, '$'); + + name = pattern_read_name(&p); + for (i = 0; i < count; i++) { + if (fields[i].variable != NULL && + strcmp(fields[i].variable, name) == 0) + break; + } + if (i == count) { + return t_strconcat("Missing SQL field for variable: ", + name, NULL); + } + + /* mark this field as used */ + fields[i].variable = NULL; + array_append(&ctx->cur_map.sql_fields, + &fields[i].sql_field, 1); + } + + /* make sure there aren't any unused fields */ + for (i = 0; i < count; i++) { + if (fields[i].variable != NULL) { + return t_strconcat("Unused variable: ", + fields[i].variable, NULL); + } + } + + if (ctx->set->max_field_count < count) + ctx->set->max_field_count = count; + ctx->cur_map.pattern = p_strdup(ctx->pool, str_c(pattern)); + return NULL; +} + +static const char *dict_sql_map_finish(struct setting_parser_ctx *ctx) +{ + if (!array_is_created(&ctx->cur_map.sql_fields)) { + /* no fields besides value. allocate the array anyway. */ + p_array_init(&ctx->cur_map.sql_fields, ctx->pool, 1); + if (strchr(ctx->cur_map.pattern, '$') != NULL) + return "Missing fields for pattern variables"; + } + array_append(&ctx->set->maps, &ctx->cur_map, 1); + memset(&ctx->cur_map, 0, sizeof(ctx->cur_map)); + return NULL; +} + +static const char * +parse_setting(const char *key, const char *value, + struct setting_parser_ctx *ctx) +{ + struct dict_sql_map_field *field; + + switch (ctx->type) { + case SECTION_ROOT: + if (strcmp(key, "connect") == 0) { + ctx->set->connect = p_strdup(ctx->pool, value); + return NULL; + } + break; + case SECTION_MAP: + return parse_setting_from_defs(ctx->pool, + dict_sql_map_setting_defs, + &ctx->cur_map, key, value); + case SECTION_FIELDS: + if (*value != '$') { + return t_strconcat("Value is missing '$' for field: ", + key, NULL); + } + field = array_append_space(&ctx->cur_fields); + field->sql_field = p_strdup(ctx->pool, key); + field->variable = p_strdup(ctx->pool, value + 1); + return NULL; + } + return t_strconcat("Unknown setting: ", key, NULL); +} + +static bool +parse_section(const char *type, const char *name ATTR_UNUSED, + struct setting_parser_ctx *ctx, const char **error_r) +{ + switch (ctx->type) { + case SECTION_ROOT: + if (type == NULL) + return FALSE; + if (strcmp(type, "map") == 0) { + array_clear(&ctx->cur_fields); + ctx->type = SECTION_MAP; + return TRUE; + } + break; + case SECTION_MAP: + if (type == NULL) { + ctx->type = SECTION_ROOT; + *error_r = dict_sql_map_finish(ctx); + return FALSE; + } + if (strcmp(type, "fields") == 0) { + ctx->type = SECTION_FIELDS; + return TRUE; + } + break; + case SECTION_FIELDS: + if (type == NULL) { + ctx->type = SECTION_MAP; + *error_r = dict_sql_fields_map(ctx); + return FALSE; + } + break; + } + *error_r = t_strconcat("Unknown section: ", type, NULL); + return FALSE; +} + +struct dict_sql_settings *dict_sql_settings_read(pool_t pool, const char *path) +{ + struct setting_parser_ctx ctx; + + memset(&ctx, 0, sizeof(ctx)); + ctx.pool = pool; + ctx.set = p_new(pool, struct dict_sql_settings, 1); + t_array_init(&ctx.cur_fields, 16); + p_array_init(&ctx.set->maps, pool, 8); + + if (!settings_read(path, NULL, parse_setting, parse_section, &ctx)) + return NULL; + + return ctx.set; +}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/lib-dict/dict-sql-settings.h Wed Aug 27 12:10:21 2008 +0300 @@ -0,0 +1,24 @@ +#ifndef DICT_SQL_SETTINGS_H +#define DICT_SQL_SETTINGS_H + +struct dict_sql_map { + /* pattern is in simplified form: all variables are stored as simple + '$' character. fields array is sorted by the variable index. */ + const char *pattern; + const char *table; + const char *username_field; + const char *value_field; + + ARRAY_TYPE(const_string) sql_fields; +}; + +struct dict_sql_settings { + const char *connect; + + unsigned int max_field_count; + ARRAY_DEFINE(maps, struct dict_sql_map); +}; + +struct dict_sql_settings *dict_sql_settings_read(pool_t pool, const char *path); + +#endif
--- a/src/lib-dict/dict-sql.c Wed Aug 27 12:08:14 2008 +0300 +++ b/src/lib-dict/dict-sql.c Wed Aug 27 12:10:21 2008 +0300 @@ -7,6 +7,7 @@ #include "sql-api-private.h" #include "sql-pool.h" #include "dict-private.h" +#include "dict-sql-settings.h" #include "dict-sql.h" #include <unistd.h> @@ -19,19 +20,22 @@ pool_t pool; struct sql_db *db; - - const char *connect_string, *username; - const char *table, *select_field, *where_field, *username_field; + const char *username; + const struct dict_sql_settings *set; + unsigned int prev_map_match_idx; unsigned int has_on_duplicate_key:1; }; struct sql_dict_iterate_context { struct dict_iterate_context ctx; + enum dict_iterate_flags flags; + char *path; struct sql_result *result; - char *prev_key; - bool priv; + string_t *key; + const struct dict_sql_map *map; + unsigned int key_prefix_len, next_map_idx; }; struct sql_dict_transaction_context { @@ -45,77 +49,6 @@ static struct sql_pool *dict_sql_pool; -static void sql_dict_config_parse_line(struct sql_dict *dict, const char *line) -{ - const char *p, *value; - - while (*line == ' ') line++; - value = strchr(line, '='); - if (value == NULL) - return; - - for (p = value; p[-1] == ' ' && p != line; p--) ; - line = t_strdup_until(line, p); - value++; - while (*value == ' ') value++; - - if (strcmp(line, "connect") == 0) - dict->connect_string = p_strdup(dict->pool, value); - else if (strcmp(line, "table") == 0) - dict->table = p_strdup(dict->pool, value); - else if (strcmp(line, "select_field") == 0) - dict->select_field = p_strdup(dict->pool, value); - else if (strcmp(line, "where_field") == 0) - dict->where_field = p_strdup(dict->pool, value); - else if (strcmp(line, "username_field") == 0) - dict->username_field = p_strdup(dict->pool, value); -} - -static int sql_dict_read_config(struct sql_dict *dict, const char *path) -{ - struct istream *input; - const char *line; - int fd; - - fd = open(path, O_RDONLY); - if (fd == -1) { - i_error("open(%s) failed: %m", path); - return -1; - } - - input = i_stream_create_fd(fd, (size_t)-1, FALSE); - while ((line = i_stream_read_next_line(input)) != NULL) { - T_BEGIN { - sql_dict_config_parse_line(dict, line); - } T_END; - } - i_stream_destroy(&input); - (void)close(fd); - - if (dict->connect_string == NULL) { - i_error("%s: 'connect' missing", path); - return -1; - } - if (dict->table == NULL) { - i_error("%s: 'table' missing", path); - return -1; - } - if (dict->select_field == NULL) { - i_error("%s: 'select_field' missing", path); - return -1; - } - if (dict->where_field == NULL) { - i_error("%s: 'where_field' missing", path); - return -1; - } - if (dict->username_field == NULL) { - i_error("%s: 'username_field' missing", path); - return -1; - } - - return 0; -} - static struct dict * sql_dict_init(struct dict *driver, const char *uri, enum dict_data_type value_type ATTR_UNUSED, @@ -129,8 +62,8 @@ dict->pool = pool; dict->dict = *driver; dict->username = p_strdup(pool, username); - - if (sql_dict_read_config(dict, uri) < 0) { + dict->set = dict_sql_settings_read(pool, uri); + if (dict->set == NULL) { pool_unref(&pool); return NULL; } @@ -139,7 +72,7 @@ dict->has_on_duplicate_key = strcmp(driver->name, "mysql") == 0; dict->db = sql_pool_new(dict_sql_pool, driver->name, - dict->connect_string); + dict->set->connect); return &dict->dict; } @@ -151,58 +84,135 @@ pool_unref(&dict->pool); } -static int sql_path_fix(const char **path, bool *private_r) +static bool +dict_sql_map_match(const struct dict_sql_map *map, const char *path, + ARRAY_TYPE(const_string) *values, bool partial_ok) { - const char *p; - size_t len; + const char *pat, *field, *p; - p = strchr(*path, '/'); - if (p == NULL) - return -1; - len = p - *path; + array_clear(values); + pat = map->pattern; + while (*pat != '\0' && *path != '\0') { + if (*pat == '$') { + /* variable */ + pat++; + if (*pat == '\0') { + /* pattern ended with this variable, + it'll match the rest of the path */ + array_append(values, &path, 1); + return TRUE; + } + /* pattern matches until the next '/' in path */ + p = strchr(path, '/'); + if (p == NULL) + return FALSE; + field = t_strdup_until(path, p); + array_append(values, &field, 1); + path = p; + } else if (*pat == *path) { + pat++; + path++; + } else { + return FALSE; + } + } + if (*pat == '\0') + return *path == '\0'; + else if (!partial_ok) + return FALSE; + else { + /* partial matches must end with '/' */ + return pat == map->pattern || pat[-1] == '/'; + } +} - if (strncmp(*path, DICT_PATH_PRIVATE, len) == 0) - *private_r = TRUE; - else if (strncmp(*path, DICT_PATH_SHARED, len) == 0) - *private_r = FALSE; - else - return -1; +static const struct dict_sql_map * +sql_dict_find_map(struct sql_dict *dict, const char *path, + ARRAY_TYPE(const_string) *values) +{ + const struct dict_sql_map *maps; + unsigned int i, idx, count; + + t_array_init(values, dict->set->max_field_count); + maps = array_get(&dict->set->maps, &count); + for (i = 0; i < count; i++) { + /* start matching from the previously successful match */ + idx = (dict->prev_map_match_idx + i) % count; + if (dict_sql_map_match(&maps[idx], path, values, FALSE)) { + dict->prev_map_match_idx = idx; + return &maps[idx]; + } + } + return NULL; +} - *path += len + 1; - return 0; +static void +sql_dict_where_build(struct sql_dict *dict, const struct dict_sql_map *map, + const ARRAY_TYPE(const_string) *values_arr, + const char *key, string_t *query) +{ + const char *const *sql_fields, *const *values; + unsigned int i, count, count2; + bool priv = *key == DICT_PATH_PRIVATE[0]; + + sql_fields = array_get(&map->sql_fields, &count); + values = array_get(values_arr, &count2); + /* if we came here from iteration code there may be less values */ + i_assert(count2 <= count); + + if (count2 == 0 && !priv) { + /* we want everything */ + return; + } + + str_append(query, "WHERE"); + for (i = 0; i < count2; i++) { + if (i > 0) + str_append(query, " AND"); + str_printfa(query, " %s = '%s'", sql_fields[i], + sql_escape_string(dict->db, values[i])); + } + if (priv) { + if (count2 > 0) + str_append(query, " AND"); + str_printfa(query, " %s = '%s'", map->username_field, + sql_escape_string(dict->db, dict->username)); + } } static int sql_dict_lookup(struct dict *_dict, pool_t pool, const char *key, const char **value_r) { struct sql_dict *dict = (struct sql_dict *)_dict; + const struct dict_sql_map *map; + ARRAY_TYPE(const_string) values; struct sql_result *result; int ret; - bool priv; - if (sql_path_fix(&key, &priv) < 0) { + map = sql_dict_find_map(dict, key, &values); + if (map == NULL) { + i_error("sql dict lookup: Invalid/unmapped key: %s", key); *value_r = NULL; - return -1; + return 0; } T_BEGIN { string_t *query = t_str_new(256); - str_printfa(query, "SELECT %s FROM %s WHERE %s = '%s'", - dict->select_field, dict->table, - dict->where_field, - sql_escape_string(dict->db, key)); - if (priv) { - str_printfa(query, " AND %s = '%s'", - dict->username_field, - sql_escape_string(dict->db, dict->username)); - } + + str_printfa(query, "SELECT %s FROM %s ", + map->value_field, map->table); + sql_dict_where_build(dict, map, &values, key, query); result = sql_query_s(dict->db, str_c(query)); } T_END; ret = sql_result_next_row(result); - if (ret <= 0) + if (ret <= 0) { + if (ret < 0) { + i_error("dict sql lookup failed: %s", + sql_result_get_error(result)); + } *value_r = NULL; - else { + } else { *value_r = p_strdup(pool, sql_result_get_field_value(result, 0)); } @@ -211,64 +221,85 @@ return ret; } +static const struct dict_sql_map * +sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx, + ARRAY_TYPE(const_string) *values) +{ + struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; + const struct dict_sql_map *maps; + unsigned int i, count; + + t_array_init(values, dict->set->max_field_count); + maps = array_get(&dict->set->maps, &count); + for (i = ctx->next_map_idx; i < count; i++) { + if (dict_sql_map_match(&maps[i], ctx->path, values, TRUE) && + ((ctx->flags & DICT_ITERATE_FLAG_RECURSE) != 0 || + array_count(values)+1 == array_count(&maps[i].sql_fields))) { + ctx->next_map_idx = i + 1; + return &maps[i]; + } + } + return NULL; +} + +static bool sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx) +{ + struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; + const struct dict_sql_map *map; + ARRAY_TYPE(const_string) values; + const char *const *sql_fields; + unsigned int i, count; + + map = sql_dict_iterate_find_next_map(ctx, &values); + if (map == NULL) + return FALSE; + + T_BEGIN { + string_t *query = t_str_new(256); + + str_printfa(query, "SELECT %s", map->value_field); + /* get all missing fields */ + sql_fields = array_get(&map->sql_fields, &count); + for (i = array_count(&values); i < count; i++) + str_printfa(query, ",%s", sql_fields[i]); + str_printfa(query, " FROM %s ", map->table); + sql_dict_where_build(dict, map, &values, ctx->path, query); + + if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) { + str_append(query, "ORDER BY "); + for (i = array_count(&values); i < count; i++) { + str_printfa(query, "%s", sql_fields[i]); + if (i < count-1) + str_append_c(query, ','); + } + } else if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0) + str_printfa(query, "ORDER BY %s", map->value_field); + ctx->result = sql_query_s(dict->db, str_c(query)); + } T_END; + + ctx->map = map; + return TRUE; +} + static struct dict_iterate_context * sql_dict_iterate_init(struct dict *_dict, const char *path, enum dict_iterate_flags flags) { - struct sql_dict *dict = (struct sql_dict *)_dict; struct sql_dict_iterate_context *ctx; - unsigned int len; - bool priv; ctx = i_new(struct sql_dict_iterate_context, 1); ctx->ctx.dict = _dict; + ctx->path = i_strdup(path); + ctx->flags = flags; + ctx->key = str_new(default_pool, 256); + str_append(ctx->key, path); + ctx->key_prefix_len = str_len(ctx->key); - if (sql_path_fix(&path, &priv) < 0) { + if (!sql_dict_iterate_next_query(ctx)) { + i_error("sql dict iterate: Invalid/unmapped path: %s", path); ctx->result = NULL; return &ctx->ctx; } - ctx->priv = priv; - - T_BEGIN { - string_t *query = t_str_new(256); - str_printfa(query, "SELECT %s, %s FROM %s WHERE ", - dict->where_field, dict->select_field, - dict->table); - len = str_len(query); - - if (*path != '\0') { - str_printfa(query, "%s LIKE '%s/%%' AND ", - dict->where_field, - sql_escape_string(dict->db, path)); - } - if (priv) { - str_printfa(query, "%s = '%s' AND ", - dict->username_field, - sql_escape_string(dict->db, - dict->username)); - } - if ((flags & DICT_ITERATE_FLAG_RECURSE) != 0) { - /* get everything */ - } else if (*path == '\0') { - str_printfa(query, "%s NOT LIKE '%%/%%' AND ", - dict->where_field); - } else { - str_printfa(query, "%s NOT LIKE '%s/%%/%%' AND ", - dict->where_field, - sql_escape_string(dict->db, path)); - } - if (str_len(query) == len) - str_truncate(query, str_len(query) - 6); - else - str_truncate(query, str_len(query) - 4); - - if ((flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) - str_printfa(query, "ORDER BY %s", dict->where_field); - else if ((flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0) - str_printfa(query, "ORDER BY %s", dict->select_field); - ctx->result = sql_query_s(dict->db, str_c(query)); - } T_END; - return &ctx->ctx; } @@ -277,24 +308,41 @@ { struct sql_dict_iterate_context *ctx = (struct sql_dict_iterate_context *)_ctx; - const char *key; + const char *p; + unsigned int i, count; int ret; if (ctx->result == NULL) return -1; - if ((ret = sql_result_next_row(ctx->result)) <= 0) + while ((ret = sql_result_next_row(ctx->result)) == 0) { + /* see if there are more results in the next map */ + if (!sql_dict_iterate_next_query(ctx)) + return 0; + } + if (ret < 0) { + i_error("dict sql iterate failed: %s", + sql_result_get_error(ctx->result)); return ret; + } - key = sql_result_get_field_value(ctx->result, 0); - i_free(ctx->prev_key); - if (ctx->priv) - ctx->prev_key = i_strconcat(DICT_PATH_PRIVATE, key, NULL); - else - ctx->prev_key = i_strconcat(DICT_PATH_SHARED, key, NULL); + /* convert fetched row to dict key */ + str_truncate(ctx->key, ctx->key_prefix_len); + count = sql_result_get_fields_count(ctx->result); + i = 1; + for (p = ctx->map->pattern + ctx->key_prefix_len; *p != '\0'; p++) { + if (*p != '$') + str_append_c(ctx->key, *p); + else { + i_assert(i < count); + str_append(ctx->key, + sql_result_get_field_value(ctx->result, i)); + i++; + } + } - *key_r = ctx->prev_key; - *value_r = sql_result_get_field_value(ctx->result, 1); + *key_r = str_c(ctx->key); + *value_r = sql_result_get_field_value(ctx->result, 0); return 1; } @@ -304,7 +352,8 @@ (struct sql_dict_iterate_context *)_ctx; sql_result_free(ctx->result); - i_free(ctx->prev_key); + str_free(&ctx->key); + i_free(ctx->path); i_free(ctx); } @@ -354,32 +403,51 @@ } static const char * -sql_dict_set_query(struct sql_dict *dict, const char *key, const char *value, - bool priv) +sql_dict_set_query(struct sql_dict *dict, const struct dict_sql_map *map, + const ARRAY_TYPE(const_string) *values_arr, + const char *key, const char *value, bool inc) { - string_t *str; + const char *const *sql_fields, *const *values; + unsigned int i, count, count2; + string_t *prefix, *suffix; + + prefix = t_str_new(64); + suffix = t_str_new(256); + str_printfa(prefix, "INSERT INTO %s (%s", map->table, map->value_field); + str_append(suffix, ") VALUES ("); + if (inc) + str_append(suffix, value); + else + str_printfa(suffix, "'%s'", sql_escape_string(dict->db, value)); + if (*key == DICT_PATH_PRIVATE[0]) { + str_printfa(prefix, ",%s", map->username_field); + str_printfa(suffix, ",'%s'", + sql_escape_string(dict->db, dict->username)); + } - str = t_str_new(256); - if (priv) { - str_printfa(str, "INSERT INTO %s (%s, %s, %s) " - "VALUES ('%s', '%s', '%s')", - dict->table, dict->select_field, dict->where_field, - dict->username_field, - sql_escape_string(dict->db, value), - sql_escape_string(dict->db, key), - sql_escape_string(dict->db, dict->username)); - } else { - str_printfa(str, "INSERT INTO %s (%s, %s) VALUES ('%s', '%s')", - dict->table, dict->select_field, dict->where_field, - sql_escape_string(dict->db, value), - sql_escape_string(dict->db, key)); + /* add the other fields from the key */ + sql_fields = array_get(&map->sql_fields, &count); + values = array_get(values_arr, &count2); + i_assert(count == count2); + for (i = 0; i < count; i++) { + str_printfa(prefix, ",%s", sql_fields[i]); + str_printfa(suffix, ",'%s'", + sql_escape_string(dict->db, values[i])); } + + str_append_str(prefix, suffix); + str_append_c(prefix, ')'); if (dict->has_on_duplicate_key) { - str_printfa(str, " ON DUPLICATE KEY UPDATE %s = '%s'", - dict->select_field, - sql_escape_string(dict->db, value)); + str_printfa(prefix, " ON DUPLICATE KEY UPDATE %s =", + map->value_field); + if (inc) + str_printfa(prefix, "%s+%s", map->value_field, value); + else { + str_printfa(prefix, "'%s'", + sql_escape_string(dict->db, value)); + } } - return str_c(str); + return str_c(prefix); } static void sql_dict_set(struct dict_transaction_context *_ctx, @@ -388,10 +456,12 @@ struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; - bool priv; + const struct dict_sql_map *map; + ARRAY_TYPE(const_string) values; - if (sql_path_fix(&key, &priv) < 0) { - i_error("sql dict: Invalid key: %s", key); + map = sql_dict_find_map(dict, key, &values); + if (map == NULL) { + i_error("sql dict set: Invalid/unmapped key: %s", key); ctx->failed = TRUE; return; } @@ -399,75 +469,35 @@ T_BEGIN { const char *query; - query = sql_dict_set_query(dict, key, value, priv); + query = sql_dict_set_query(dict, map, &values, key, value, + FALSE); sql_update(ctx->sql_ctx, query); } T_END; } -static const char * -sql_dict_unset_query(struct sql_dict *dict, const char *key, bool priv) -{ - if (priv) { - return t_strdup_printf( - "DELETE FROM %s WHERE %s = '%s' AND %s = '%s'", - dict->table, dict->where_field, - sql_escape_string(dict->db, key), - dict->username_field, - sql_escape_string(dict->db, dict->username)); - } else { - return t_strdup_printf( - "DELETE FROM %s WHERE %s = '%s'", - dict->table, dict->where_field, - sql_escape_string(dict->db, key)); - } -} - static void sql_dict_unset(struct dict_transaction_context *_ctx, const char *key) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; - bool priv; + const struct dict_sql_map *map; + ARRAY_TYPE(const_string) values; - if (sql_path_fix(&key, &priv) < 0) { - i_error("sql dict: Invalid key: %s", key); + map = sql_dict_find_map(dict, key, &values); + if (map == NULL) { + i_error("sql dict unset: Invalid/unmapped key: %s", key); ctx->failed = TRUE; return; } T_BEGIN { - const char *query; - - query = sql_dict_unset_query(dict, key, priv); - sql_update(ctx->sql_ctx, query); - } T_END; -} - -static const char * -sql_dict_atomic_inc_query(struct sql_dict *dict, const char *key, - long long diff, bool priv) -{ - string_t *str; + string_t *query = t_str_new(256); - str = t_str_new(256); - if (priv) { - str_printfa(str, "INSERT INTO %s (%s, %s, %s) " - "VALUES (%lld, '%s', '%s')", - dict->table, dict->select_field, dict->where_field, - dict->username_field, - diff, sql_escape_string(dict->db, key), - sql_escape_string(dict->db, dict->username)); - } else { - str_printfa(str, "INSERT INTO %s (%s, %s) VALUES (%lld, '%s')", - dict->table, dict->select_field, dict->where_field, - diff, sql_escape_string(dict->db, key)); - } - if (dict->has_on_duplicate_key) { - str_printfa(str, " ON DUPLICATE KEY UPDATE %s = %s + %lld", - dict->select_field, dict->select_field, diff); - } - return str_c(str); + str_printfa(query, "DELETE FROM %s ", map->table); + sql_dict_where_build(dict, map, &values, key, query); + sql_update(ctx->sql_ctx, str_c(query)); + } T_END; } static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx, @@ -476,10 +506,12 @@ struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; - bool priv; + const struct dict_sql_map *map; + ARRAY_TYPE(const_string) values; - if (sql_path_fix(&key, &priv) < 0) { - i_error("sql dict: Invalid key: %s", key); + map = sql_dict_find_map(dict, key, &values); + if (map == NULL) { + i_error("sql dict atomic inc: Invalid/unmapped key: %s", key); ctx->failed = TRUE; return; } @@ -487,7 +519,8 @@ T_BEGIN { const char *query; - query = sql_dict_atomic_inc_query(dict, key, diff, priv); + query = sql_dict_set_query(dict, map, &values, key, + dec2str(diff), TRUE); sql_update(ctx->sql_ctx, query); } T_END; }
--- a/src/lib-dict/dict.c Wed Aug 27 12:08:14 2008 +0300 +++ b/src/lib-dict/dict.c Wed Aug 27 12:10:21 2008 +0300 @@ -84,9 +84,16 @@ dict->v.deinit(dict); } +static bool dict_key_prefix_is_valid(const char *key) +{ + return strncmp(key, DICT_PATH_SHARED, strlen(DICT_PATH_SHARED)) == 0 || + strncmp(key, DICT_PATH_PRIVATE, strlen(DICT_PATH_PRIVATE)) == 0; +} + int dict_lookup(struct dict *dict, pool_t pool, const char *key, const char **value_r) { + i_assert(dict_key_prefix_is_valid(key)); return dict->v.lookup(dict, pool, key, value_r); } @@ -94,6 +101,7 @@ dict_iterate_init(struct dict *dict, const char *path, enum dict_iterate_flags flags) { + i_assert(dict_key_prefix_is_valid(path)); return dict->v.iterate_init(dict, path, flags); } @@ -135,6 +143,8 @@ void dict_set(struct dict_transaction_context *ctx, const char *key, const char *value) { + i_assert(dict_key_prefix_is_valid(key)); + ctx->dict->v.set(ctx, key, value); ctx->changed = TRUE; } @@ -142,6 +152,8 @@ void dict_unset(struct dict_transaction_context *ctx, const char *key) { + i_assert(dict_key_prefix_is_valid(key)); + ctx->dict->v.unset(ctx, key); ctx->changed = TRUE; } @@ -149,6 +161,8 @@ void dict_atomic_inc(struct dict_transaction_context *ctx, const char *key, long long diff) { + i_assert(dict_key_prefix_is_valid(key)); + if (diff != 0) { ctx->dict->v.atomic_inc(ctx, key, diff); ctx->changed = TRUE;