changeset 20564:dced7deca4d6

dcrypt: Use refcounting on keys
author Aki Tuomi <aki.tuomi@dovecot.fi>
date Fri, 15 Jul 2016 16:27:03 +0300
parents e6647016bc12
children b550a486b502
files src/lib-dcrypt/dcrypt-openssl.c src/lib-dcrypt/dcrypt-private.h src/lib-dcrypt/dcrypt.c src/lib-dcrypt/dcrypt.h
diffstat 4 files changed, 120 insertions(+), 68 deletions(-) [+]
line wrap: on
line diff
--- a/src/lib-dcrypt/dcrypt-openssl.c	Thu Aug 04 08:55:12 2016 +0300
+++ b/src/lib-dcrypt/dcrypt-openssl.c	Fri Jul 15 16:27:03 2016 +0300
@@ -99,11 +99,13 @@
 };
 
 struct dcrypt_public_key {
-	void *ctx;
+	EVP_PKEY *key;
+	unsigned int ref;
 };
 
 struct dcrypt_private_key {
-	void *ctx;
+	EVP_PKEY *key;
+	unsigned int ref;
 };
 
 static
@@ -117,9 +119,9 @@
 static
 void dcrypt_openssl_private_to_public_key(struct dcrypt_private_key *priv_key, struct dcrypt_public_key **pub_key_r);
 static
-void dcrypt_openssl_free_private_key(struct dcrypt_private_key **key);
+void dcrypt_openssl_unref_private_key(struct dcrypt_private_key **key);
 static
-void dcrypt_openssl_free_public_key(struct dcrypt_public_key **key);
+void dcrypt_openssl_unref_public_key(struct dcrypt_public_key **key);
 static
 bool dcrypt_openssl_rsa_decrypt(struct dcrypt_private_key *key, const unsigned char *data, size_t data_len, buffer_t *result, const char **error_r);
 
@@ -570,7 +572,7 @@
 bool dcrypt_openssl_ecdh_derive_secret_local(struct dcrypt_private_key *local_key, buffer_t *R, buffer_t *S, const char **error_r)
 {
 	i_assert(local_key != NULL);
-	EVP_PKEY *local = (EVP_PKEY*)local_key;
+	EVP_PKEY *local = local_key->key;
 	BN_CTX *bn_ctx = BN_CTX_new();
 	if (bn_ctx == NULL)
 		return dcrypt_openssl_error(error_r);
@@ -645,7 +647,7 @@
 {
 	/* ensure peer_key is EC key */
 	EVP_PKEY *local = NULL;
-	EVP_PKEY *peer = (EVP_PKEY*)peer_key;
+	EVP_PKEY *peer = peer_key->key;
 	if (EVP_PKEY_base_id(peer) != EVP_PKEY_EC) {
 		if (error_r != NULL)
 			*error_r = "Only ECC key can be used";
@@ -725,7 +727,10 @@
 	EVP_PKEY *pkey = NULL;
 	if (kind == DCRYPT_KEY_RSA) {
 		if (dcrypt_openssl_generate_rsa_key(bits, &pkey, error_r)) {
-			pair_r->priv = (struct dcrypt_private_key*)pkey;
+			pair_r->priv = i_new(struct dcrypt_private_key, 1);
+			pair_r->priv->key = pkey;
+			pair_r->priv->ref++;
+			pair_r->pub = NULL;
 			dcrypt_openssl_private_to_public_key(pair_r->priv, &(pair_r->pub));
 			return TRUE;
 		} else return dcrypt_openssl_error(error_r);
@@ -737,7 +742,10 @@
 			return FALSE;
 		}
 		if (dcrypt_openssl_generate_ec_key(nid, &pkey, error_r)) {
-			pair_r->priv = (struct dcrypt_private_key*)pkey;
+			pair_r->priv = i_new(struct dcrypt_private_key, 1);
+			pair_r->priv->key = pkey;
+			pair_r->priv->ref++;
+			pair_r->pub = NULL;
 			dcrypt_openssl_private_to_public_key(pair_r->priv, &(pair_r->pub));
 			return TRUE;
 		} else return dcrypt_openssl_error(error_r);
@@ -938,7 +946,9 @@
 		}
 		EVP_PKEY_set1_EC_KEY(key, eckey);
 		EC_KEY_free(eckey);
-		*key_r = (struct dcrypt_private_key *)key;
+		*key_r = i_new(struct dcrypt_private_key, 1);
+		(*key_r)->key = key;
+		(*key_r)->ref++;
 		return TRUE;
 	}
 
@@ -1050,14 +1060,14 @@
 		/* check that we have correct decryption key */
 		dcrypt_openssl_private_to_public_key(dec_key, &pubkey);
 		if (!dcrypt_openssl_public_key_id(pubkey, "sha256", data, error_r)) {
-			if (pubkey != NULL) dcrypt_openssl_free_public_key(&pubkey);
+			if (pubkey != NULL) dcrypt_openssl_unref_public_key(&pubkey);
 			return FALSE;
 		}
 
-		dcrypt_openssl_free_public_key(&pubkey);
+		dcrypt_openssl_unref_public_key(&pubkey);
 
 		if (strcmp(binary_to_hex(data->data, data->used), input[9]) != 0) {
-			dcrypt_openssl_free_public_key(&pubkey);
+			dcrypt_openssl_unref_public_key(&pubkey);
 			if (error_r != NULL)
 				*error_r = "No private key available";
 			return FALSE;
@@ -1133,7 +1143,9 @@
 		}
 		EVP_PKEY_set1_RSA(pkey, rsa);
 		RSA_free(rsa);
-		*key_r = (struct dcrypt_private_key *)pkey;
+		*key_r = i_new(struct dcrypt_private_key, 1);
+		(*key_r)->key = pkey;
+		(*key_r)->ref++;
 	} else {
 		int ec;
 		BIGNUM *point = BN_new();
@@ -1173,7 +1185,9 @@
 		if (ec == 1 && key != NULL && EC_KEY_check_key(eckey) == 1) {
 			EVP_PKEY_set1_EC_KEY(key, eckey);
 			EC_KEY_free(eckey);
-			*key_r = (struct dcrypt_private_key *)key;
+			*key_r = i_new(struct dcrypt_private_key, 1);
+			(*key_r)->key = key;
+			(*key_r)->ref++;
 		} else {
 			EVP_PKEY_free(key);
 			EC_KEY_free(eckey);
@@ -1185,7 +1199,7 @@
 	dcrypt_openssl_private_key_id(*key_r, "sha256", key_data, NULL);
 
 	if (strcmp(binary_to_hex(key_data->data, key_data->used), input[len-1]) != 0) {
-		dcrypt_openssl_free_private_key(key_r);
+		dcrypt_openssl_unref_private_key(key_r);
 		if (error_r != NULL)
 			*error_r = "Key id mismatch after load";
 		return FALSE;
@@ -1268,14 +1282,17 @@
 		EC_KEY_free(eckey);
 		/* make sure digest matches */
 		buffer_t *dgst = buffer_create_dynamic(pool_datastack_create(), 32);
-		dcrypt_openssl_public_key_id_old((struct dcrypt_public_key *)key, dgst, NULL);
+		struct dcrypt_public_key tmp = { key, 0 };
+		dcrypt_openssl_public_key_id_old(&tmp, dgst, NULL);
 		if (strcmp(binary_to_hex(dgst->data, dgst->used), input[len-1]) != 0) {
 			if (error_r != NULL)
 				*error_r = "Key id mismatch after load";
 			EVP_PKEY_free(key);
 			return -1;
 		}
-		*key_r = (struct dcrypt_public_key *)key;
+		*key_r = i_new(struct dcrypt_public_key, 1);
+		(*key_r)->key = key;
+		(*key_r)->ref++;
 		return 0;
 	}
 
@@ -1309,7 +1326,8 @@
 
 	/* make sure digest matches */
 	buffer_t *dgst = buffer_create_dynamic(pool_datastack_create(), 32);
-	dcrypt_openssl_public_key_id((struct dcrypt_public_key *)pkey, "sha256", dgst, NULL);
+	struct dcrypt_public_key tmpkey = {pkey, 0};
+	dcrypt_openssl_public_key_id(&tmpkey, "sha256", dgst, NULL);
 	if (strcmp(binary_to_hex(dgst->data, dgst->used), input[len-1]) != 0) {
 		if (error_r != NULL)
 			*error_r = "Key id mismatch after load";
@@ -1317,7 +1335,9 @@
 		return -1;
 	}
 
-	*key_r = (struct dcrypt_public_key *)pkey;
+	*key_r = i_new(struct dcrypt_public_key, 1);
+	(*key_r)->key = pkey;
+	(*key_r)->ref++;
 	return 0;
 }
 
@@ -1371,7 +1391,7 @@
 		DCRYPT_DOVECOT_KEY_ENCRYPT_ROUNDS));
 
 	if (enctype == DCRYPT_DOVECOT_KEY_ENCRYPT_PK) {
-		if (EVP_PKEY_base_id((EVP_PKEY*)enc_key) == EVP_PKEY_RSA) {
+		if (EVP_PKEY_base_id(enc_key->key) == EVP_PKEY_RSA) {
 			size_t used = buffer_get_used_size(secret);
 			/* peer key, in this case, is encrypted secret, which is 16 bytes of data */
 			ptr = buffer_append_space_unsafe(secret, 16);
@@ -1380,7 +1400,7 @@
 			if (!dcrypt_rsa_encrypt(enc_key, secret->data, secret->used, peer_key, error_r)) {
 				return FALSE;
 			}
-		} else if (EVP_PKEY_base_id((EVP_PKEY*)enc_key) == EVP_PKEY_EC) {
+		} else if (EVP_PKEY_base_id(enc_key->key) == EVP_PKEY_EC) {
 			/* generate secret by ECDHE */
 			if (!dcrypt_openssl_ecdh_derive_secret_peer(enc_key, peer_key, secret, error_r)) {
 				return FALSE;
@@ -1423,7 +1443,7 @@
 {
 	size_t dest_used = buffer_get_used_size(destination);
 	const char *cipher2 = NULL;
-	EVP_PKEY *pkey = (EVP_PKEY*)key;
+	EVP_PKEY *pkey = key->key;
 	char objtxt[80]; /* openssl manual says this is OK */
 	ASN1_OBJECT *obj;
 	if (EVP_PKEY_base_id(pkey) == EVP_PKEY_EC) {
@@ -1511,7 +1531,7 @@
 static
 bool dcrypt_openssl_store_public_key_dovecot(struct dcrypt_public_key *key, buffer_t *destination, const char **error_r)
 {
-	EVP_PKEY *pubkey = (EVP_PKEY*)key;
+	EVP_PKEY *pubkey = key->key;
 	unsigned char *tmp = NULL;
 	size_t dest_used = buffer_get_used_size(destination);
 
@@ -1567,7 +1587,9 @@
 		EC_KEY_set_conv_form(EVP_PKEY_get0_EC_KEY(key), POINT_CONVERSION_COMPRESSED);
 	}
 
-	*key_r = (struct dcrypt_private_key *)key;
+	*key_r = i_new(struct dcrypt_private_key, 1);
+	(*key_r)->key = key;
+	(*key_r)->ref++;
 
 	return TRUE;
 }
@@ -1619,7 +1641,9 @@
 	if (key == NULL)
 		return dcrypt_openssl_error(error_r);
 
-	*key_r = (struct dcrypt_public_key *)key;
+	*key_r = i_new(struct dcrypt_public_key, 1);
+	(*key_r)->key = key;
+	(*key_r)->ref++;
 
 	return TRUE;
 }
@@ -1636,7 +1660,7 @@
 		return ret;
 	}
 
-	EVP_PKEY *pkey = (EVP_PKEY*)key;
+	EVP_PKEY *pkey = key->key;
 	BIO *key_out = BIO_new(BIO_s_mem());
 	if (key_out == NULL)
 		return dcrypt_openssl_error(error_r);
@@ -1677,7 +1701,7 @@
 	if (format == DCRYPT_FORMAT_DOVECOT)
 		return dcrypt_openssl_store_public_key_dovecot(key, destination, error_r);
 
-	EVP_PKEY *pkey = (EVP_PKEY*)key;
+	EVP_PKEY *pkey = key->key;
 	BIO *key_out = BIO_new(BIO_s_mem());
 	if (key_out == NULL)
 		return dcrypt_openssl_error(error_r);
@@ -1716,14 +1740,11 @@
 static
 void dcrypt_openssl_private_to_public_key(struct dcrypt_private_key *priv_key, struct dcrypt_public_key **pub_key_r)
 {
-	EVP_PKEY *pkey = (EVP_PKEY*)priv_key;
+	EVP_PKEY *pkey = priv_key->key;
 	EVP_PKEY *pk;
 
-	if (*pub_key_r == NULL) {
-		pk = EVP_PKEY_new();
-		i_assert(pk != NULL); /* we shouldn't get malloc() failures */
-	} else
-		pk = (EVP_PKEY*)*pub_key_r;
+	pk = EVP_PKEY_new();
+	i_assert(pk != NULL); /* we shouldn't get malloc() failures */
 
 	if (EVP_PKEY_base_id(pkey) == EVP_PKEY_RSA)
 	{
@@ -1740,7 +1761,9 @@
 		i_unreached();
 	}
 
-	*pub_key_r = (struct dcrypt_public_key*)pk;
+	*pub_key_r = i_new(struct dcrypt_public_key, 1);
+	(*pub_key_r)->key = pk;
+	(*pub_key_r)->ref++;
 }
 
 static
@@ -1862,22 +1885,38 @@
 }
 
 static
-void dcrypt_openssl_free_public_key(struct dcrypt_public_key **key)
+void dcrypt_openssl_ref_public_key(struct dcrypt_public_key *key)
 {
-	EVP_PKEY_free((EVP_PKEY*)*key);
-	*key = NULL;
+	key->ref++;
+}
+static
+void dcrypt_openssl_ref_private_key(struct dcrypt_private_key *key)
+{
+	key->ref++;
 }
 static
-void dcrypt_openssl_free_private_key(struct dcrypt_private_key **key)
+void dcrypt_openssl_unref_public_key(struct dcrypt_public_key **key)
 {
-	EVP_PKEY_free((EVP_PKEY*)*key);
+	struct dcrypt_public_key *_key = *key;
 	*key = NULL;
+	if (--_key->ref > 0) return;
+	EVP_PKEY_free(_key->key);
+	i_free(_key);
 }
 static
-void dcrypt_openssl_free_keypair(struct dcrypt_keypair *keypair)
+void dcrypt_openssl_unref_private_key(struct dcrypt_private_key **key)
 {
-	dcrypt_openssl_free_public_key(&(keypair->pub));
-	dcrypt_openssl_free_private_key(&(keypair->priv));
+	struct dcrypt_private_key *_key = *key;
+	*key = NULL;
+	if (--_key->ref > 0) return;
+	EVP_PKEY_free(_key->key);
+	i_free(_key);
+}
+static
+void dcrypt_openssl_unref_keypair(struct dcrypt_keypair *keypair)
+{
+	dcrypt_openssl_unref_public_key(&(keypair->pub));
+	dcrypt_openssl_unref_private_key(&(keypair->priv));
 }
 
 static
@@ -1885,8 +1924,8 @@
 {
 	int ec;
 
-	EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new((EVP_PKEY*)key, NULL);
-	size_t outl = EVP_PKEY_size((EVP_PKEY*)key);
+	EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(key->key, NULL);
+	size_t outl = EVP_PKEY_size(key->key);
 	unsigned char buf[outl];
 
 	if (ctx == NULL ||
@@ -1909,8 +1948,8 @@
 {
 	int ec;
 
-	EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new((EVP_PKEY*)key, NULL);
-	size_t outl = EVP_PKEY_size((EVP_PKEY*)key);
+	EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new(key->key, NULL);
+	size_t outl = EVP_PKEY_size(key->key);
 	unsigned char buf[outl];
 
 	if (ctx == NULL ||
@@ -1966,7 +2005,7 @@
 static
 enum dcrypt_key_type dcrypt_openssl_private_key_type(struct dcrypt_private_key *key)
 {
-	EVP_PKEY *priv = (EVP_PKEY*)key;
+	EVP_PKEY *priv = key->key;
 	i_assert(priv != NULL);
 	if (EVP_PKEY_base_id(priv) == EVP_PKEY_RSA) return DCRYPT_KEY_RSA;
 	else if (EVP_PKEY_base_id(priv) == EVP_PKEY_EC) return DCRYPT_KEY_EC;
@@ -1976,7 +2015,7 @@
 static
 enum dcrypt_key_type dcrypt_openssl_public_key_type(struct dcrypt_public_key *key)
 {
-	EVP_PKEY *pub = (EVP_PKEY*)key;
+	EVP_PKEY *pub = key->key;
 	i_assert(pub != NULL);
 	if (EVP_PKEY_base_id(pub) == EVP_PKEY_RSA) return DCRYPT_KEY_RSA;
 	else if (EVP_PKEY_base_id(pub) == EVP_PKEY_EC) return DCRYPT_KEY_EC;
@@ -1988,7 +2027,7 @@
 bool dcrypt_openssl_public_key_id_old(struct dcrypt_public_key *key, buffer_t *result, const char **error_r)
 {
 	unsigned char buf[SHA256_DIGEST_LENGTH];
-	EVP_PKEY *pub = (EVP_PKEY*)key;
+	EVP_PKEY *pub = key->key;
 
 	i_assert(pub != NULL);
 	if (EVP_PKEY_base_id(pub) != EVP_PKEY_EC) {
@@ -2011,7 +2050,7 @@
 bool dcrypt_openssl_private_key_id_old(struct dcrypt_private_key *key, buffer_t *result, const char **error_r)
 {
 	unsigned char buf[SHA256_DIGEST_LENGTH];
-	EVP_PKEY *priv = (EVP_PKEY*)key;
+	EVP_PKEY *priv = key->key;
 
 	i_assert(priv != NULL);
 	if (EVP_PKEY_base_id(priv) != EVP_PKEY_EC) {
@@ -2076,7 +2115,7 @@
 bool dcrypt_openssl_public_key_id(struct dcrypt_public_key *key, const char *algorithm, buffer_t *result, const char **error_r)
 {
 	const EVP_MD *md = EVP_get_digestbyname(algorithm);
-	EVP_PKEY *pub = (EVP_PKEY*)key;
+	EVP_PKEY *pub = key->key;
 
 	i_assert(pub != NULL);
 	if (md == NULL) {
@@ -2092,7 +2131,7 @@
 bool dcrypt_openssl_private_key_id(struct dcrypt_private_key *key, const char *algorithm, buffer_t *result, const char **error_r)
 {
 	const EVP_MD *md = EVP_get_digestbyname(algorithm);
-	EVP_PKEY *priv = (EVP_PKEY*)key;
+	EVP_PKEY *priv = key->key;
 
 	i_assert(priv != NULL);
 	if (md == NULL) {
@@ -2144,9 +2183,11 @@
 	.store_public_key = dcrypt_openssl_store_public_key,
 	.private_to_public_key = dcrypt_openssl_private_to_public_key,
 	.key_string_get_info = dcrypt_openssl_key_string_get_info,
-	.free_keypair = dcrypt_openssl_free_keypair,
-	.free_public_key = dcrypt_openssl_free_public_key,
-	.free_private_key = dcrypt_openssl_free_private_key,
+	.unref_keypair = dcrypt_openssl_unref_keypair,
+	.unref_public_key = dcrypt_openssl_unref_public_key,
+	.unref_private_key = dcrypt_openssl_unref_private_key,
+	.ref_public_key = dcrypt_openssl_ref_public_key,
+	.ref_private_key = dcrypt_openssl_ref_private_key,
 	.rsa_encrypt = dcrypt_openssl_rsa_encrypt,
 	.rsa_decrypt = dcrypt_openssl_rsa_decrypt,
 	.oid2name = dcrypt_openssl_oid2name,
--- a/src/lib-dcrypt/dcrypt-private.h	Thu Aug 04 08:55:12 2016 +0300
+++ b/src/lib-dcrypt/dcrypt-private.h	Fri Jul 15 16:27:03 2016 +0300
@@ -78,9 +78,11 @@
 		enum dcrypt_key_kind *kind_r, enum dcrypt_key_encryption_type *encryption_type_r, const char **encryption_key_hash_r,
 		const char **key_hash_r, const char **error_r);
 
-	void (*free_keypair)(struct dcrypt_keypair *keypair);
-	void (*free_public_key)(struct dcrypt_public_key **key);
-	void (*free_private_key)(struct dcrypt_private_key **key);
+	void (*unref_keypair)(struct dcrypt_keypair *keypair);
+	void (*unref_public_key)(struct dcrypt_public_key **key);
+	void (*unref_private_key)(struct dcrypt_private_key **key);
+        void (*ref_public_key)(struct dcrypt_public_key *key);
+        void (*ref_private_key)(struct dcrypt_private_key *key);
 
 	bool (*rsa_encrypt)(struct dcrypt_public_key *key, const unsigned char *data, size_t data_len,
 		buffer_t *result, const char **error_r);
--- a/src/lib-dcrypt/dcrypt.c	Thu Aug 04 08:55:12 2016 +0300
+++ b/src/lib-dcrypt/dcrypt.c	Fri Jul 15 16:27:03 2016 +0300
@@ -255,17 +255,25 @@
 {
 	return dcrypt_vfs->private_key_id_old(key, result, error_r);
 }
-void dcrypt_keypair_free(struct dcrypt_keypair *keypair)
+void dcrypt_keypair_unref(struct dcrypt_keypair *keypair)
 {
-	dcrypt_vfs->free_keypair(keypair);
+	dcrypt_vfs->unref_keypair(keypair);
 }
-void dcrypt_key_free_public(struct dcrypt_public_key **key)
+void dcrypt_key_ref_public(struct dcrypt_public_key *key)
+{
+	dcrypt_vfs->ref_public_key(key);
+}
+void dcrypt_key_ref_private(struct dcrypt_private_key *key)
 {
-	dcrypt_vfs->free_public_key(key);
+	dcrypt_vfs->ref_private_key(key);
 }
-void dcrypt_key_free_private(struct dcrypt_private_key **key)
+void dcrypt_key_unref_public(struct dcrypt_public_key **key)
 {
-	dcrypt_vfs->free_private_key(key);
+	dcrypt_vfs->unref_public_key(key);
+}
+void dcrypt_key_unref_private(struct dcrypt_private_key **key)
+{
+	dcrypt_vfs->unref_private_key(key);
 }
 
 bool dcrypt_rsa_encrypt(struct dcrypt_public_key *key, const unsigned char *data, size_t data_len, buffer_t *result, const char **error_r)
--- a/src/lib-dcrypt/dcrypt.h	Thu Aug 04 08:55:12 2016 +0300
+++ b/src/lib-dcrypt/dcrypt.h	Fri Jul 15 16:27:03 2016 +0300
@@ -184,10 +184,11 @@
 
 void dcrypt_key_convert_private_to_public(struct dcrypt_private_key *priv_key, struct dcrypt_public_key **pub_key_r);
 
-void dcrypt_keypair_free(struct dcrypt_keypair *keypair);
-
-void dcrypt_key_free_public(struct dcrypt_public_key **key);
-void dcrypt_key_free_private(struct dcrypt_private_key **key);
+void dcrypt_keypair_unref(struct dcrypt_keypair *keypair);
+void dcrypt_key_ref_public(struct dcrypt_public_key *key);
+void dcrypt_key_ref_private(struct dcrypt_private_key *key);
+void dcrypt_key_unref_public(struct dcrypt_public_key **key);
+void dcrypt_key_unref_private(struct dcrypt_private_key **key);
 
 enum dcrypt_key_type dcrypt_key_type_private(struct dcrypt_private_key *key);
 enum dcrypt_key_type dcrypt_key_type_public(struct dcrypt_public_key *key);