github.com/kubri/kubri@v0.5.1-0.20240317001612-bda2aaef967e/pkg/crypto/internal/cryptotest/cryptotest.go (about) 1 package cryptotest 2 3 import ( 4 "bytes" 5 "errors" 6 "os" 7 "os/exec" 8 "path/filepath" 9 "testing" 10 11 "github.com/google/go-cmp/cmp" 12 13 "github.com/kubri/kubri/pkg/crypto" 14 ) 15 16 type Implementation[PrivateKey any, PublicKey any] struct { 17 NewPrivateKey func() (PrivateKey, error) 18 MarshalPrivateKey func(key PrivateKey) ([]byte, error) 19 UnmarshalPrivateKey func(b []byte) (PrivateKey, error) 20 Public func(key PrivateKey) PublicKey 21 MarshalPublicKey func(key PublicKey) ([]byte, error) 22 UnmarshalPublicKey func(b []byte) (PublicKey, error) 23 Sign func(key PrivateKey, data []byte) ([]byte, error) 24 Verify func(key PublicKey, data, sig []byte) bool 25 } 26 27 type options struct { 28 cmp cmp.Options 29 opensslArgs []string 30 } 31 32 type Option func(*options) 33 34 func WithCmpOptions(opt ...cmp.Option) Option { 35 return func(o *options) { o.cmp = append(o.cmp, opt...) } 36 } 37 38 func WithOpenSSLTest(arg ...string) Option { 39 return func(o *options) { o.opensslArgs = arg } 40 } 41 42 //nolint:funlen,gocognit,maintidx 43 func Test[PrivateKey, PublicKey any](t *testing.T, i Implementation[PrivateKey, PublicKey], opts ...Option) { 44 var opt options 45 for _, o := range opts { 46 o(&opt) 47 } 48 49 priv, err := i.NewPrivateKey() 50 if err != nil { 51 t.Fatal(err) 52 } 53 54 privBytes, err := i.MarshalPrivateKey(priv) 55 if err != nil { 56 t.Fatal(err) 57 } 58 59 pub := i.Public(priv) 60 61 pubBytes, err := i.MarshalPublicKey(pub) 62 if err != nil { 63 t.Fatal(err) 64 } 65 66 data := []byte("foo\nbar\nbaz\n") 67 68 sig, err := i.Sign(priv, data) 69 if err != nil { 70 t.Fatal(err) 71 } 72 73 t.Run("MarshalPrivateKey", func(t *testing.T) { 74 tests := []struct { 75 name string 76 in PrivateKey 77 want []byte 78 err error 79 }{ 80 { 81 name: "valid key", 82 in: priv, 83 want: privBytes, 84 }, 85 { 86 name: "nil key", 87 err: crypto.ErrInvalidKey, 88 }, 89 } 90 91 for _, test := range tests { 92 got, err := i.MarshalPrivateKey(test.in) 93 if !errors.Is(err, test.err) { 94 t.Error(test.name, "should return error", test.err, "got", err) 95 } else if diff := cmp.Diff(string(test.want), string(got), opt.cmp); diff != "" { 96 t.Error(test.name, diff) 97 } 98 } 99 }) 100 101 t.Run("UnmarshalPrivateKey", func(t *testing.T) { 102 tests := []struct { 103 name string 104 in []byte 105 want PrivateKey 106 err error 107 }{ 108 { 109 name: "valid key", 110 in: privBytes, 111 want: priv, 112 }, 113 { 114 name: "nil bytes", 115 err: crypto.ErrInvalidKey, 116 }, 117 { 118 name: "non-key data", 119 in: data, 120 err: crypto.ErrInvalidKey, 121 }, 122 { 123 name: "public key", 124 in: pubBytes, 125 err: crypto.ErrInvalidKey, 126 }, 127 } 128 129 for _, test := range tests { 130 got, err := i.UnmarshalPrivateKey(test.in) 131 if !errors.Is(err, test.err) { 132 t.Errorf("%s should return error %q got %q", test.name, test.err, err) 133 } else if diff := cmp.Diff(test.want, got, opt.cmp); diff != "" { 134 t.Error(test.name, diff) 135 } 136 } 137 }) 138 139 t.Run("MarshalPublicKey", func(t *testing.T) { 140 tests := []struct { 141 name string 142 in PublicKey 143 want []byte 144 err error 145 }{ 146 { 147 name: "valid key", 148 in: pub, 149 want: pubBytes, 150 }, 151 { 152 name: "nil key", 153 err: crypto.ErrInvalidKey, 154 }, 155 } 156 157 for _, test := range tests { 158 got, err := i.MarshalPublicKey(test.in) 159 if !errors.Is(err, test.err) { 160 t.Errorf("%s should return error %q got %q", test.name, test.err, err) 161 } else if diff := cmp.Diff(string(test.want), string(got), opt.cmp); diff != "" { 162 t.Error(test.name, diff) 163 } 164 } 165 }) 166 167 t.Run("UnmarshalPublicKey", func(t *testing.T) { 168 tests := []struct { 169 name string 170 in []byte 171 want PublicKey 172 err error 173 }{ 174 { 175 name: "valid key", 176 in: pubBytes, 177 want: pub, 178 }, 179 { 180 name: "nil bytes", 181 err: crypto.ErrInvalidKey, 182 }, 183 { 184 name: "non-key data", 185 in: data, 186 err: crypto.ErrInvalidKey, 187 }, 188 { 189 name: "private key", 190 in: privBytes, 191 err: crypto.ErrInvalidKey, 192 }, 193 } 194 195 for _, test := range tests { 196 got, err := i.UnmarshalPublicKey(test.in) 197 if !errors.Is(err, test.err) { 198 t.Errorf("%s should return error %q got %q", test.name, test.err, err) 199 } else if diff := cmp.Diff(test.want, got, opt.cmp); diff != "" { 200 t.Errorf("%s\n%s", test.name, diff) 201 } 202 } 203 }) 204 205 t.Run("Sign", func(t *testing.T) { 206 tests := []struct { 207 name string 208 key PrivateKey 209 data []byte 210 err error 211 }{ 212 { 213 name: "nil key", 214 data: data, 215 }, 216 } 217 218 for _, test := range tests { 219 _, err := i.Sign(test.key, data) 220 if err == nil { 221 t.Error(test.name, "should error") 222 } 223 } 224 }) 225 226 t.Run("Verify", func(t *testing.T) { 227 wrongPriv, _ := i.NewPrivateKey() 228 wrongPub := i.Public(wrongPriv) 229 wrongSig, _ := i.Sign(priv, []byte("wrong data")) 230 231 tests := []struct { 232 name string 233 key PublicKey 234 data []byte 235 sig []byte 236 want bool 237 }{ 238 { 239 name: "valid key", 240 key: pub, 241 data: data, 242 sig: sig, 243 want: true, 244 }, 245 { 246 name: "nil key", 247 data: data, 248 sig: sig, 249 }, 250 { 251 name: "wrong key", 252 key: wrongPub, 253 data: data, 254 sig: sig, 255 }, 256 { 257 name: "nil data", 258 key: pub, 259 sig: sig, 260 }, 261 { 262 name: "nil signature", 263 key: pub, 264 data: data, 265 }, 266 { 267 name: "wrong signature", 268 key: pub, 269 data: data, 270 sig: wrongSig, 271 }, 272 } 273 274 for _, test := range tests { 275 ok := i.Verify(test.key, test.data, test.sig) 276 if ok != test.want { 277 t.Errorf("%s should return %t got %t", test.name, test.want, ok) 278 } 279 } 280 }) 281 282 if opt.opensslArgs != nil { 283 t.Run("OpenSSL", func(t *testing.T) { 284 if _, err := exec.LookPath("openssl"); err != nil { 285 t.Skip("openssl not in path") 286 } 287 288 dir := t.TempDir() 289 _ = os.WriteFile(filepath.Join(dir, "public.pem"), pubBytes, 0o600) 290 _ = os.WriteFile(filepath.Join(dir, "data.txt"), data, 0o600) 291 _ = os.WriteFile(filepath.Join(dir, "data.txt.sig"), sig, 0o600) 292 293 cmd := exec.Command("openssl", opt.opensslArgs...) 294 cmd.Dir = dir 295 296 out, err := cmd.CombinedOutput() 297 t.Log(string(bytes.TrimSpace(out))) 298 if err != nil { 299 t.Fatal(err) 300 } 301 }) 302 } 303 }