Skip to content

Use cgo noescape/nocallback instead of C wrappers #249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jobs:
strategy:
fail-fast: false
matrix:
go-version: [1.22.x, 1.23.x]
go-version: [1.23.x, 1.24.x]
openssl-version: [1.1.0, 1.1.1, 3.0.1, 3.0.13, 3.1.5, 3.2.1, 3.3.0, 3.3.1]
runs-on: ubuntu-20.04
steps:
Expand Down
7 changes: 7 additions & 0 deletions asan_disabled_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !asan

package openssl_test

func Asan() bool {
return false
}
7 changes: 7 additions & 0 deletions asan_enabled_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build asan

package openssl_test

func Asan() bool {
return true
}
20 changes: 18 additions & 2 deletions cgo_go124.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ package openssl
// observed to benefit from these directives, not every function that is merely
// expected to meet the noescape/nocallback criteria.

// #cgo noescape go_openssl_RAND_bytes
// #cgo nocallback go_openssl_RAND_bytes
/*
#cgo noescape go_openssl_RAND_bytes
#cgo nocallback go_openssl_RAND_bytes
#cgo noescape go_openssl_EVP_EncryptUpdate
#cgo nocallback go_openssl_EVP_EncryptUpdate
#cgo noescape go_openssl_EVP_DecryptUpdate
#cgo nocallback go_openssl_EVP_DecryptUpdate
#cgo noescape go_openssl_EVP_CipherUpdate
#cgo nocallback go_openssl_EVP_CipherUpdate
#cgo noescape go_openssl_EVP_PKEY_derive
#cgo nocallback go_openssl_EVP_PKEY_derive
#cgo noescape go_openssl_EVP_PKEY_get_raw_public_key
#cgo nocallback go_openssl_EVP_PKEY_get_raw_public_key
#cgo noescape go_openssl_EVP_PKEY_get_raw_private_key
#cgo nocallback go_openssl_EVP_PKEY_get_raw_private_key
#cgo noescape go_openssl_EVP_DigestSign
#cgo nocallback go_openssl_EVP_DigestSign
*/
import "C"
12 changes: 8 additions & 4 deletions cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ func (c *evpCipher) encrypt(dst, src []byte) error {
}
defer C.go_openssl_EVP_CIPHER_CTX_free(enc_ctx)

if C.go_openssl_EVP_EncryptUpdate_wrapper(enc_ctx, base(dst), base(src), C.int(c.blockSize)) != 1 {
var outl C.int
if C.go_openssl_EVP_EncryptUpdate(enc_ctx, base(dst), &outl, base(src), C.int(c.blockSize)) != 1 {
return errors.New("EncryptUpdate failed")
}
runtime.KeepAlive(c)
Expand Down Expand Up @@ -208,7 +209,8 @@ func (c *evpCipher) decrypt(dst, src []byte) error {
return errors.New("could not disable cipher padding")
}

C.go_openssl_EVP_DecryptUpdate_wrapper(dec_ctx, base(dst), base(src), C.int(c.blockSize))
var outl C.int
C.go_openssl_EVP_DecryptUpdate(dec_ctx, base(dst), &outl, base(src), C.int(c.blockSize))
runtime.KeepAlive(c)
return nil
}
Expand All @@ -235,7 +237,8 @@ func (x *cipherCBC) CryptBlocks(dst, src []byte) {
panic("crypto/cipher: output smaller than input")
}
if len(src) > 0 {
if C.go_openssl_EVP_CipherUpdate_wrapper(x.ctx, base(dst), base(src), C.int(len(src))) != 1 {
var outl C.int
if C.go_openssl_EVP_CipherUpdate(x.ctx, base(dst), &outl, base(src), C.int(len(src))) != 1 {
panic("crypto/cipher: CipherUpdate failed")
}
runtime.KeepAlive(x)
Expand Down Expand Up @@ -278,7 +281,8 @@ func (x *cipherCTR) XORKeyStream(dst, src []byte) {
if len(src) == 0 {
return
}
if C.go_openssl_EVP_EncryptUpdate_wrapper(x.ctx, base(dst), base(src), C.int(len(src))) != 1 {
var outl C.int
if C.go_openssl_EVP_EncryptUpdate(x.ctx, base(dst), &outl, base(src), C.int(len(src))) != 1 {
panic("crypto/cipher: EncryptUpdate failed")
}
runtime.KeepAlive(x)
Expand Down
12 changes: 6 additions & 6 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ func ECDH(priv *PrivateKeyECDH, pub *PublicKeyECDH) ([]byte, error) {
if C.go_openssl_EVP_PKEY_derive_set_peer(ctx, pub._pkey) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_set_peer")
}
r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0)
if r.result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
var keylen C.size_t
if C.go_openssl_EVP_PKEY_derive(ctx, nil, &keylen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
out := make([]byte, keylen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &keylen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out, nil
}
Expand Down
24 changes: 12 additions & 12 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ func NewPrivateKeyEd25519FromSeed(seed []byte) (*PrivateKeyEd25519, error) {
}

func extractPKEYPubEd25519(pkey C.GO_EVP_PKEY_PTR, pub []byte) error {
r := C.go_openssl_EVP_PKEY_get_raw_public_key_wrapper(pkey, base(pub), C.size_t(publicKeySizeEd25519))
if r.result != 1 {
keylen := C.size_t(publicKeySizeEd25519)
if C.go_openssl_EVP_PKEY_get_raw_public_key(pkey, base(pub), &keylen) != 1 {
return newOpenSSLError("EVP_PKEY_get_raw_public_key")
}
if r.len != publicKeySizeEd25519 {
return errors.New("ed25519: bad public key length: " + strconv.Itoa(int(r.len)))
if int(keylen) != publicKeySizeEd25519 {
return errors.New("ed25519: bad public key length: " + strconv.Itoa(int(keylen)))
}
return nil
}
Expand All @@ -169,12 +169,12 @@ func extractPKEYPrivEd25519(pkey C.GO_EVP_PKEY_PTR, priv []byte) error {
if err := extractPKEYPubEd25519(pkey, priv[seedSizeEd25519:]); err != nil {
return err
}
r := C.go_openssl_EVP_PKEY_get_raw_private_key_wrapper(pkey, base(priv), C.size_t(seedSizeEd25519))
if r.result != 1 {
keylen := C.size_t(seedSizeEd25519)
if C.go_openssl_EVP_PKEY_get_raw_private_key(pkey, base(priv), &keylen) != 1 {
return newOpenSSLError("EVP_PKEY_get_raw_private_key")
}
if r.len != seedSizeEd25519 {
return errors.New("ed25519: bad private key length: " + strconv.Itoa(int(r.len)))
if int(keylen) != seedSizeEd25519 {
return errors.New("ed25519: bad private key length: " + strconv.Itoa(int(keylen)))
}
return nil
}
Expand All @@ -200,12 +200,12 @@ func signEd25519(priv *PrivateKeyEd25519, sig, message []byte) error {
if C.go_openssl_EVP_DigestSignInit(ctx, nil, nil, nil, priv._pkey) != 1 {
return newOpenSSLError("EVP_DigestSignInit")
}
r := C.go_openssl_EVP_DigestSign_wrapper(ctx, base(sig), C.size_t(signatureSizeEd25519), base(message), C.size_t(len(message)))
if r.result != 1 {
siglen := C.size_t(signatureSizeEd25519)
if C.go_openssl_EVP_DigestSign(ctx, base(sig), &siglen, base(message), C.size_t(len(message))) != 1 {
return newOpenSSLError("EVP_DigestSign")
}
if r.siglen != signatureSizeEd25519 {
return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(r.siglen)))
if int(siglen) != signatureSizeEd25519 {
return errors.New("ed25519: bad signature length: " + strconv.Itoa(int(siglen)))
}
return nil
}
Expand Down
76 changes: 0 additions & 76 deletions goopenssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,82 +76,6 @@ go_hash_sum(GO_EVP_MD_CTX_PTR ctx, GO_EVP_MD_CTX_PTR ctx2, unsigned char *out)
return go_openssl_EVP_DigestFinal(ctx2, out, NULL);
}

// These wrappers allocate out_len on the C stack to avoid having to pass a pointer from Go, which would escape to the heap.
// Use them only in situations where the output length can be safely discarded.
static inline int
go_openssl_EVP_EncryptUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *out, const unsigned char *in, int in_len)
{
int len;
return go_openssl_EVP_EncryptUpdate(ctx, out, &len, in, in_len);
}

static inline int
go_openssl_EVP_DecryptUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *out, const unsigned char *in, int in_len)
{
int len;
return go_openssl_EVP_DecryptUpdate(ctx, out, &len, in, in_len);
}

static inline int
go_openssl_EVP_CipherUpdate_wrapper(GO_EVP_CIPHER_CTX_PTR ctx, unsigned char *out, const unsigned char *in, int in_len)
{
int len;
return go_openssl_EVP_CipherUpdate(ctx, out, &len, in, in_len);
}

// These wrappers also allocate length variables on the C stack to avoid escape to the heap, but do return the result.
// A struct is returned that contains multiple return values instead of OpenSSL's approach of using pointers.

typedef struct
{
int result;
size_t keylen;
} go_openssl_EVP_PKEY_derive_wrapper_out;

static inline go_openssl_EVP_PKEY_derive_wrapper_out
go_openssl_EVP_PKEY_derive_wrapper(GO_EVP_PKEY_CTX_PTR ctx, unsigned char *key, size_t keylen)
{
go_openssl_EVP_PKEY_derive_wrapper_out r = {0, keylen};
r.result = go_openssl_EVP_PKEY_derive(ctx, key, &r.keylen);
return r;
}

typedef struct
{
int result;
size_t len;
} go_openssl_EVP_PKEY_get_raw_key_out;

static inline go_openssl_EVP_PKEY_get_raw_key_out
go_openssl_EVP_PKEY_get_raw_public_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *pub, size_t len)
{
go_openssl_EVP_PKEY_get_raw_key_out r = {0, len};
r.result = go_openssl_EVP_PKEY_get_raw_public_key(pkey, pub, &r.len);
return r;
}

static inline go_openssl_EVP_PKEY_get_raw_key_out
go_openssl_EVP_PKEY_get_raw_private_key_wrapper(const GO_EVP_PKEY_PTR pkey, unsigned char *priv, size_t len)
{
go_openssl_EVP_PKEY_get_raw_key_out r = {0, len};
r.result = go_openssl_EVP_PKEY_get_raw_private_key(pkey, priv, &r.len);
return r;
}

typedef struct
{
int result;
size_t siglen;
} go_openssl_EVP_DigestSign_wrapper_out;

static inline go_openssl_EVP_DigestSign_wrapper_out
go_openssl_EVP_DigestSign_wrapper(GO_EVP_MD_CTX_PTR ctx, unsigned char *sigret, size_t siglen, const unsigned char *tbs, size_t tbslen)
{
go_openssl_EVP_DigestSign_wrapper_out r = {0, siglen};
r.result = go_openssl_EVP_DigestSign(ctx, sigret, &r.siglen, tbs, tbslen);
return r;
}

// These wrappers allocate out_len on the C stack, and check that it matches the expected
// value, to avoid having to pass a pointer from Go, which would escape to the heap.

Expand Down
15 changes: 8 additions & 7 deletions hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (c *hkdf1) Read(p []byte) (int, error) {
}
c.buf = append(c.buf, make([]byte, needLen)...)
outLen := C.size_t(prevLen + needLen)
if C.go_openssl_EVP_PKEY_derive_wrapper(c.ctx, base(c.buf), outLen).result != 1 {
if C.go_openssl_EVP_PKEY_derive(c.ctx, base(c.buf), &outLen) != 1 {
return 0, newOpenSSLError("EVP_PKEY_derive")
}
n := copy(p, c.buf[prevLen:outLen])
Expand All @@ -126,15 +126,15 @@ func ExtractHKDF(h func() hash.Hash, secret, salt []byte) ([]byte, error) {
return nil, err
}
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
r := C.go_openssl_EVP_PKEY_derive_wrapper(ctx, nil, 0)
if r.result != 1 {
var keylen C.size_t
if C.go_openssl_EVP_PKEY_derive(ctx, nil, &keylen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive_init")
}
out := make([]byte, r.keylen)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), r.keylen).result != 1 {
out := make([]byte, keylen)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &keylen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
return out[:r.keylen], nil
return out[:keylen], nil
case 3:
ctx, err := newHKDFCtx3(md, C.GO_EVP_KDF_HKDF_MODE_EXTRACT_ONLY, secret, salt, nil, nil)
if err != nil {
Expand Down Expand Up @@ -170,7 +170,8 @@ func ExpandHKDFOneShot(h func() hash.Hash, pseudorandomKey, info []byte, keyLeng
return nil, err
}
defer C.go_openssl_EVP_PKEY_CTX_free(ctx)
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(out), C.size_t(keyLength)).result != 1 {
keylen := C.size_t(keyLength)
if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &keylen) != 1 {
return nil, newOpenSSLError("EVP_PKEY_derive")
}
case 3:
Expand Down
3 changes: 3 additions & 0 deletions rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ func TestRand(t *testing.T) {
}

func TestAllocations(t *testing.T) {
if Asan() {
t.Skip("skipping allocations test with sanitizers")
}
n := int(testing.AllocsPerRun(10, func() {
buf := make([]byte, 32)
openssl.RandReader.Read(buf)
Expand Down
2 changes: 1 addition & 1 deletion tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func tls1PRF1(result, secret, label, seed []byte, md C.GO_EVP_MD_PTR) error {
return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed")
}
outLen := C.size_t(len(result))
if C.go_openssl_EVP_PKEY_derive_wrapper(ctx, base(result), outLen).result != 1 {
if C.go_openssl_EVP_PKEY_derive(ctx, base(result), &outLen) != 1 {
return newOpenSSLError("EVP_PKEY_derive")
}
// The Go standard library expects TLS1PRF to return the requested number of bytes,
Expand Down
Loading