view src/lib-auth/auth-server-connection.c @ 2708:f1e9f3ec8135 HEAD

Buffer API change: we no longer support limited sized buffers where writes past limit wouldn't kill the process. They weren't used hardly anywhere, they could have hidden bugs and the code for handling them was too complex. This also changed base64 and hex-binary APIs.
author Timo Sirainen <tss@iki.fi>
date Fri, 08 Oct 2004 20:51:47 +0300
parents d141e1bfdd63
children 0f31778d3c34
line wrap: on
line source

/* Copyright (C) 2003 Timo Sirainen */

#include "lib.h"
#include "buffer.h"
#include "hash.h"
#include "ioloop.h"
#include "istream.h"
#include "ostream.h"
#include "network.h"
#include "auth-client.h"
#include "auth-server-connection.h"
#include "auth-server-request.h"

#include <unistd.h>

/* Maximum size for an auth reply. 50kB should be more than enough. */
#define MAX_INBUF_SIZE (1024*50)

#define MAX_OUTBUF_SIZE \
	(sizeof(struct auth_client_request_continue) + \
	 AUTH_CLIENT_MAX_REQUEST_DATA_SIZE)

static void auth_server_connection_unref(struct auth_server_connection *conn);

static void update_available_auth_mechs(struct auth_server_connection *conn)
{
	struct auth_client *client = conn->client;
	const struct auth_mech_desc *mech;
	struct auth_mech_desc *new_mech;
	unsigned int i;

	mech = conn->available_auth_mechs;
	for (i = 0; i < conn->available_auth_mechs_count; i++) {
		if (auth_client_find_mech(client, mech[i].name) == NULL) {
			new_mech = buffer_append_space_unsafe(
				client->available_auth_mechs, sizeof(*mech));
			*new_mech = mech[i];
			new_mech->name = i_strdup(mech[i].name);
		}
	}
}

static void auth_handle_handshake(struct auth_server_connection *conn,
				  struct auth_client_handshake_reply *handshake,
				  const unsigned char *data)
{
	struct auth_client_handshake_mech_desc handshake_mech_desc;
	struct auth_mech_desc mech_desc;
	buffer_t *buf;
	unsigned int i;

	if (handshake->data_size == 0 || data[handshake->data_size-1] != '\0' ||
	    handshake->mech_count * sizeof(handshake_mech_desc) >=
	    handshake->data_size)  {
		i_error("BUG: Auth server sent corrupted handshake");
		auth_server_connection_destroy(conn, FALSE);
		return;
	}

	buf = buffer_create_dynamic(conn->pool, sizeof(mech_desc) *
				    handshake->mech_count);
	for (i = 0; i < handshake->mech_count; i++) {
		memcpy(&handshake_mech_desc,
		       data + sizeof(handshake_mech_desc) * i,
		       sizeof(handshake_mech_desc));

		if (handshake_mech_desc.name_idx >= handshake->data_size) {
			i_error("BUG: Auth server sent corrupted handshake");
			auth_server_connection_destroy(conn, FALSE);
			return;
		}

		mech_desc.name = p_strdup(conn->pool, (const char *)data +
					  handshake_mech_desc.name_idx);
		mech_desc.plaintext = handshake_mech_desc.plaintext;
		mech_desc.advertise = handshake_mech_desc.advertise;
		buffer_append(buf, &mech_desc, sizeof(mech_desc));

		if (strcmp(mech_desc.name, "PLAIN") == 0)
			conn->has_plain_mech = TRUE;
	}

	conn->server_pid = handshake->server_pid;
	conn->connect_uid = handshake->connect_uid;
	conn->available_auth_mechs_count =
		buffer_get_used_size(buf) / sizeof(mech_desc);
	conn->available_auth_mechs = buffer_free_without_data(buf);
	conn->handshake_received = TRUE;

        conn->client->conn_waiting_handshake_count--;
	update_available_auth_mechs(conn);

	if (conn->client->connect_notify_callback != NULL &&
	    auth_client_is_connected(conn->client)) {
		conn->client->connect_notify_callback(conn->client, TRUE,
				conn->client->connect_notify_context);
	}
}

static void auth_client_input(void *context)
{
	struct auth_server_connection *conn = context;
	struct auth_client_handshake_reply handshake;
	const unsigned char *data;
	size_t size;

	switch (i_stream_read(conn->input)) {
	case 0:
		return;
	case -1:
		/* disconnected */
		auth_server_connection_destroy(conn, TRUE);
		return;
	case -2:
		/* buffer full - can't happen unless auth is buggy */
		i_error("BUG: Auth server sent us more than %d bytes of data",
			MAX_INBUF_SIZE);
		auth_server_connection_destroy(conn, FALSE);
		return;
	}

	if (!conn->handshake_received) {
		data = i_stream_get_data(conn->input, &size);
		if (size < sizeof(handshake))
			return;

		memcpy(&handshake, data, sizeof(handshake));
		if (size < sizeof(handshake) + handshake.data_size)
			return;

		conn->refcount++;
		auth_handle_handshake(conn, &handshake,
				      data + sizeof(handshake));
		i_stream_skip(conn->input, sizeof(handshake) +
			      handshake.data_size);
		auth_server_connection_unref(conn);
		return;
	}

	if (!conn->reply_received) {
		data = i_stream_get_data(conn->input, &size);
		if (size < sizeof(conn->reply))
			return;

		memcpy(&conn->reply, data, sizeof(conn->reply));
		i_stream_skip(conn->input, sizeof(conn->reply));
		conn->reply_received = TRUE;
	}

	data = i_stream_get_data(conn->input, &size);
	if (size < conn->reply.data_size)
		return;

	/* we've got a full reply */
	conn->refcount++;
	conn->reply_received = FALSE;

	auth_server_request_handle_reply(conn, &conn->reply, data);
	i_stream_skip(conn->input, conn->reply.data_size);

	auth_server_connection_unref(conn);
}

struct auth_server_connection *
auth_server_connection_new(struct auth_client *client, const char *path)
{
	struct auth_server_connection *conn;
	struct auth_client_handshake_request handshake;
	pool_t pool;
	int fd;

	fd = net_connect_unix(path);
	if (fd == -1) {
		i_error("Can't connect to auth server at %s: %m", path);
		return NULL;
	}

	/* use blocking connection since we depend on auth server -
	   if it's slow, just wait */

	pool = pool_alloconly_create("Auth connection", 1024);
	conn = p_new(pool, struct auth_server_connection, 1);
	conn->refcount = 1;
	conn->pool = pool;

	conn->client = client;
	conn->path = p_strdup(pool, path);
	conn->fd = fd;
	if (client->ext_input_add == NULL)
		conn->io = io_add(fd, IO_READ, auth_client_input, conn);
	else {
		conn->ext_input_io =
			client->ext_input_add(fd, auth_client_input, conn);
	}
	conn->input = i_stream_create_file(fd, default_pool, MAX_INBUF_SIZE,
					   FALSE);
	conn->output = o_stream_create_file(fd, default_pool, (size_t)-1,
					    FALSE);
	conn->requests = hash_create(default_pool, pool, 100, NULL, NULL);

	conn->next = client->connections;
	client->connections = conn;

	/* send our handshake */
	memset(&handshake, 0, sizeof(handshake));
	handshake.client_pid = client->pid;

        client->conn_waiting_handshake_count++;
	if (o_stream_send(conn->output, &handshake, sizeof(handshake)) < 0) {
		errno = conn->output->stream_errno;
		i_warning("Error sending handshake to auth server: %m");
		auth_server_connection_destroy(conn, TRUE);
		return NULL;
	}
	return conn;
}

void auth_server_connection_destroy(struct auth_server_connection *conn,
				    int reconnect)
{
	struct auth_client *client = conn->client;
	struct auth_server_connection **pos;

	if (conn->fd == -1)
		return;

        pos = &conn->client->connections;
	for (; *pos != NULL; pos = &(*pos)->next) {
		if (*pos == conn) {
			*pos = conn->next;
			break;
		}
	}

	if (!conn->handshake_received)
		client->conn_waiting_handshake_count--;

	if (conn->ext_input_io != NULL) {
		client->ext_input_remove(conn->ext_input_io);
		conn->ext_input_io = NULL;
	}
	if (conn->io != NULL) {
		io_remove(conn->io);
		conn->io = NULL;
	}

	i_stream_close(conn->input);
	o_stream_close(conn->output);

	if (close(conn->fd) < 0)
		i_error("close(auth) failed: %m");
	conn->fd = -1;

	auth_server_requests_remove_all(conn);
        auth_server_connection_unref(conn);

	if (reconnect)
		auth_client_connect_missing_servers(client);
	else if (client->connect_notify_callback != NULL) {
		client->connect_notify_callback(client,
				auth_client_is_connected(client),
				client->connect_notify_context);
	}
}

static void auth_server_connection_unref(struct auth_server_connection *conn)
{
	if (--conn->refcount > 0)
		return;
	i_assert(conn->refcount == 0);

	hash_destroy(conn->requests);

	i_stream_unref(conn->input);
	o_stream_unref(conn->output);
	pool_unref(conn->pool);
}

struct auth_server_connection *
auth_server_connection_find_path(struct auth_client *client, const char *path)
{
	struct auth_server_connection *conn;

	for (conn = client->connections; conn != NULL; conn = conn->next) {
		if (strcmp(conn->path, path) == 0)
			return conn;
	}

	return NULL;
}

struct auth_server_connection *
auth_server_connection_find_mech(struct auth_client *client,
				 const char *name, const char **error_r)
{
	struct auth_server_connection *conn;
	const struct auth_mech_desc *mech;
	unsigned int i;

	for (conn = client->connections; conn != NULL; conn = conn->next) {
		mech = conn->available_auth_mechs;
		for (i = 0; i < conn->available_auth_mechs_count; i++) {
			if (strcasecmp(mech[i].name, name) == 0)
				return conn;
		}
	}

	if (auth_client_find_mech(client, name) == NULL)
		*error_r = "Unsupported authentication mechanism";
	else {
		*error_r = "Authentication server isn't connected, "
			"try again later..";
	}

	return NULL;
}