view src/doveadm/server-connection.c @ 22595:a3632f4d817f

doveadm: Return after destroying connection Prevents NULL deferences, found by coverity.
author Aki Tuomi <aki.tuomi@dovecot.fi>
date Fri, 06 Oct 2017 08:46:00 +0300
parents 36f70fbf6d2f
children cf66220d281e
line wrap: on
line source

/* Copyright (c) 2010-2017 Dovecot authors, see the included COPYING file */

#include "lib.h"
#include "array.h"
#include "base64.h"
#include "ioloop.h"
#include "net.h"
#include "istream.h"
#include "istream-multiplex.h"
#include "ostream.h"
#include "ostream-dot.h"
#include "str.h"
#include "strescape.h"
#include "iostream-ssl.h"
#include "master-service.h"
#include "master-service-settings.h"
#include "settings-parser.h"
#include "doveadm.h"
#include "doveadm-print.h"
#include "doveadm-util.h"
#include "doveadm-server.h"
#include "doveadm-settings.h"
#include "server-connection.h"

#include <sysexits.h>
#include <unistd.h>

#define DOVEADM_LOG_CHANNEL_ID 'L'

#define MAX_INBUF_SIZE (1024*32)

enum server_reply_state {
	SERVER_REPLY_STATE_DONE = 0,
	SERVER_REPLY_STATE_PRINT,
	SERVER_REPLY_STATE_RET
};

struct server_connection {
	struct doveadm_server *server;

	pool_t pool;
	struct doveadm_settings *set;

	int fd;
	unsigned int minor;

	struct io *io;
	struct io *io_log;
	struct istream *input;
	struct istream *log_input;
	struct ostream *output;
	struct ssl_iostream *ssl_iostream;
	struct timeout *to_input;

	struct istream *cmd_input;
	struct ostream *cmd_output;
	const char *delayed_cmd;
	server_cmd_callback_t *callback;
	void *context;

	enum server_reply_state state;

	unsigned int handshaked:1;
	unsigned int authenticated:1;
	unsigned int streaming:1;
};

static struct server_connection *printing_conn = NULL;
static ARRAY(struct doveadm_server *) print_pending_servers = ARRAY_INIT;

static void server_connection_input(struct server_connection *conn);
static bool server_connection_input_one(struct server_connection *conn);

static void server_set_print_pending(struct doveadm_server *server)
{
	struct doveadm_server *const *serverp;

	if (!array_is_created(&print_pending_servers))
		i_array_init(&print_pending_servers, 16);
	array_foreach(&print_pending_servers, serverp) {
		if (*serverp == server)
			return;
	}
	array_append(&print_pending_servers, &server, 1);
}

static void server_print_connection_released(struct doveadm_server *server)
{
	struct server_connection *const *conns;
	unsigned int i, count;

	conns = array_get(&server->connections, &count);
	for (i = 0; i < count; i++) {
		if (conns[i]->io != NULL)
			continue;

		conns[i]->io = io_add(conns[i]->fd, IO_READ,
				      server_connection_input, conns[i]);
		conns[i]->to_input = timeout_add_short(0,
			server_connection_input, conns[i]);
	}
}

static void print_connection_released(void)
{
	struct doveadm_server *const *serverp;

	printing_conn = NULL;
	if (!array_is_created(&print_pending_servers))
		return;

	array_foreach(&print_pending_servers, serverp)
		server_print_connection_released(*serverp);
	array_free(&print_pending_servers);
}

static int server_connection_send_cmd_input_more(struct server_connection *conn)
{
	off_t ret;

	/* ostream-dot writes only up to max buffer size, so keep it non-zero */
	o_stream_set_max_buffer_size(conn->cmd_output, IO_BLOCK_SIZE);
	ret = o_stream_send_istream(conn->cmd_output, conn->cmd_input);
	o_stream_set_max_buffer_size(conn->cmd_output, (size_t)-1);

	if (ret >= 0 && i_stream_have_bytes_left(conn->cmd_input)) {
		o_stream_set_flush_pending(conn->cmd_output, TRUE);
		return 0;
	}
	if (conn->cmd_input->stream_errno != 0) {
		i_error("read(%s) failed: %s",
			i_stream_get_name(conn->cmd_input),
			i_stream_get_error(conn->cmd_input));
	} else if (conn->cmd_output->stream_errno != 0 ||
		   o_stream_flush(conn->cmd_output) < 0) {
		i_error("write(%s) failed: %s",
			o_stream_get_name(conn->cmd_output),
			o_stream_get_error(conn->cmd_output));
	}

	i_stream_destroy(&conn->cmd_input);
	o_stream_destroy(&conn->cmd_output);
	return ret < 0 ? -1 : 1;
}

static void server_connection_send_cmd_input(struct server_connection *conn)
{
	if (conn->cmd_input == NULL)
		return;

	conn->cmd_output = o_stream_create_dot(conn->output, TRUE);
	(void)server_connection_send_cmd_input_more(conn);
}

static int server_connection_output(struct server_connection *conn)
{
	int ret;

	o_stream_cork(conn->output);
	ret = o_stream_flush(conn->output);
	if (ret > 0 && conn->cmd_input != NULL && conn->delayed_cmd == NULL)
		ret = server_connection_send_cmd_input_more(conn);
	o_stream_uncork(conn->output);
	if (ret < 0)
		server_connection_destroy(&conn);
	return ret;
}

static void
server_connection_callback(struct server_connection *conn,
			   int exit_code, const char *error)
{
	server_cmd_callback_t *callback = conn->callback;

	conn->callback = NULL;
	callback(exit_code, error, conn->context);
}

static void stream_data(string_t *str, const unsigned char *data, size_t size)
{
	str_truncate(str, 0);
	str_append_tabunescaped(str, data, size);
	doveadm_print_stream(str->data, str->used);
}

static void server_flush_field(struct server_connection *conn, string_t *str,
			       const unsigned char *data, size_t size)
{
	if (conn->streaming) {
		conn->streaming = FALSE;
		if (size > 0)
			stream_data(str, data, size);
		doveadm_print_stream("", 0);
	} else {
		str_truncate(str, 0);
		str_append_tabunescaped(str, data, size);
		doveadm_print(str_c(str));
	}
}

static void
server_handle_input(struct server_connection *conn,
		    const unsigned char *data, size_t size)
{
	string_t *str;
	size_t i, start;

	if (printing_conn == conn) {
		/* continue printing */
	} else if (printing_conn == NULL) {
		printing_conn = conn;
	} else {
		/* someone else is printing. don't continue until it
		   goes away */
		server_set_print_pending(conn->server);
		io_remove(&conn->io);
		return;
	}

	if (data[size-1] == '\001') {
		/* last character is an escape */
		size--;
	}

	str = t_str_new(128);
	for (i = start = 0; i < size; i++) {
		if (data[i] == '\n') {
			if (i != start) {
				i_error("doveadm server sent broken print input");
				server_connection_destroy(&conn);
				return;
			}
			conn->state = SERVER_REPLY_STATE_RET;
			i_stream_skip(conn->input, i + 1);

			print_connection_released();
			return;
		}
		if (data[i] == '\t') {
			server_flush_field(conn, str, data + start, i - start);
			start = i + 1;
		}
	}
	if (start != size) {
		conn->streaming = TRUE;
		stream_data(str, data + start, size - start);
	}
	i_stream_skip(conn->input, size);
}

static void server_connection_authenticated(struct server_connection *conn)
{
	conn->authenticated = TRUE;
	if (conn->delayed_cmd != NULL) {
		o_stream_nsend_str(conn->output, conn->delayed_cmd);
		conn->delayed_cmd = NULL;
		server_connection_send_cmd_input(conn);
	}
}

static int
server_connection_authenticate(struct server_connection *conn)
{
	string_t *plain = t_str_new(128);
	string_t *cmd = t_str_new(128);

	if (*conn->set->doveadm_password == '\0') {
		i_error("doveadm_password not set, "
			"can't authenticate to remote server");
		return -1;
	}

	str_append_c(plain, '\0');
	str_append(plain, conn->set->doveadm_username);
	str_append_c(plain, '\0');
	str_append(plain, conn->set->doveadm_password);

	str_append(cmd, "PLAIN\t");
	base64_encode(plain->data, plain->used, cmd);
	str_append_c(cmd, '\n');

	o_stream_nsend(conn->output, cmd->data, cmd->used);
	return 0;
}

static void server_log_disconnect_error(struct server_connection *conn)
{
	const char *error;

	error = conn->ssl_iostream == NULL ? NULL :
		ssl_iostream_get_last_error(conn->ssl_iostream);
	if (error == NULL) {
		error = conn->input->stream_errno == 0 ? "EOF" :
			strerror(conn->input->stream_errno);
	}
	i_error("doveadm server disconnected before handshake: %s", error);
}

static void server_connection_print_log(struct server_connection *conn)
{
	const char *line;
	struct failure_context ctx;
	i_zero(&ctx);

	while((line = i_stream_read_next_line(conn->log_input))!=NULL) {
		/* skip empty lines */
		if (*line == '\0') continue;

		if (!doveadm_log_type_from_char(line[0], &ctx.type))
			i_warning("Doveadm server sent invalid log type 0x%02x",
				  line[0]);
		line++;
		i_log_type(&ctx, "remote(%s): %s", conn->server->name, line);
	}
}

static void server_connection_start_multiplex(struct server_connection *conn)
{
	struct istream *is = conn->input;
	conn->input = i_stream_create_multiplex(is, MAX_INBUF_SIZE);
	i_stream_unref(&is);
	io_remove(&conn->io);
	conn->io = io_add_istream(conn->input, server_connection_input, conn);
	conn->log_input = i_stream_multiplex_add_channel(conn->input, DOVEADM_LOG_CHANNEL_ID);
	conn->io_log = io_add_istream(conn->log_input, server_connection_print_log, conn);
	i_stream_set_return_partial_line(conn->log_input, TRUE);
}

static void server_connection_input(struct server_connection *conn)
{
	const char *line;

	if (conn->to_input != NULL)
		timeout_remove(&conn->to_input);

	if (!conn->handshaked || !conn->authenticated) {
		while((line = i_stream_read_next_line(conn->input)) != NULL) {
			if (strncmp(line, "VERSION\t", 8) == 0) {
				if (!version_string_verify_full(line, "doveadm-client",
								DOVEADM_SERVER_PROTOCOL_VERSION_MAJOR,
								&conn->minor)) {
					i_error("doveadm server not compatible with this client"
						"(mixed old and new binaries?)");
					server_connection_destroy(&conn);
					return;
				}
				continue;
			}
			if (strcmp(line, "+") == 0) {
				if (conn->minor > 0)
					server_connection_start_multiplex(conn);
				server_connection_authenticated(conn);
				break;
			} else if (strcmp(line, "-") == 0) {
				if (!conn->handshaked &&
				    server_connection_authenticate(conn) < 0) {
					server_connection_destroy(&conn);
					return;
				} else if (conn->handshaked) {
					i_error("doveadm authentication failed (%s)",
						line+1);
					server_connection_destroy(&conn);
					return;
				}
			} else {
				i_error("doveadm server sent invalid handshake: %s",
					line);
				server_connection_destroy(&conn);
				return;
			}
			conn->handshaked = TRUE;
		}

		if (line == NULL) {
			if (conn->input->eof || conn->input->stream_errno != 0) {
				server_log_disconnect_error(conn);
				server_connection_destroy(&conn);
			}
		}
		return;
	}

	if (i_stream_read(conn->input) < 0) {
		/* disconnected */
		server_log_disconnect_error(conn);
		server_connection_destroy(&conn);
		return;
	}

	while (server_connection_input_one(conn)) ;
}

static bool server_connection_input_one(struct server_connection *conn)
{
	const unsigned char *data;
	size_t size;
	const char *line;
	int exit_code;

	data = i_stream_get_data(conn->input, &size);
	if (size == 0)
		return FALSE;

	/* check logs */
	(void)server_connection_print_log(conn);

	switch (conn->state) {
	case SERVER_REPLY_STATE_DONE:
		i_error("doveadm server sent unexpected input");
		server_connection_destroy(&conn);
		return FALSE;
	case SERVER_REPLY_STATE_PRINT:
		server_handle_input(conn, data, size);
		if (conn->state != SERVER_REPLY_STATE_RET)
			return FALSE;
		/* fall through */
	case SERVER_REPLY_STATE_RET:
		line = i_stream_next_line(conn->input);
		if (line == NULL)
			return FALSE;
		if (line[0] == '+')
			server_connection_callback(conn, 0, "");
		else if (line[0] == '-') {
			line++;
			exit_code = doveadm_str_to_exit_code(line);
			if (exit_code == DOVEADM_EX_UNKNOWN &&
			    str_to_int(line, &exit_code) < 0) {
				/* old doveadm-server */
				exit_code = EX_TEMPFAIL;
			}
			server_connection_callback(conn, exit_code, line);
		} else {
			i_error("doveadm server sent broken input "
				"(expected cmd reply): %s", line);
			server_connection_destroy(&conn);
			return FALSE;
		}
		if (conn->callback == NULL) {
			/* we're finished, close the connection */
			server_connection_destroy(&conn);
			return FALSE;
		}
		return TRUE;
	}
	i_unreached();
}

static int server_connection_read_settings(struct server_connection *conn)
{
	const struct setting_parser_info *set_roots[] = {
		&doveadm_setting_parser_info,
		NULL
	};
	struct master_service_settings_input input;
	struct master_service_settings_output output;
	const char *error;
	in_port_t port;
	void *set;

	i_zero(&input);
	input.roots = set_roots;
	input.service = "doveadm";

	(void)net_getsockname(conn->fd, &input.local_ip, &port);
	(void)net_getpeername(conn->fd, &input.remote_ip, &port);

	if (master_service_settings_read(master_service, &input,
					 &output, &error) < 0) {
		i_error("Error reading configuration: %s", error);
		return -1;
	}
	set = master_service_settings_get_others(master_service)[0];
	conn->set = settings_dup(&doveadm_setting_parser_info, set, conn->pool);
	return 0;
}

static int server_connection_ssl_handshaked(const char **error_r, void *context)
{
	struct server_connection *conn = context;
	const char *host, *p;

	host = conn->server->name;
	p = strrchr(host, ':');
	if (p != NULL)
		host = t_strdup_until(host, p);

	if (ssl_iostream_check_cert_validity(conn->ssl_iostream, host, error_r) < 0)
		return -1;
	if (doveadm_debug)
		i_debug("%s: SSL handshake successful", conn->server->name);
	return 0;
}

static int server_connection_init_ssl(struct server_connection *conn)
{
	struct ssl_iostream_settings ssl_set;
	const char *error;

	if (conn->server->ssl_ctx == NULL)
		return 0;

	i_zero(&ssl_set);
	ssl_set.verify_remote_cert = TRUE;
	ssl_set.require_valid_cert = TRUE;
	ssl_set.verbose_invalid_cert = TRUE;

	if (io_stream_create_ssl_client(conn->server->ssl_ctx,
					conn->server->name, &ssl_set,
					&conn->input, &conn->output,
					&conn->ssl_iostream, &error) < 0) {
		i_error("Couldn't initialize SSL client: %s", error);
		return -1;
	}
	ssl_iostream_set_handshake_callback(conn->ssl_iostream,
					    server_connection_ssl_handshaked,
					    conn);
	if (ssl_iostream_handshake(conn->ssl_iostream) < 0) {
		i_error("SSL handshake failed: %s",
			ssl_iostream_get_last_error(conn->ssl_iostream));
		return -1;
	}
	return 0;
}

int server_connection_create(struct doveadm_server *server,
			     struct server_connection **conn_r)
{
	struct server_connection *conn;
	pool_t pool;

	pool = pool_alloconly_create("doveadm server connection", 1024*16);
	conn = p_new(pool, struct server_connection, 1);
	conn->pool = pool;
	conn->server = server;
	conn->fd = doveadm_connect_with_default_port(server->name,
						     doveadm_settings->doveadm_port);
	net_set_nonblock(conn->fd, TRUE);
	conn->io = io_add(conn->fd, IO_READ, server_connection_input, conn);
	conn->input = i_stream_create_fd(conn->fd, MAX_INBUF_SIZE, FALSE);
	conn->output = o_stream_create_fd(conn->fd, (size_t)-1, FALSE);
	o_stream_set_flush_callback(conn->output, server_connection_output, conn);

	i_stream_set_name(conn->input, server->name);
	o_stream_set_name(conn->output, server->name);

	array_append(&conn->server->connections, &conn, 1);

	if (server_connection_read_settings(conn) < 0 ||
	    server_connection_init_ssl(conn) < 0) {
		server_connection_destroy(&conn);
		return -1;
	}

	o_stream_set_no_error_handling(conn->output, TRUE);
	conn->state = SERVER_REPLY_STATE_DONE;
	o_stream_nsend_str(conn->output, DOVEADM_SERVER_PROTOCOL_VERSION_LINE"\n");

	*conn_r = conn;
	return 0;
}

void server_connection_destroy(struct server_connection **_conn)
{
	struct server_connection *conn = *_conn;
	struct server_connection *const *conns;
	const char *error;
	unsigned int i, count;

	*_conn = NULL;

	conns = array_get(&conn->server->connections, &count);
	for (i = 0; i < count; i++) {
		if (conns[i] == conn) {
			array_delete(&conn->server->connections, i, 1);
			break;
		}
	}

	if (conn->callback != NULL) {
		error = conn->ssl_iostream == NULL ? NULL :
			ssl_iostream_get_last_error(conn->ssl_iostream);
		if (error == NULL) {
			error = conn->input->stream_errno == 0 ? "EOF" :
				strerror(conn->input->stream_errno);
		}
		server_connection_callback(conn, SERVER_EXIT_CODE_DISCONNECTED,
					   error);
	}
	if (printing_conn == conn)
		print_connection_released();

	if (conn->to_input != NULL)
		timeout_remove(&conn->to_input);
	if (conn->input != NULL)
		i_stream_destroy(&conn->input);
	if (conn->output != NULL)
		o_stream_destroy(&conn->output);
	if (conn->cmd_input != NULL)
		i_stream_destroy(&conn->cmd_input);
	/* close cmd_output after its parent, so the "." isn't sent */
	if (conn->cmd_output != NULL)
		o_stream_destroy(&conn->cmd_output);
	if (conn->ssl_iostream != NULL)
		ssl_iostream_unref(&conn->ssl_iostream);
        if (conn->io_log != NULL)
                io_remove(&conn->io_log);
        /* make sure all logs got consumed */
        if (conn->log_input != NULL) {
                server_connection_print_log(conn);
                i_stream_unref(&conn->log_input);
	}
	if (conn->io != NULL)
		io_remove(&conn->io);
	if (conn->fd != -1) {
		if (close(conn->fd) < 0)
			i_error("close(server) failed: %m");
	}

	pool_unref(&conn->pool);
}

struct doveadm_server *
server_connection_get_server(struct server_connection *conn)
{
	return conn->server;
}

void server_connection_cmd(struct server_connection *conn, const char *line,
			   struct istream *cmd_input,
			   server_cmd_callback_t *callback, void *context)
{
	i_assert(conn->delayed_cmd == NULL);

	conn->state = SERVER_REPLY_STATE_PRINT;
	if (cmd_input != NULL) {
		i_assert(conn->cmd_input == NULL);
		i_stream_ref(cmd_input);
		conn->cmd_input = cmd_input;
	}
	if (!conn->authenticated)
		conn->delayed_cmd = p_strdup(conn->pool, line);
	else {
		o_stream_nsend_str(conn->output, line);
		server_connection_send_cmd_input(conn);
	}
	conn->callback = callback;
	conn->context = context;
}

bool server_connection_is_idle(struct server_connection *conn)
{
	return conn->callback == NULL;
}

void server_connection_extract(struct server_connection *conn,
			       struct istream **istream_r,
			       struct ostream **ostream_r,
			       struct ssl_iostream **ssl_iostream_r)
{
	*istream_r = conn->input;
	*ostream_r = conn->output;
	*ssl_iostream_r = conn->ssl_iostream;

	conn->input = NULL;
	conn->output = NULL;
	conn->ssl_iostream = NULL;
	if (conn->io != NULL)
		io_remove(&conn->io);
	conn->fd = -1;
}