diff src/login-common/ssl-proxy-openssl.c @ 4538:9d9e72374164 HEAD

Fixes to login process handling, especially with login_process_per_connection=no. Removed login_max_logging_users setting since it was somewhat weird in how it worked. Added login_max_connections to replace it with login_process_per_connection=no, and with =yes its functionality is now within login_max_processes_count.
author Timo Sirainen <tss@iki.fi>
date Sun, 06 Aug 2006 23:05:32 +0300
parents 025ffc5a3643
children 93e62d5d12e4
line wrap: on
line diff
--- a/src/login-common/ssl-proxy-openssl.c	Fri Aug 04 20:46:11 2006 +0300
+++ b/src/login-common/ssl-proxy-openssl.c	Sun Aug 06 23:05:32 2006 +0300
@@ -26,11 +26,11 @@
 /* Check every 30 minutes if parameters file has been updated */
 #define SSL_PARAMFILE_CHECK_INTERVAL (60*30)
 
-#define PLAIN_OUTPUT_OPTIMAL_SIZE 2048
-
-enum ssl_want {
-	WANT_INPUT,
-	WANT_OUTPUT
+enum ssl_io_action {
+	SSL_ADD_INPUT,
+	SSL_REMOVE_INPUT,
+	SSL_ADD_OUTPUT,
+	SSL_REMOVE_OUTPUT
 };
 
 struct ssl_proxy {
@@ -40,13 +40,10 @@
 	struct ip_addr ip;
 
 	int fd_ssl, fd_plain;
-	struct io *io_ssl, *io_plain_input;
+	struct io *io_ssl_read, *io_ssl_write, *io_plain_read, *io_plain_write;
 
-	enum ssl_want want;
-	void (*step)(struct ssl_proxy *);
-	unsigned int ssl_want_size;
-
-	struct ostream *plain_output;
+	unsigned char plainout_buf[1024];
+	unsigned int plainout_size;
 
 	unsigned char sslout_buf[1024];
 	unsigned int sslout_size;
@@ -70,8 +67,10 @@
 static struct hash_table *ssl_proxies;
 static struct ssl_parameters ssl_params;
 
-static void ssl_input(struct ssl_proxy *proxy);
-static void ssl_output(struct ssl_proxy *proxy);
+static void plain_read(void *context);
+static void plain_write(void *context);
+static void ssl_read(struct ssl_proxy *proxy);
+static void ssl_write(struct ssl_proxy *proxy);
 static void ssl_step(void *context);
 static void ssl_proxy_destroy(struct ssl_proxy *proxy);
 static void ssl_proxy_unref(struct ssl_proxy *proxy);
@@ -188,6 +187,120 @@
 	}
 }
 
+static void ssl_set_io(struct ssl_proxy *proxy, enum ssl_io_action action)
+{
+	switch (action) {
+	case SSL_ADD_INPUT:
+		if (proxy->io_ssl_read != NULL)
+			break;
+		proxy->io_ssl_read = io_add(proxy->fd_ssl, IO_READ,
+					    ssl_step, proxy);
+		break;
+	case SSL_REMOVE_INPUT:
+		if (proxy->io_ssl_read != NULL)
+			io_remove(&proxy->io_ssl_read);
+		break;
+	case SSL_ADD_OUTPUT:
+		if (proxy->io_ssl_write != NULL)
+			break;
+		proxy->io_ssl_write = io_add(proxy->fd_ssl, IO_WRITE,
+					     ssl_step, proxy);
+		break;
+	case SSL_REMOVE_OUTPUT:
+		if (proxy->io_ssl_write != NULL)
+			io_remove(&proxy->io_ssl_write);
+		break;
+	}
+}
+
+static void plain_block_input(struct ssl_proxy *proxy, bool block)
+{
+	if (block) {
+		if (proxy->io_plain_read != NULL)
+			io_remove(&proxy->io_plain_read);
+	} else {
+		if (proxy->io_plain_read == NULL) {
+			proxy->io_plain_read = io_add(proxy->fd_plain, IO_READ,
+						      plain_read, proxy);
+		}
+	}
+}
+
+static void plain_read(void *context)
+{
+	struct ssl_proxy *proxy = context;
+	ssize_t ret;
+	bool corked = FALSE;
+
+	if (proxy->sslout_size == sizeof(proxy->sslout_buf)) {
+		/* buffer full, block input until it's written */
+		plain_block_input(proxy, TRUE);
+		return;
+	}
+
+	proxy->refcount++;
+
+	while (proxy->sslout_size < sizeof(proxy->sslout_buf) &&
+	       !proxy->destroyed) {
+		ret = net_receive(proxy->fd_plain,
+				  proxy->sslout_buf + proxy->sslout_size,
+				  sizeof(proxy->sslout_buf) -
+				  proxy->sslout_size);
+		if (ret <= 0) {
+			if (ret < 0)
+				ssl_proxy_destroy(proxy);
+			break;
+		} else {
+			proxy->sslout_size += ret;
+			if (!corked) {
+				net_set_cork(proxy->fd_ssl, TRUE);
+				corked = TRUE;
+			}
+			ssl_write(proxy);
+		}
+	}
+
+	if (corked)
+		net_set_cork(proxy->fd_ssl, FALSE);
+
+	ssl_proxy_unref(proxy);
+}
+
+static void plain_write(void *context)
+{
+	struct ssl_proxy *proxy = context;
+	ssize_t ret;
+
+	proxy->refcount++;
+
+	ret = net_transmit(proxy->fd_plain, proxy->plainout_buf,
+			   proxy->plainout_size);
+	if (ret < 0)
+		ssl_proxy_destroy(proxy);
+	else {
+		proxy->plainout_size -= ret;
+		memmove(proxy->plainout_buf, proxy->plainout_buf + ret,
+			proxy->plainout_size);
+
+		if (proxy->plainout_size > 0) {
+			if (proxy->io_plain_write == NULL) {
+				proxy->io_plain_write =
+					io_add(proxy->fd_plain, IO_WRITE,
+					       plain_write, proxy);
+			}
+		} else {
+			if (proxy->io_plain_write != NULL)
+				io_remove(&proxy->io_plain_write);
+		}
+
+		ssl_set_io(proxy, SSL_ADD_INPUT);
+		if (SSL_pending(proxy->ssl) > 0)
+			ssl_read(proxy);
+	}
+
+	ssl_proxy_unref(proxy);
+}
+
 static const char *ssl_last_error(void)
 {
 	unsigned long err;
@@ -204,30 +317,8 @@
 	return buf;
 }
 
-static void ssl_set_io(struct ssl_proxy *proxy, enum ssl_want want)
-{
-	if (proxy->io_ssl != NULL) {
-		if (want == proxy->want)
-			return;
-		io_remove(&proxy->io_ssl);
-	}
-
-	proxy->want = want;
-	switch (want) {
-	case WANT_INPUT:
-		proxy->io_ssl =
-			io_add(proxy->fd_ssl, IO_READ, ssl_step, proxy);
-		break;
-	case WANT_OUTPUT:
-		proxy->io_ssl =
-			io_add(proxy->fd_ssl, IO_WRITE, ssl_step, proxy);
-		break;
-	}
-}
-
-static void
-ssl_handle_error(struct ssl_proxy *proxy, int ret, const char *func_name,
-		 void (*func)(struct ssl_proxy *), unsigned int want_size)
+static void ssl_handle_error(struct ssl_proxy *proxy, int ret,
+			     const char *func_name)
 {
 	const char *errstr;
 	int err;
@@ -236,14 +327,10 @@
 
 	switch (err) {
 	case SSL_ERROR_WANT_READ:
-		proxy->step = func;
-		proxy->ssl_want_size = want_size;
-		ssl_set_io(proxy, WANT_INPUT);
+		ssl_set_io(proxy, SSL_ADD_INPUT);
 		break;
 	case SSL_ERROR_WANT_WRITE:
-		proxy->step = func;
-		proxy->ssl_want_size = want_size;
-		ssl_set_io(proxy, WANT_OUTPUT);
+		ssl_set_io(proxy, SSL_ADD_OUTPUT);
 		break;
 	case SSL_ERROR_SYSCALL:
 		/* eat up the error queue */
@@ -282,150 +369,56 @@
 	}
 }
 
-static void plain_input(void *context)
+static void ssl_handshake(struct ssl_proxy *proxy)
 {
-	struct ssl_proxy *proxy = context;
-	ssize_t ret;
+	int ret;
 
-	if (proxy->sslout_size == sizeof(proxy->sslout_buf)) {
-		/* buffer full, block input until it's written */
-		io_remove(&proxy->io_plain_input);
-		return;
+	ret = SSL_accept(proxy->ssl);
+	if (ret != 1)
+		ssl_handle_error(proxy, ret, "SSL_accept()");
+	else {
+		proxy->handshaked = TRUE;
+
+		ssl_set_io(proxy, SSL_ADD_INPUT);
+		plain_block_input(proxy, FALSE);
 	}
+}
 
-	ret = net_receive(proxy->fd_plain,
-			  proxy->sslout_buf + proxy->sslout_size,
-			  sizeof(proxy->sslout_buf) - proxy->sslout_size);
-	if (ret <= 0) {
-		if (ret < 0)
-			ssl_proxy_destroy(proxy);
-	} else {
-		proxy->sslout_size += ret;
-		if (SSL_want(proxy->ssl) == SSL_NOTHING) {
-			i_assert(proxy->ssl_want_size == 0);
-			ssl_output(proxy);
+static void ssl_read(struct ssl_proxy *proxy)
+{
+	int ret;
+
+	while (proxy->plainout_size < sizeof(proxy->plainout_buf) &&
+	       !proxy->destroyed) {
+		ret = SSL_read(proxy->ssl,
+			       proxy->plainout_buf + proxy->plainout_size,
+			       sizeof(proxy->plainout_buf) -
+			       proxy->plainout_size);
+		if (ret <= 0) {
+			ssl_handle_error(proxy, ret, "SSL_read()");
+			break;
+		} else {
+			proxy->plainout_size += ret;
+			plain_write(proxy);
 		}
 	}
 }
 
-static int plain_output(void *context)
+static void ssl_write(struct ssl_proxy *proxy)
 {
-	struct ssl_proxy *proxy = context;
 	int ret;
 
-	if (proxy->ssl_want_size != 0)
-		return 0;
-
-	if ((ret = o_stream_flush(proxy->plain_output)) < 0) {
-		ssl_proxy_destroy(proxy);
-		return 1;
-	}
-
-	if (o_stream_get_buffer_used_size(proxy->plain_output) <
-	    PLAIN_OUTPUT_OPTIMAL_SIZE &&
-	    proxy->want == WANT_INPUT && proxy->io_ssl == NULL)
-		ssl_set_io(proxy, WANT_INPUT);
-
-	return ret;
-}
-
-static void ssl_handshake(struct ssl_proxy *proxy)
-{
-	int ret, old_errno;
-
-	net_set_cork(proxy->fd_ssl, TRUE);
-	ret = SSL_accept(proxy->ssl);
-
-	old_errno = errno;
-	net_set_cork(proxy->fd_ssl, FALSE);
-	errno = old_errno;
-
-	if (ret != 1)
-		ssl_handle_error(proxy, ret, "SSL_accept()", ssl_handshake, 0);
+	ret = SSL_write(proxy->ssl, proxy->sslout_buf, proxy->sslout_size);
+	if (ret <= 0)
+		ssl_handle_error(proxy, ret, "SSL_write()");
 	else {
-		proxy->handshaked = TRUE;
-		proxy->step = ssl_input;
-		ssl_set_io(proxy, WANT_INPUT);
-
-		proxy->io_plain_input = io_add(proxy->fd_plain, IO_READ,
-					       plain_input, proxy);
-	}
-}
-
-static void ssl_input(struct ssl_proxy *proxy)
-{
-	unsigned char buf[PLAIN_OUTPUT_OPTIMAL_SIZE];
-	size_t size, used;
-	ssize_t ret, ret2;
-
-	used = o_stream_get_buffer_used_size(proxy->plain_output);
-	if (used >= PLAIN_OUTPUT_OPTIMAL_SIZE) {
-		io_remove(&proxy->io_ssl);
-		return;
-	}
+		proxy->sslout_size -= ret;
+		memmove(proxy->sslout_buf, proxy->sslout_buf + ret,
+			proxy->sslout_size);
 
-	size = sizeof(buf) - used;
-	if (proxy->ssl_want_size != 0) {
-		i_assert(proxy->ssl_want_size <= size);
-		size = proxy->ssl_want_size;
-		proxy->ssl_want_size = 0;
-	}
-
-	do {
-		ret = SSL_read(proxy->ssl, buf, size);
-		if (ret <= 0) {
-			ssl_handle_error(proxy, ret, "SSL_read()",
-					 ssl_input, size);
-			return;
-		}
-		o_stream_cork(proxy->plain_output);
-		ret2 = o_stream_send(proxy->plain_output, buf, ret);
-		i_assert(ret2 < 0 || ret2 == ret);
-		o_stream_uncork(proxy->plain_output);
-
-		if (proxy->sslout_size > 0)
-			ssl_output(proxy);
-	} while (SSL_pending(proxy->ssl) > 0);
-}
-
-static void ssl_output(struct ssl_proxy *proxy)
-{
-	unsigned int size;
-	int ret, old_errno;
-
-	if (proxy->ssl_want_size == 0)
-		size = proxy->sslout_size;
-	else {
-		i_assert(proxy->ssl_want_size <= proxy->sslout_size);
-		size = proxy->ssl_want_size;
-		proxy->ssl_want_size = 0;
-	}
-
-	net_set_cork(proxy->fd_ssl, TRUE);
-	ret = SSL_write(proxy->ssl, proxy->sslout_buf, size);
-
-	old_errno = errno;
-	net_set_cork(proxy->fd_ssl, FALSE);
-	errno = old_errno;
-
-	if (ret <= 0) {
-		ssl_handle_error(proxy, ret, "SSL_write()", ssl_output, size);
-		return;
-	}
-
-	proxy->sslout_size -= ret;
-	memmove(proxy->sslout_buf, proxy->sslout_buf + ret, proxy->sslout_size);
-
-	if (proxy->sslout_size > 0) {
-		ssl_set_io(proxy, WANT_OUTPUT);
-		proxy->step = ssl_output;
-	} else {
-		ssl_set_io(proxy, WANT_INPUT);
-		proxy->step = ssl_input;
-	}
-	if (proxy->io_plain_input == NULL) {
-		proxy->io_plain_input = io_add(proxy->fd_plain, IO_READ,
-					       plain_input, proxy);
+		ssl_set_io(proxy, proxy->sslout_size > 0 ?
+			   SSL_ADD_OUTPUT : SSL_REMOVE_OUTPUT);
+		plain_block_input(proxy, FALSE);
 	}
 }
 
@@ -433,7 +426,27 @@
 {
 	struct ssl_proxy *proxy = context;
 
-	proxy->step(proxy);
+	proxy->refcount++;
+
+	if (!proxy->handshaked)
+		ssl_handshake(proxy);
+
+	if (proxy->handshaked) {
+		if (proxy->plainout_size == sizeof(proxy->plainout_buf))
+			ssl_set_io(proxy, SSL_REMOVE_INPUT);
+		else
+			ssl_read(proxy);
+
+		if (proxy->sslout_size == 0)
+			ssl_set_io(proxy, SSL_REMOVE_OUTPUT);
+		else {
+			net_set_cork(proxy->fd_ssl, TRUE);
+			ssl_write(proxy);
+			net_set_cork(proxy->fd_ssl, FALSE);
+		}
+	}
+
+	ssl_proxy_unref(proxy);
 }
 
 int ssl_proxy_new(int fd, struct ip_addr *ip, struct ssl_proxy **proxy_r)
@@ -479,16 +492,10 @@
 	proxy->fd_ssl = fd;
 	proxy->fd_plain = sfd[0];
 	proxy->ip = *ip;
-	proxy->plain_output =
-		o_stream_create_file(proxy->fd_plain, default_pool,
-				     (size_t)-1, FALSE);
-	o_stream_set_flush_callback(proxy->plain_output, plain_output, proxy);
-
         SSL_set_ex_data(ssl, extdata_index, proxy);
 
 	hash_insert(ssl_proxies, proxy, proxy);
 
-	proxy->step = ssl_handshake;
 	ssl_handshake(proxy);
 	main_ref();
 
@@ -549,16 +556,21 @@
 
 	hash_remove(ssl_proxies, proxy);
 
-	if (proxy->io_ssl != NULL)
-		io_remove(&proxy->io_ssl);
-	if (proxy->io_plain_input != NULL)
-		io_remove(&proxy->io_plain_input);
+	if (proxy->io_ssl_read != NULL)
+		io_remove(&proxy->io_ssl_read);
+	if (proxy->io_ssl_write != NULL)
+		io_remove(&proxy->io_ssl_write);
+	if (proxy->io_plain_read != NULL)
+		io_remove(&proxy->io_plain_read);
+	if (proxy->io_plain_write != NULL)
+		io_remove(&proxy->io_plain_write);
 
-	o_stream_unref(&proxy->plain_output);
 	(void)net_disconnect(proxy->fd_ssl);
 	(void)net_disconnect(proxy->fd_plain);
 
 	ssl_proxy_unref(proxy);
+
+	main_listen_start();
 }
 
 static RSA *ssl_gen_rsa_key(SSL *ssl __attr_unused__,
@@ -641,6 +653,11 @@
 	return strlen(buf);
 }
 
+unsigned int ssl_proxy_get_count(void)
+{
+	return hash_size(ssl_proxies);
+}
+
 void ssl_proxy_init(void)
 {
 	const char *cafile, *certfile, *keyfile, *cipher_list;