github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/with_test.go (about)

     1  package ydb //nolint:testpackage
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/pem"
    11  	"math/big"
    12  	"os"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/config"
    20  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/certificates"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
    22  )
    23  
    24  func TestWithCertificatesCached(t *testing.T) { //nolint:funlen
    25  	ca := &x509.Certificate{
    26  		SerialNumber: big.NewInt(2019),
    27  		Subject: pkix.Name{
    28  			Organization:  []string{"Company, INC."},
    29  			Country:       []string{"US"},
    30  			Province:      []string{""},
    31  			Locality:      []string{"San Francisco"},
    32  			StreetAddress: []string{"Golden Gate Bridge"},
    33  			PostalCode:    []string{"94016"},
    34  		},
    35  		NotBefore:             time.Now(),
    36  		NotAfter:              time.Now().AddDate(10, 0, 0),
    37  		IsCA:                  true,
    38  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
    39  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
    40  		BasicConstraintsValid: true,
    41  	}
    42  	caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
    43  	require.NoError(t, err)
    44  	caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
    45  	require.NoError(t, err)
    46  	caPEM := new(bytes.Buffer)
    47  	err = pem.Encode(caPEM, &pem.Block{
    48  		Type:  "CERTIFICATE",
    49  		Bytes: caBytes,
    50  	})
    51  	require.NoError(t, err)
    52  	f, err := os.CreateTemp(os.TempDir(), "ca.pem")
    53  	require.NoError(t, err)
    54  	_, err = f.Write(caPEM.Bytes())
    55  	require.NoError(t, err)
    56  	defer os.Remove(f.Name())
    57  
    58  	var (
    59  		n           = 100
    60  		hitCounter  uint64
    61  		missCounter uint64
    62  		ctx         = context.TODO()
    63  	)
    64  	for _, test := range []struct {
    65  		name    string
    66  		options []Option
    67  		expMiss uint64
    68  		expHit  uint64
    69  	}{
    70  		{
    71  			"no cache",
    72  			[]Option{},
    73  			0,
    74  			0,
    75  		},
    76  		{
    77  			"file cache",
    78  			[]Option{
    79  				WithCertificatesFromFile(f.Name(),
    80  					certificates.FromFileOnHit(func() {
    81  						atomic.AddUint64(&hitCounter, 1)
    82  					}),
    83  					certificates.FromFileOnMiss(func() {
    84  						atomic.AddUint64(&missCounter, 1)
    85  					}),
    86  				),
    87  			},
    88  			0,
    89  			uint64(n),
    90  		},
    91  		{
    92  			"pem cache",
    93  			[]Option{
    94  				WithCertificatesFromPem(caPEM.Bytes(),
    95  					certificates.FromPemOnHit(func() {
    96  						atomic.AddUint64(&hitCounter, 1)
    97  					}),
    98  					certificates.FromPemMiss(func() {
    99  						atomic.AddUint64(&missCounter, 1)
   100  					}),
   101  				),
   102  			},
   103  			0,
   104  			uint64(n),
   105  		},
   106  		{
   107  			"pem&file cache",
   108  			[]Option{
   109  				WithCertificatesFromFile(f.Name(),
   110  					certificates.FromFileOnHit(func() {
   111  						atomic.AddUint64(&hitCounter, 1)
   112  					}),
   113  					certificates.FromFileOnMiss(func() {
   114  						atomic.AddUint64(&missCounter, 1)
   115  					}),
   116  				),
   117  				WithCertificatesFromPem(caPEM.Bytes(),
   118  					certificates.FromPemOnHit(func() {
   119  						atomic.AddUint64(&hitCounter, 1)
   120  					}),
   121  					certificates.FromPemMiss(func() {
   122  						atomic.AddUint64(&missCounter, 1)
   123  					}),
   124  				),
   125  			},
   126  			0,
   127  			uint64(n * 2),
   128  		},
   129  	} {
   130  		t.Run(test.name, func(t *testing.T) {
   131  			db, err := newConnectionFromOptions(ctx,
   132  				append(
   133  					test.options,
   134  					withConnPool(conn.NewPool(context.Background(), config.New())), //nolint:contextcheck
   135  				)...,
   136  			)
   137  			require.NoError(t, err)
   138  
   139  			hitCounter, missCounter = 0, 0
   140  
   141  			for i := 0; i < n; i++ {
   142  				_, _, err := db.with(ctx,
   143  					func(ctx context.Context, c *Driver) error {
   144  						return nil // nothing to do
   145  					},
   146  				)
   147  				require.NoError(t, err)
   148  			}
   149  			require.Equal(t, test.expHit, hitCounter)
   150  			require.Equal(t, test.expMiss, missCounter)
   151  		})
   152  	}
   153  }