github.com/lestrrat-go/jwx/v2@v2.0.21/internal/jwxtest/jwxtest.go (about)

     1  package jwxtest
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"crypto/ed25519"
     8  	"crypto/elliptic"
     9  	"crypto/rand"
    10  	"crypto/rsa"
    11  	"encoding/json"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/lestrrat-go/jwx/v2/internal/ecutil"
    19  	"github.com/lestrrat-go/jwx/v2/jwa"
    20  	"github.com/lestrrat-go/jwx/v2/jwe"
    21  	"github.com/lestrrat-go/jwx/v2/jwk"
    22  	"github.com/lestrrat-go/jwx/v2/jws"
    23  	"github.com/lestrrat-go/jwx/v2/x25519"
    24  	"github.com/stretchr/testify/assert"
    25  )
    26  
    27  func GenerateRsaKey() (*rsa.PrivateKey, error) {
    28  	return rsa.GenerateKey(rand.Reader, 2048)
    29  }
    30  
    31  func GenerateRsaJwk() (jwk.Key, error) {
    32  	key, err := GenerateRsaKey()
    33  	if err != nil {
    34  		return nil, fmt.Errorf(`failed to generate RSA private key: %w`, err)
    35  	}
    36  
    37  	k, err := jwk.FromRaw(key)
    38  	if err != nil {
    39  		return nil, fmt.Errorf(`failed to generate jwk.RSAPrivateKey: %w`, err)
    40  	}
    41  
    42  	return k, nil
    43  }
    44  
    45  func GenerateRsaPublicJwk() (jwk.Key, error) {
    46  	key, err := GenerateRsaJwk()
    47  	if err != nil {
    48  		return nil, fmt.Errorf(`failed to generate jwk.RSAPrivateKey: %w`, err)
    49  	}
    50  
    51  	return jwk.PublicKeyOf(key)
    52  }
    53  
    54  func GenerateEcdsaKey(alg jwa.EllipticCurveAlgorithm) (*ecdsa.PrivateKey, error) {
    55  	var crv elliptic.Curve
    56  	if tmp, ok := ecutil.CurveForAlgorithm(alg); ok {
    57  		crv = tmp
    58  	} else {
    59  		return nil, fmt.Errorf(`invalid curve algorithm %s`, alg)
    60  	}
    61  
    62  	return ecdsa.GenerateKey(crv, rand.Reader)
    63  }
    64  
    65  func GenerateEcdsaJwk() (jwk.Key, error) {
    66  	key, err := GenerateEcdsaKey(jwa.P521)
    67  	if err != nil {
    68  		return nil, fmt.Errorf(`failed to generate ECDSA private key: %w`, err)
    69  	}
    70  
    71  	k, err := jwk.FromRaw(key)
    72  	if err != nil {
    73  		return nil, fmt.Errorf(`failed to generate jwk.ECDSAPrivateKey: %w`, err)
    74  	}
    75  
    76  	return k, nil
    77  }
    78  
    79  func GenerateEcdsaPublicJwk() (jwk.Key, error) {
    80  	key, err := GenerateEcdsaJwk()
    81  	if err != nil {
    82  		return nil, fmt.Errorf(`failed to generate jwk.ECDSAPrivateKey: %w`, err)
    83  	}
    84  
    85  	return jwk.PublicKeyOf(key)
    86  }
    87  
    88  func GenerateSymmetricKey() []byte {
    89  	sharedKey := make([]byte, 64)
    90  	rand.Read(sharedKey)
    91  	return sharedKey
    92  }
    93  
    94  func GenerateSymmetricJwk() (jwk.Key, error) {
    95  	key, err := jwk.FromRaw(GenerateSymmetricKey())
    96  	if err != nil {
    97  		return nil, fmt.Errorf(`failed to generate jwk.SymmetricKey: %w`, err)
    98  	}
    99  
   100  	return key, nil
   101  }
   102  
   103  func GenerateEd25519Key() (ed25519.PrivateKey, error) {
   104  	_, priv, err := ed25519.GenerateKey(rand.Reader)
   105  	return priv, err
   106  }
   107  
   108  func GenerateEd25519Jwk() (jwk.Key, error) {
   109  	key, err := GenerateEd25519Key()
   110  	if err != nil {
   111  		return nil, fmt.Errorf(`failed to generate Ed25519 private key: %w`, err)
   112  	}
   113  
   114  	k, err := jwk.FromRaw(key)
   115  	if err != nil {
   116  		return nil, fmt.Errorf(`failed to generate jwk.OKPPrivateKey: %w`, err)
   117  	}
   118  
   119  	return k, nil
   120  }
   121  
   122  func GenerateX25519Key() (x25519.PrivateKey, error) {
   123  	_, priv, err := x25519.GenerateKey(rand.Reader)
   124  	return priv, err
   125  }
   126  
   127  func GenerateX25519Jwk() (jwk.Key, error) {
   128  	key, err := GenerateX25519Key()
   129  	if err != nil {
   130  		return nil, fmt.Errorf(`failed to generate X25519 private key: %w`, err)
   131  	}
   132  
   133  	k, err := jwk.FromRaw(key)
   134  	if err != nil {
   135  		return nil, fmt.Errorf(`failed to generate jwk.OKPPrivateKey: %w`, err)
   136  	}
   137  
   138  	return k, nil
   139  }
   140  
   141  func WriteFile(template string, src io.Reader) (string, func(), error) {
   142  	file, cleanup, err := CreateTempFile(template)
   143  	if err != nil {
   144  		return "", nil, fmt.Errorf(`failed to create temporary file: %w`, err)
   145  	}
   146  
   147  	if _, err := io.Copy(file, src); err != nil {
   148  		defer cleanup()
   149  		return "", nil, fmt.Errorf(`failed to copy content to temporary file: %w`, err)
   150  	}
   151  
   152  	if err := file.Sync(); err != nil {
   153  		defer cleanup()
   154  		return "", nil, fmt.Errorf(`failed to sync file: %w`, err)
   155  	}
   156  	return file.Name(), cleanup, nil
   157  }
   158  
   159  func WriteJSONFile(template string, v interface{}) (string, func(), error) {
   160  	var buf bytes.Buffer
   161  
   162  	enc := json.NewEncoder(&buf)
   163  	if err := enc.Encode(v); err != nil {
   164  		return "", nil, fmt.Errorf(`failed to encode object to JSON: %w`, err)
   165  	}
   166  	return WriteFile(template, &buf)
   167  }
   168  
   169  func DumpFile(t *testing.T, file string) {
   170  	buf, err := os.ReadFile(file)
   171  	if !assert.NoError(t, err, `failed to read file %s for debugging`, file) {
   172  		return
   173  	}
   174  
   175  	if isHash, isArray := bytes.ContainsRune(buf, '{'), bytes.ContainsRune(buf, '['); isHash || isArray {
   176  		// Looks like a JSON-like thing. Dump that in a formatted manner, and
   177  		// be done with it
   178  
   179  		var v interface{}
   180  		if isHash {
   181  			v = map[string]interface{}{}
   182  		} else {
   183  			v = []interface{}{}
   184  		}
   185  
   186  		if !assert.NoError(t, json.Unmarshal(buf, &v), `failed to parse contents as JSON`) {
   187  			return
   188  		}
   189  
   190  		buf, _ = json.MarshalIndent(v, "", "  ")
   191  		t.Logf("=== BEGIN %s (formatted JSON) ===", file)
   192  		t.Logf("%s", buf)
   193  		t.Logf("=== END   %s (formatted JSON) ===", file)
   194  		return
   195  	}
   196  
   197  	// If the contents do not look like JSON, then we attempt to parse each content
   198  	// based on heuristics (from its file name) and do our best
   199  	t.Logf("=== BEGIN %s (raw) ===", file)
   200  	t.Logf("%s", buf)
   201  	t.Logf("=== END   %s (raw) ===", file)
   202  
   203  	if strings.HasSuffix(file, ".jwe") {
   204  		// cross our fingers our jwe implementation works
   205  		m, err := jwe.Parse(buf)
   206  		if !assert.NoError(t, err, `failed to parse JWE encrypted message`) {
   207  			return
   208  		}
   209  
   210  		buf, _ = json.MarshalIndent(m, "", "  ")
   211  	}
   212  
   213  	t.Logf("=== BEGIN %s (formatted JSON) ===", file)
   214  	t.Logf("%s", buf)
   215  	t.Logf("=== END   %s (formatted JSON) ===", file)
   216  }
   217  
   218  func CreateTempFile(template string) (*os.File, func(), error) {
   219  	file, err := os.CreateTemp("", template)
   220  	if err != nil {
   221  		return nil, nil, fmt.Errorf(`failed to create temporary file: %w`, err)
   222  	}
   223  
   224  	cleanup := func() {
   225  		file.Close()
   226  		os.Remove(file.Name())
   227  	}
   228  
   229  	return file, cleanup, nil
   230  }
   231  
   232  func ReadFile(file string) ([]byte, error) {
   233  	f, err := os.Open(file)
   234  	if err != nil {
   235  		return nil, fmt.Errorf(`failed to open file %s: %w`, file, err)
   236  	}
   237  	defer f.Close()
   238  
   239  	buf, err := io.ReadAll(f)
   240  	if err != nil {
   241  		return nil, fmt.Errorf(`failed to read from key file %s: %w`, file, err)
   242  	}
   243  
   244  	return buf, nil
   245  }
   246  
   247  func ParseJwkFile(_ context.Context, file string) (jwk.Key, error) {
   248  	buf, err := ReadFile(file)
   249  	if err != nil {
   250  		return nil, fmt.Errorf(`failed to read from key file %s: %w`, file, err)
   251  	}
   252  
   253  	key, err := jwk.ParseKey(buf)
   254  	if err != nil {
   255  		return nil, fmt.Errorf(`filed to parse JWK in key file %s: %w`, file, err)
   256  	}
   257  
   258  	return key, nil
   259  }
   260  
   261  func DecryptJweFile(ctx context.Context, file string, alg jwa.KeyEncryptionAlgorithm, jwkfile string) ([]byte, error) {
   262  	key, err := ParseJwkFile(ctx, jwkfile)
   263  	if err != nil {
   264  		return nil, fmt.Errorf(`failed to parse keyfile %s: %w`, file, err)
   265  	}
   266  
   267  	buf, err := ReadFile(file)
   268  	if err != nil {
   269  		return nil, fmt.Errorf(`failed to read from encrypted file %s: %w`, file, err)
   270  	}
   271  
   272  	var rawkey interface{}
   273  	if err := key.Raw(&rawkey); err != nil {
   274  		return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err)
   275  	}
   276  
   277  	return jwe.Decrypt(buf, jwe.WithKey(alg, rawkey))
   278  }
   279  
   280  func EncryptJweFile(ctx context.Context, payload []byte, keyalg jwa.KeyEncryptionAlgorithm, keyfile string, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm) (string, func(), error) {
   281  	key, err := ParseJwkFile(ctx, keyfile)
   282  	if err != nil {
   283  		return "", nil, fmt.Errorf(`failed to parse keyfile %s: %w`, keyfile, err)
   284  	}
   285  
   286  	var keyif interface{}
   287  
   288  	switch keyalg {
   289  	case jwa.RSA1_5, jwa.RSA_OAEP, jwa.RSA_OAEP_256:
   290  		var rawkey rsa.PrivateKey
   291  		if err := key.Raw(&rawkey); err != nil {
   292  			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
   293  		}
   294  		keyif = rawkey.PublicKey
   295  	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   296  		var rawkey ecdsa.PrivateKey
   297  		if err := key.Raw(&rawkey); err != nil {
   298  			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
   299  		}
   300  		keyif = rawkey.PublicKey
   301  	default:
   302  		var rawkey []byte
   303  		if err := key.Raw(&rawkey); err != nil {
   304  			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
   305  		}
   306  		keyif = rawkey
   307  	}
   308  
   309  	buf, err := jwe.Encrypt(payload, jwe.WithKey(keyalg, keyif), jwe.WithContentEncryption(contentalg), jwe.WithCompress(compressalg))
   310  	if err != nil {
   311  		return "", nil, fmt.Errorf(`failed to encrypt payload: %w`, err)
   312  	}
   313  
   314  	return WriteFile("jwx-test-*.jwe", bytes.NewReader(buf))
   315  }
   316  
   317  func VerifyJwsFile(ctx context.Context, file string, alg jwa.SignatureAlgorithm, jwkfile string) ([]byte, error) {
   318  	key, err := ParseJwkFile(ctx, jwkfile)
   319  	if err != nil {
   320  		return nil, fmt.Errorf(`failed to parse keyfile %s: %w`, file, err)
   321  	}
   322  
   323  	buf, err := ReadFile(file)
   324  	if err != nil {
   325  		return nil, fmt.Errorf(`failed to read from encrypted file %s: %w`, file, err)
   326  	}
   327  
   328  	var rawkey, pubkey interface{}
   329  	if err := key.Raw(&rawkey); err != nil {
   330  		return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err)
   331  	}
   332  	pubkey = rawkey
   333  	switch tkey := rawkey.(type) {
   334  	case *ecdsa.PrivateKey:
   335  		pubkey = tkey.PublicKey
   336  	case *rsa.PrivateKey:
   337  		pubkey = tkey.PublicKey
   338  	case *ed25519.PrivateKey:
   339  		pubkey = tkey.Public()
   340  	}
   341  
   342  	return jws.Verify(buf, jws.WithKey(alg, pubkey))
   343  }
   344  
   345  func SignJwsFile(ctx context.Context, payload []byte, alg jwa.SignatureAlgorithm, keyfile string) (string, func(), error) {
   346  	key, err := ParseJwkFile(ctx, keyfile)
   347  	if err != nil {
   348  		return "", nil, fmt.Errorf(`failed to parse keyfile %s: %w`, keyfile, err)
   349  	}
   350  
   351  	buf, err := jws.Sign(payload, jws.WithKey(alg, key))
   352  	if err != nil {
   353  		return "", nil, fmt.Errorf(`failed to sign payload: %w`, err)
   354  	}
   355  
   356  	return WriteFile("jwx-test-*.jws", bytes.NewReader(buf))
   357  }