view src/login/ssl-proxy.c @ 0:3b1985cbc908 HEAD

Initial revision
author Timo Sirainen <tss@iki.fi>
date Fri, 09 Aug 2002 12:15:38 +0300
parents
children 82b7de533f98
line wrap: on
line source

/* Copyright (C) 2002 Timo Sirainen */

#include "common.h"
#include "ioloop.h"
#include "network.h"
#include "ssl-proxy.h"

#ifdef HAVE_SSL

#include <stdlib.h>
#include <gnutls/gnutls.h>

typedef struct {
	GNUTLS_STATE state;
	int fd_ssl, fd_plain;
	IO io_ssl, io_plain;

	unsigned char outbuf_plain[1024];
	unsigned int outbuf_pos_plain;

	unsigned int send_left_ssl, send_left_plain;
} SSLProxy;

#define DH_BITS 1024

const int protocol_priority[] =
	{ GNUTLS_TLS1, GNUTLS_SSL3, 0 };
const int kx_priority[] =
	{ GNUTLS_KX_RSA, GNUTLS_KX_DHE_RSA, 0 };
const int cipher_priority[] =
	{ GNUTLS_CIPHER_RIJNDAEL_CBC, GNUTLS_CIPHER_3DES_CBC, 0 };
const int comp_priority[] =
	{ GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
const int mac_priority[] =
	{ GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };

static GNUTLS_CERTIFICATE_SERVER_CREDENTIALS x509_cred;
static GNUTLS_DH_PARAMS dh_params;

static void ssl_input(void *user_data, int handle, IO io);
static void plain_input(void *user_data, int handle, IO io);
static void ssl_proxy_destroy(SSLProxy *proxy);

static int proxy_recv_ssl(SSLProxy *proxy, void *data, unsigned int size)
{
	int rcvd;

	rcvd = gnutls_record_recv(proxy->state, data, size);
	if (rcvd > 0)
		return rcvd;

	if (rcvd == 0) {
		/* disconnected */
		ssl_proxy_destroy(proxy);
		return -1;
	}

	if (!gnutls_error_is_fatal(rcvd))
		return 0;

	/* fatal error occured */
	i_warning("Error reading from SSL client: %s", gnutls_strerror(rcvd));
	ssl_proxy_destroy(proxy);
	return -1;
}

static int proxy_send_ssl(SSLProxy *proxy, const void *data, unsigned int size)
{
	int sent;

	sent = gnutls_record_send(proxy->state, data, size);
	if (sent >= 0)
		return sent;

	if (!gnutls_error_is_fatal(sent))
		return 0;

	/* error occured */
	i_warning("Error sending to SSL client: %s", gnutls_strerror(sent));
	ssl_proxy_destroy(proxy);
	return -1;
}

static void ssl_proxy_destroy(SSLProxy *proxy)
{
	gnutls_deinit(proxy->state);

	(void)net_disconnect(proxy->fd_ssl);
	(void)net_disconnect(proxy->fd_plain);

	io_remove(proxy->io_ssl);
	io_remove(proxy->io_plain);

	i_free(proxy);
}

static void ssl_output(void *user_data, int fd __attr_unused__,
		       IO io __attr_unused__)
{
        SSLProxy *proxy = user_data;
	int sent;

	sent = net_transmit(proxy->fd_plain,
			    proxy->outbuf_plain + proxy->outbuf_pos_plain,
			    proxy->send_left_plain);
	if (sent < 0) {
		/* disconnected */
		ssl_proxy_destroy(proxy);
		return;
	}

	proxy->send_left_plain -= sent;
	proxy->outbuf_pos_plain += sent;

	if (proxy->send_left_plain > 0)
		return;

	/* everything is sent, start reading again */
	io_remove(proxy->io_ssl);
	proxy->io_ssl = io_add(proxy->fd_ssl, IO_READ, ssl_input, proxy);
}

static void ssl_input(void *user_data, int fd __attr_unused__,
		      IO io __attr_unused__)
{
        SSLProxy *proxy = user_data;
	int rcvd, sent;

	rcvd = proxy_recv_ssl(proxy, proxy->outbuf_plain,
			      sizeof(proxy->outbuf_plain));
	if (rcvd <= 0)
		return;

	sent = net_transmit(proxy->fd_plain, proxy->outbuf_plain,
			    (unsigned int) rcvd);
	if (sent == rcvd)
		return;

	if (sent < 0) {
		/* disconnected */
		ssl_proxy_destroy(proxy);
		return;
	}

	/* everything wasn't sent - don't read anything until we've
	   sent it all */
        proxy->outbuf_pos_plain = 0;
	proxy->send_left_plain = rcvd - sent;

	io_remove(proxy->io_ssl);
	proxy->io_ssl = io_add(proxy->fd_ssl, IO_WRITE, ssl_output, proxy);
}

static void plain_output(void *user_data, int fd __attr_unused__,
			 IO io __attr_unused__)
{
	SSLProxy *proxy = user_data;
	int sent;

	/* FIXME: (void*) 1 is horrible kludge, but there's no need for us
	   to store the data as gnutls does it already, maybe it needes an
	   api change or some clarification how to do it better.. */
	sent = proxy_send_ssl(proxy, (void *) 1, proxy->send_left_ssl);
	if (sent <= 0)
		return;

	proxy->send_left_ssl -= sent;
	if (proxy->send_left_ssl > 0)
		return;

	/* everything is sent, start reading again */
	io_remove(proxy->io_plain);
	proxy->io_plain = io_add(proxy->fd_plain, IO_READ, plain_input, proxy);
}

static void plain_input(void *user_data, int fd __attr_unused__,
			IO io __attr_unused__)
{
	SSLProxy *proxy = user_data;
	char buf[1024];
	int rcvd, sent;

	rcvd = net_receive(proxy->fd_plain, buf, sizeof(buf));
	if (rcvd < 0) {
		/* disconnected */
		gnutls_bye(proxy->state, 1);
		ssl_proxy_destroy(proxy);
		return;
	}

	sent = proxy_send_ssl(proxy, buf, (unsigned int) rcvd);
	if (sent < 0 || sent == rcvd)
		return;

	/* everything wasn't sent - don't read anything until we've
	   sent it all */
	proxy->send_left_ssl = rcvd - sent;

	io_remove(proxy->io_plain);
	proxy->io_plain = io_add(proxy->fd_ssl, IO_WRITE, plain_output, proxy);
}

static GNUTLS_STATE initialize_state(void)
{
	GNUTLS_STATE state;

	gnutls_init(&state, GNUTLS_SERVER);

	gnutls_protocol_set_priority(state, protocol_priority);
	gnutls_cipher_set_priority(state, cipher_priority);
	gnutls_compression_set_priority(state, comp_priority);
	gnutls_kx_set_priority(state, kx_priority);
	gnutls_mac_set_priority(state, mac_priority);

	gnutls_cred_set(state, GNUTLS_CRD_CERTIFICATE, x509_cred);

	/*gnutls_certificate_server_set_request(state, GNUTLS_CERT_REQUEST);*/

	gnutls_dh_set_prime_bits(state, DH_BITS);
	return state;
}

int ssl_proxy_new(int fd)
{
        SSLProxy *proxy;
	GNUTLS_STATE state;
	int ret, sfd[2];

	state = initialize_state();
	gnutls_transport_set_ptr(state, fd);

	net_set_nonblock(fd, FALSE); /* FIXME: blocks! */
	if ((ret = gnutls_handshake(state)) < 0) {
		gnutls_deinit(state);
		return -1;
	}
	net_set_nonblock(fd, TRUE);

	if (socketpair(AF_UNIX, SOCK_STREAM, 0, sfd) == -1) {
		i_error("socketpair() failed: %m");
		gnutls_deinit(state);
		return -1;
	}

	proxy = i_new(SSLProxy, 1);
	proxy->state = state;
	proxy->fd_ssl = fd;
	proxy->fd_plain = sfd[0];

	proxy->io_ssl = io_add(proxy->fd_ssl, IO_READ, ssl_input, proxy);
	proxy->io_plain = io_add(proxy->fd_plain, IO_READ, plain_input, proxy);

	return sfd[1];
}

static void generate_dh_primes(void)
{
	gnutls_datum prime, generator;
	int ret;

	/* Generate Diffie Hellman parameters - for use with DHE
	   kx algorithms. These should be discarded and regenerated
	   once a day, once a week or once a month. Depends on the
	   security requirements. */
	if ((ret = gnutls_dh_params_init(&dh_params)) < 0) {
		i_fatal("gnutls_dh_params_init() failed: %s",
			gnutls_strerror(ret));
	}

	ret = gnutls_dh_params_generate(&prime, &generator, DH_BITS);
	if (ret < 0) {
		i_fatal("gnutls_dh_params_generate() failed: %s",
			gnutls_strerror(ret));
	}

	ret = gnutls_dh_params_set(dh_params, prime, generator, DH_BITS);
	if (ret < 0) {
		i_fatal("gnutls_dh_params_set() failed: %s",
			gnutls_strerror(ret));
	}

	free(prime.data);
	free(generator.data);
}

void ssl_proxy_init(void)
{
	const char *certfile, *keyfile;
	int ret;

	certfile = getenv("SSL_CERT_FILE");
	keyfile = getenv("SSL_KEY_FILE");

	if (certfile == NULL)
		i_fatal("SSL_CERT_FILE environment not set");
	if (keyfile == NULL)
		i_fatal("SSL_KEY_FILE environment not set");

	if ((ret = gnutls_global_init() < 0)) {
		i_fatal("gnu_tls_global_init() failed: %s",
			gnutls_strerror(ret));
	}

	if ((ret = gnutls_certificate_allocate_cred(&x509_cred)) < 0) {
		i_fatal("gnutls_certificate_allocate_cred() failed: %s",
			gnutls_strerror(ret));
	}

	ret = gnutls_certificate_set_x509_key_file(x509_cred, certfile, keyfile,
						   GNUTLS_X509_FMT_PEM);
	if (ret < 0) {
		i_fatal("Can't load certificate files %s and %s: %s",
			certfile, keyfile, gnutls_strerror(ret));
	}

	generate_dh_primes();
	gnutls_certificate_set_dh_params(x509_cred, dh_params);
}

void ssl_proxy_deinit(void)
{
	gnutls_certificate_free_cred(x509_cred);
	gnutls_global_deinit();
}

#else

/* no SSL support */

int ssl_proxy_new(int fd) { return -1; }
void ssl_proxy_init(void) {}
void ssl_proxy_deinit(void) {}

#endif