github.com/kubri/kubri@v0.5.1-0.20240317001612-bda2aaef967e/pkg/crypto/internal/cryptotest/cryptotest.go (about)

     1  package cryptotest
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"os"
     7  	"os/exec"
     8  	"path/filepath"
     9  	"testing"
    10  
    11  	"github.com/google/go-cmp/cmp"
    12  
    13  	"github.com/kubri/kubri/pkg/crypto"
    14  )
    15  
    16  type Implementation[PrivateKey any, PublicKey any] struct {
    17  	NewPrivateKey       func() (PrivateKey, error)
    18  	MarshalPrivateKey   func(key PrivateKey) ([]byte, error)
    19  	UnmarshalPrivateKey func(b []byte) (PrivateKey, error)
    20  	Public              func(key PrivateKey) PublicKey
    21  	MarshalPublicKey    func(key PublicKey) ([]byte, error)
    22  	UnmarshalPublicKey  func(b []byte) (PublicKey, error)
    23  	Sign                func(key PrivateKey, data []byte) ([]byte, error)
    24  	Verify              func(key PublicKey, data, sig []byte) bool
    25  }
    26  
    27  type options struct {
    28  	cmp         cmp.Options
    29  	opensslArgs []string
    30  }
    31  
    32  type Option func(*options)
    33  
    34  func WithCmpOptions(opt ...cmp.Option) Option {
    35  	return func(o *options) { o.cmp = append(o.cmp, opt...) }
    36  }
    37  
    38  func WithOpenSSLTest(arg ...string) Option {
    39  	return func(o *options) { o.opensslArgs = arg }
    40  }
    41  
    42  //nolint:funlen,gocognit,maintidx
    43  func Test[PrivateKey, PublicKey any](t *testing.T, i Implementation[PrivateKey, PublicKey], opts ...Option) {
    44  	var opt options
    45  	for _, o := range opts {
    46  		o(&opt)
    47  	}
    48  
    49  	priv, err := i.NewPrivateKey()
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  
    54  	privBytes, err := i.MarshalPrivateKey(priv)
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  
    59  	pub := i.Public(priv)
    60  
    61  	pubBytes, err := i.MarshalPublicKey(pub)
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  
    66  	data := []byte("foo\nbar\nbaz\n")
    67  
    68  	sig, err := i.Sign(priv, data)
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  
    73  	t.Run("MarshalPrivateKey", func(t *testing.T) {
    74  		tests := []struct {
    75  			name string
    76  			in   PrivateKey
    77  			want []byte
    78  			err  error
    79  		}{
    80  			{
    81  				name: "valid key",
    82  				in:   priv,
    83  				want: privBytes,
    84  			},
    85  			{
    86  				name: "nil key",
    87  				err:  crypto.ErrInvalidKey,
    88  			},
    89  		}
    90  
    91  		for _, test := range tests {
    92  			got, err := i.MarshalPrivateKey(test.in)
    93  			if !errors.Is(err, test.err) {
    94  				t.Error(test.name, "should return error", test.err, "got", err)
    95  			} else if diff := cmp.Diff(string(test.want), string(got), opt.cmp); diff != "" {
    96  				t.Error(test.name, diff)
    97  			}
    98  		}
    99  	})
   100  
   101  	t.Run("UnmarshalPrivateKey", func(t *testing.T) {
   102  		tests := []struct {
   103  			name string
   104  			in   []byte
   105  			want PrivateKey
   106  			err  error
   107  		}{
   108  			{
   109  				name: "valid key",
   110  				in:   privBytes,
   111  				want: priv,
   112  			},
   113  			{
   114  				name: "nil bytes",
   115  				err:  crypto.ErrInvalidKey,
   116  			},
   117  			{
   118  				name: "non-key data",
   119  				in:   data,
   120  				err:  crypto.ErrInvalidKey,
   121  			},
   122  			{
   123  				name: "public key",
   124  				in:   pubBytes,
   125  				err:  crypto.ErrInvalidKey,
   126  			},
   127  		}
   128  
   129  		for _, test := range tests {
   130  			got, err := i.UnmarshalPrivateKey(test.in)
   131  			if !errors.Is(err, test.err) {
   132  				t.Errorf("%s should return error %q got %q", test.name, test.err, err)
   133  			} else if diff := cmp.Diff(test.want, got, opt.cmp); diff != "" {
   134  				t.Error(test.name, diff)
   135  			}
   136  		}
   137  	})
   138  
   139  	t.Run("MarshalPublicKey", func(t *testing.T) {
   140  		tests := []struct {
   141  			name string
   142  			in   PublicKey
   143  			want []byte
   144  			err  error
   145  		}{
   146  			{
   147  				name: "valid key",
   148  				in:   pub,
   149  				want: pubBytes,
   150  			},
   151  			{
   152  				name: "nil key",
   153  				err:  crypto.ErrInvalidKey,
   154  			},
   155  		}
   156  
   157  		for _, test := range tests {
   158  			got, err := i.MarshalPublicKey(test.in)
   159  			if !errors.Is(err, test.err) {
   160  				t.Errorf("%s should return error %q got %q", test.name, test.err, err)
   161  			} else if diff := cmp.Diff(string(test.want), string(got), opt.cmp); diff != "" {
   162  				t.Error(test.name, diff)
   163  			}
   164  		}
   165  	})
   166  
   167  	t.Run("UnmarshalPublicKey", func(t *testing.T) {
   168  		tests := []struct {
   169  			name string
   170  			in   []byte
   171  			want PublicKey
   172  			err  error
   173  		}{
   174  			{
   175  				name: "valid key",
   176  				in:   pubBytes,
   177  				want: pub,
   178  			},
   179  			{
   180  				name: "nil bytes",
   181  				err:  crypto.ErrInvalidKey,
   182  			},
   183  			{
   184  				name: "non-key data",
   185  				in:   data,
   186  				err:  crypto.ErrInvalidKey,
   187  			},
   188  			{
   189  				name: "private key",
   190  				in:   privBytes,
   191  				err:  crypto.ErrInvalidKey,
   192  			},
   193  		}
   194  
   195  		for _, test := range tests {
   196  			got, err := i.UnmarshalPublicKey(test.in)
   197  			if !errors.Is(err, test.err) {
   198  				t.Errorf("%s should return error %q got %q", test.name, test.err, err)
   199  			} else if diff := cmp.Diff(test.want, got, opt.cmp); diff != "" {
   200  				t.Errorf("%s\n%s", test.name, diff)
   201  			}
   202  		}
   203  	})
   204  
   205  	t.Run("Sign", func(t *testing.T) {
   206  		tests := []struct {
   207  			name string
   208  			key  PrivateKey
   209  			data []byte
   210  			err  error
   211  		}{
   212  			{
   213  				name: "nil key",
   214  				data: data,
   215  			},
   216  		}
   217  
   218  		for _, test := range tests {
   219  			_, err := i.Sign(test.key, data)
   220  			if err == nil {
   221  				t.Error(test.name, "should error")
   222  			}
   223  		}
   224  	})
   225  
   226  	t.Run("Verify", func(t *testing.T) {
   227  		wrongPriv, _ := i.NewPrivateKey()
   228  		wrongPub := i.Public(wrongPriv)
   229  		wrongSig, _ := i.Sign(priv, []byte("wrong data"))
   230  
   231  		tests := []struct {
   232  			name string
   233  			key  PublicKey
   234  			data []byte
   235  			sig  []byte
   236  			want bool
   237  		}{
   238  			{
   239  				name: "valid key",
   240  				key:  pub,
   241  				data: data,
   242  				sig:  sig,
   243  				want: true,
   244  			},
   245  			{
   246  				name: "nil key",
   247  				data: data,
   248  				sig:  sig,
   249  			},
   250  			{
   251  				name: "wrong key",
   252  				key:  wrongPub,
   253  				data: data,
   254  				sig:  sig,
   255  			},
   256  			{
   257  				name: "nil data",
   258  				key:  pub,
   259  				sig:  sig,
   260  			},
   261  			{
   262  				name: "nil signature",
   263  				key:  pub,
   264  				data: data,
   265  			},
   266  			{
   267  				name: "wrong signature",
   268  				key:  pub,
   269  				data: data,
   270  				sig:  wrongSig,
   271  			},
   272  		}
   273  
   274  		for _, test := range tests {
   275  			ok := i.Verify(test.key, test.data, test.sig)
   276  			if ok != test.want {
   277  				t.Errorf("%s should return %t got %t", test.name, test.want, ok)
   278  			}
   279  		}
   280  	})
   281  
   282  	if opt.opensslArgs != nil {
   283  		t.Run("OpenSSL", func(t *testing.T) {
   284  			if _, err := exec.LookPath("openssl"); err != nil {
   285  				t.Skip("openssl not in path")
   286  			}
   287  
   288  			dir := t.TempDir()
   289  			_ = os.WriteFile(filepath.Join(dir, "public.pem"), pubBytes, 0o600)
   290  			_ = os.WriteFile(filepath.Join(dir, "data.txt"), data, 0o600)
   291  			_ = os.WriteFile(filepath.Join(dir, "data.txt.sig"), sig, 0o600)
   292  
   293  			cmd := exec.Command("openssl", opt.opensslArgs...)
   294  			cmd.Dir = dir
   295  
   296  			out, err := cmd.CombinedOutput()
   297  			t.Log(string(bytes.TrimSpace(out)))
   298  			if err != nil {
   299  				t.Fatal(err)
   300  			}
   301  		})
   302  	}
   303  }