github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/util/didsignjwt/signjwt.go (about)

     1  /*
     2  Copyright Avast Software. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package didsignjwt
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"strings"
    13  
    14  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    15  	"github.com/hyperledger/aries-framework-go/pkg/doc/jose"
    16  	"github.com/hyperledger/aries-framework-go/pkg/doc/jwt"
    17  	"github.com/hyperledger/aries-framework-go/pkg/doc/util/jwkkid"
    18  	"github.com/hyperledger/aries-framework-go/pkg/doc/util/vmparse"
    19  	"github.com/hyperledger/aries-framework-go/pkg/doc/verifiable"
    20  	"github.com/hyperledger/aries-framework-go/pkg/framework/aries/api/vdr"
    21  	"github.com/hyperledger/aries-framework-go/pkg/internal/kmssigner"
    22  )
    23  
    24  const (
    25  	ed25519VerificationKey2018 = "Ed25519VerificationKey2018"
    26  
    27  	// number of sections in verification method.
    28  	vmSectionCount = 2
    29  )
    30  
    31  type keyReader interface {
    32  	// Get key handle for the given keyID
    33  	// Returns:
    34  	//  - handle instance (to private key)
    35  	//  - error if failure
    36  	Get(keyID string) (interface{}, error)
    37  }
    38  
    39  type didResolver interface {
    40  	Resolve(did string, opts ...vdr.DIDMethodOption) (*did.DocResolution, error)
    41  }
    42  
    43  type cryptoSigner interface {
    44  	// Sign will sign msg using a matching signature primitive in kh key handle of a private key
    45  	// returns:
    46  	// 		signature in []byte
    47  	//		error in case of errors
    48  	Sign(msg []byte, kh interface{}) ([]byte, error)
    49  }
    50  
    51  // A Signer is capable of signing data.
    52  type Signer interface {
    53  	// Sign provides a signature for msg.
    54  	Sign(msg []byte) ([]byte, error)
    55  }
    56  
    57  type defaultSigner struct {
    58  	keyHandle interface{}
    59  	signer    cryptoSigner
    60  }
    61  
    62  // SignerGetter creates a signer that signs with the private key corresponding to the given public key.
    63  type SignerGetter func(vm *did.VerificationMethod) (Signer, error)
    64  
    65  // UseDefaultSigner provides SignJWT with a signer that uses the given KMS and Crypto instances.
    66  func UseDefaultSigner(r keyReader, s cryptoSigner) SignerGetter {
    67  	return func(vm *did.VerificationMethod) (Signer, error) {
    68  		pubKey, keyType, _, err := vmparse.VMToBytesTypeCrv(vm)
    69  		if err != nil {
    70  			return nil, fmt.Errorf("parsing verification method: %w", err)
    71  		}
    72  
    73  		kmsKID, err := jwkkid.CreateKID(pubKey, keyType)
    74  		if err != nil {
    75  			return nil, fmt.Errorf("determining the internal ID of the signing key: %w", err)
    76  		}
    77  
    78  		keyHandle, err := r.Get(kmsKID)
    79  		if err != nil {
    80  			return nil, fmt.Errorf("fetching the signing key from the key manager: %w", err)
    81  		}
    82  
    83  		return &defaultSigner{
    84  			keyHandle: keyHandle,
    85  			signer:    s,
    86  		}, nil
    87  	}
    88  }
    89  
    90  // Sign signs the given message using the key this signer holds a reference to.
    91  func (s *defaultSigner) Sign(msg []byte) ([]byte, error) {
    92  	return s.signer.Sign(msg, s.keyHandle)
    93  }
    94  
    95  // SignJWT signs a JWT using a key in the given KMS, identified by an owned DID.
    96  //
    97  //	Args:
    98  //		- Headers to include in the created JWT.
    99  //		- Claims for the created JWT.
   100  //		- The ID of the key to use for signing, as a DID, either with a fragment identifier to specify a verification
   101  //		  method, or without, in which case the first Authentication or Assertion verification method is used.
   102  //		- A SignerGetter that can provide a signer when given the key ID for the signing key.
   103  //		- A VDR that can resolve the provided DID.
   104  func SignJWT( // nolint: funlen,gocyclo
   105  	headers,
   106  	claims map[string]interface{},
   107  	kid string,
   108  	signerProvider SignerGetter,
   109  	didResolver didResolver,
   110  ) (string, error) {
   111  	vm, vmID, err := ResolveSigningVM(kid, didResolver)
   112  	if err != nil {
   113  		return "", err
   114  	}
   115  
   116  	keyType, crv, err := vmparse.VMToTypeCrv(vm)
   117  	if err != nil {
   118  		return "", fmt.Errorf("parsing verification method: %w", err)
   119  	}
   120  
   121  	ss, err := signerProvider(vm)
   122  	if err != nil {
   123  		return "", err
   124  	}
   125  
   126  	if headers == nil {
   127  		headers = map[string]interface{}{}
   128  	}
   129  
   130  	if claims == nil {
   131  		claims = map[string]interface{}{}
   132  	}
   133  
   134  	headers[jose.HeaderType] = "JWT"
   135  	headers[jose.HeaderAlgorithm] = kmssigner.KeyTypeToJWA(keyType)
   136  	headers["crv"] = crv
   137  	headers[jose.HeaderKeyID] = vmID
   138  
   139  	tok, err := jwt.NewSigned(claims, headers, getJWTSigner(ss, kmssigner.KeyTypeToJWA(keyType)))
   140  	if err != nil {
   141  		return "", fmt.Errorf("signing JWT: %w", err)
   142  	}
   143  
   144  	compact, err := tok.Serialize(false)
   145  	if err != nil {
   146  		return "", fmt.Errorf("serializing JWT: %w", err)
   147  	}
   148  
   149  	return compact, nil
   150  }
   151  
   152  // VerifyJWT verifies a JWT that was signed with a DID.
   153  //
   154  // Args:
   155  //   - JWT to verify.
   156  //   - A VDR that can resolve the JWT's signing DID.
   157  func VerifyJWT(compactJWT string,
   158  	didResolver didResolver) error {
   159  	_, _, err := jwt.Parse(compactJWT, jwt.WithSignatureVerifier(jwt.NewVerifier(
   160  		jwt.KeyResolverFunc(verifiable.NewVDRKeyResolver(didResolver).PublicKeyFetcher())),
   161  	))
   162  	if err != nil {
   163  		return fmt.Errorf("jwt verification failed: %w", err)
   164  	}
   165  
   166  	return nil
   167  }
   168  
   169  // ResolveSigningVM resolves a DID KeyID using the given did resolver, and returns either:
   170  //
   171  //   - the Verification Method identified by the given key ID, or
   172  //   - the first Assertion Method in the DID doc, if the DID provided has no fragment component.
   173  //
   174  // Returns:
   175  //   - a verification method suitable for signing.
   176  //   - the full DID#KID identifier of the returned verification method.
   177  func ResolveSigningVM(kid string, didResolver didResolver) (*did.VerificationMethod, string, error) {
   178  	vmSplit := strings.Split(kid, "#")
   179  
   180  	if len(vmSplit) > vmSectionCount {
   181  		return nil, "", errors.New("invalid verification method format")
   182  	}
   183  
   184  	signingDID := vmSplit[0]
   185  
   186  	docRes, err := didResolver.Resolve(signingDID)
   187  	if err != nil {
   188  		return nil, "", fmt.Errorf("failed to resolve signing DID: %w", err)
   189  	}
   190  
   191  	if len(vmSplit) == 1 {
   192  		// look for assertionmethod
   193  		verificationMethods := docRes.DIDDocument.VerificationMethods(did.AssertionMethod)
   194  
   195  		if len(verificationMethods[did.AssertionMethod]) > 0 {
   196  			vm := verificationMethods[did.AssertionMethod][0].VerificationMethod
   197  
   198  			return &vm, fullVMID(signingDID, vm.ID), nil
   199  		}
   200  
   201  		return nil, "", fmt.Errorf("DID provided has no assertion method to use as a default signing key")
   202  	}
   203  
   204  	vmID := vmSplit[vmSectionCount-1]
   205  
   206  	for _, verifications := range docRes.DIDDocument.VerificationMethods() {
   207  		for _, verification := range verifications {
   208  			if isSigningKey(verification.Relationship) && vmIDFragmentOnly(verification.VerificationMethod.ID) == vmID {
   209  				vm := verification.VerificationMethod
   210  				return &vm, kid, nil
   211  			}
   212  		}
   213  	}
   214  
   215  	return nil, "", fmt.Errorf("did document has no verification method with given ID")
   216  }
   217  
   218  func fullVMID(did, vmID string) string {
   219  	vmIDSplit := strings.Split(vmID, "#")
   220  
   221  	if len(vmIDSplit) == 1 {
   222  		return did + "#" + vmIDSplit[0]
   223  	} else if len(vmIDSplit[0]) == 0 {
   224  		return did + "#" + vmIDSplit[1]
   225  	}
   226  
   227  	return vmID
   228  }
   229  
   230  func vmIDFragmentOnly(vmID string) string {
   231  	vmSplit := strings.Split(vmID, "#")
   232  	if len(vmSplit) == 1 {
   233  		return vmSplit[0]
   234  	}
   235  
   236  	return vmSplit[1]
   237  }
   238  
   239  func isSigningKey(vr did.VerificationRelationship) bool {
   240  	switch vr {
   241  	case did.AssertionMethod, did.Authentication, did.VerificationRelationshipGeneral:
   242  		return true
   243  	}
   244  
   245  	return false
   246  }
   247  
   248  type sign interface {
   249  	Sign(data []byte) ([]byte, error)
   250  }
   251  
   252  // jwtSigner implement jose.Signer interface.
   253  type jwtSigner struct {
   254  	signer  sign
   255  	headers map[string]interface{}
   256  }
   257  
   258  func getJWTSigner(signer sign, algorithm string) *jwtSigner {
   259  	headers := map[string]interface{}{
   260  		jose.HeaderAlgorithm: algorithm,
   261  	}
   262  
   263  	return &jwtSigner{signer: signer, headers: headers}
   264  }
   265  
   266  func (s jwtSigner) Sign(data []byte) ([]byte, error) {
   267  	return s.signer.Sign(data)
   268  }
   269  
   270  func (s jwtSigner) Headers() jose.Headers {
   271  	return s.headers
   272  }