github.com/lestrrat-go/jwx/v2@v2.0.21/jws/message.go (about)

     1  package jws
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  
     8  	"github.com/lestrrat-go/jwx/v2/internal/base64"
     9  	"github.com/lestrrat-go/jwx/v2/internal/json"
    10  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    11  	"github.com/lestrrat-go/jwx/v2/jwk"
    12  )
    13  
    14  func NewSignature() *Signature {
    15  	return &Signature{}
    16  }
    17  
    18  func (s *Signature) DecodeCtx() DecodeCtx {
    19  	return s.dc
    20  }
    21  
    22  func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
    23  	s.dc = dc
    24  }
    25  
    26  func (s Signature) PublicHeaders() Headers {
    27  	return s.headers
    28  }
    29  
    30  func (s *Signature) SetPublicHeaders(v Headers) *Signature {
    31  	s.headers = v
    32  	return s
    33  }
    34  
    35  func (s Signature) ProtectedHeaders() Headers {
    36  	return s.protected
    37  }
    38  
    39  func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
    40  	s.protected = v
    41  	return s
    42  }
    43  
    44  func (s Signature) Signature() []byte {
    45  	return s.signature
    46  }
    47  
    48  func (s *Signature) SetSignature(v []byte) *Signature {
    49  	s.signature = v
    50  	return s
    51  }
    52  
    53  type signatureUnmarshalProbe struct {
    54  	Header    Headers `json:"header,omitempty"`
    55  	Protected *string `json:"protected,omitempty"`
    56  	Signature *string `json:"signature,omitempty"`
    57  }
    58  
    59  func (s *Signature) UnmarshalJSON(data []byte) error {
    60  	var sup signatureUnmarshalProbe
    61  	sup.Header = NewHeaders()
    62  	if err := json.Unmarshal(data, &sup); err != nil {
    63  		return fmt.Errorf(`failed to unmarshal signature into temporary struct: %w`, err)
    64  	}
    65  
    66  	s.headers = sup.Header
    67  	if buf := sup.Protected; buf != nil {
    68  		src := []byte(*buf)
    69  		if !bytes.HasPrefix(src, []byte{'{'}) {
    70  			decoded, err := base64.Decode(src)
    71  			if err != nil {
    72  				return fmt.Errorf(`failed to base64 decode protected headers: %w`, err)
    73  			}
    74  			src = decoded
    75  		}
    76  
    77  		prt := NewHeaders()
    78  		//nolint:forcetypeassert
    79  		prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
    80  		if err := json.Unmarshal(src, prt); err != nil {
    81  			return fmt.Errorf(`failed to unmarshal protected headers: %w`, err)
    82  		}
    83  		//nolint:forcetypeassert
    84  		prt.(*stdHeaders).SetDecodeCtx(nil)
    85  		s.protected = prt
    86  	}
    87  
    88  	if sup.Signature != nil {
    89  		decoded, err := base64.DecodeString(*sup.Signature)
    90  		if err != nil {
    91  			return fmt.Errorf(`failed to base decode signature: %w`, err)
    92  		}
    93  		s.signature = decoded
    94  	}
    95  	return nil
    96  }
    97  
    98  // Sign populates the signature field, with a signature generated by
    99  // given the signer object and payload.
   100  //
   101  // The first return value is the raw signature in binary format.
   102  // The second return value s the full three-segment signature
   103  // (e.g. "eyXXXX.XXXXX.XXXX")
   104  func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
   105  	ctx, cancel := context.WithCancel(context.Background())
   106  	defer cancel()
   107  
   108  	hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
   109  	if err != nil {
   110  		return nil, nil, fmt.Errorf(`failed to merge headers: %w`, err)
   111  	}
   112  
   113  	if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
   114  		return nil, nil, fmt.Errorf(`failed to set "alg": %w`, err)
   115  	}
   116  
   117  	// If the key is a jwk.Key instance, obtain the raw key
   118  	if jwkKey, ok := key.(jwk.Key); ok {
   119  		// If we have a key ID specified by this jwk.Key, use that in the header
   120  		if kid := jwkKey.KeyID(); kid != "" {
   121  			if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
   122  				return nil, nil, fmt.Errorf(`set key ID from jwk.Key: %w`, err)
   123  			}
   124  		}
   125  	}
   126  	hdrbuf, err := json.Marshal(hdrs)
   127  	if err != nil {
   128  		return nil, nil, fmt.Errorf(`failed to marshal headers: %w`, err)
   129  	}
   130  
   131  	buf := pool.GetBytesBuffer()
   132  	defer pool.ReleaseBytesBuffer(buf)
   133  
   134  	buf.WriteString(base64.EncodeToString(hdrbuf))
   135  	buf.WriteByte('.')
   136  
   137  	var plen int
   138  	b64 := getB64Value(hdrs)
   139  	if b64 {
   140  		encoded := base64.EncodeToString(payload)
   141  		plen = len(encoded)
   142  		buf.WriteString(encoded)
   143  	} else {
   144  		if !s.detached {
   145  			if bytes.Contains(payload, []byte{'.'}) {
   146  				return nil, nil, fmt.Errorf(`payload must not contain a "."`)
   147  			}
   148  		}
   149  		plen = len(payload)
   150  		buf.Write(payload)
   151  	}
   152  
   153  	signature, err := signer.Sign(buf.Bytes(), key)
   154  	if err != nil {
   155  		return nil, nil, fmt.Errorf(`failed to sign payload: %w`, err)
   156  	}
   157  	s.signature = signature
   158  
   159  	// Detached payload, this should be removed from the end result
   160  	if s.detached {
   161  		buf.Truncate(buf.Len() - plen)
   162  	}
   163  
   164  	buf.WriteByte('.')
   165  	buf.WriteString(base64.EncodeToString(signature))
   166  	ret := make([]byte, buf.Len())
   167  	copy(ret, buf.Bytes())
   168  
   169  	return signature, ret, nil
   170  }
   171  
   172  func NewMessage() *Message {
   173  	return &Message{}
   174  }
   175  
   176  // Clears the internal raw buffer that was accumulated during
   177  // the verify phase
   178  func (m *Message) clearRaw() {
   179  	for _, sig := range m.signatures {
   180  		if protected := sig.protected; protected != nil {
   181  			if cr, ok := protected.(*stdHeaders); ok {
   182  				cr.raw = nil
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  func (m *Message) SetDecodeCtx(dc DecodeCtx) {
   189  	m.dc = dc
   190  }
   191  
   192  func (m *Message) DecodeCtx() DecodeCtx {
   193  	return m.dc
   194  }
   195  
   196  // Payload returns the decoded payload
   197  func (m Message) Payload() []byte {
   198  	return m.payload
   199  }
   200  
   201  func (m *Message) SetPayload(v []byte) *Message {
   202  	m.payload = v
   203  	return m
   204  }
   205  
   206  func (m Message) Signatures() []*Signature {
   207  	return m.signatures
   208  }
   209  
   210  func (m *Message) AppendSignature(v *Signature) *Message {
   211  	m.signatures = append(m.signatures, v)
   212  	return m
   213  }
   214  
   215  func (m *Message) ClearSignatures() *Message {
   216  	m.signatures = nil
   217  	return m
   218  }
   219  
   220  // LookupSignature looks up a particular signature entry using
   221  // the `kid` value
   222  func (m Message) LookupSignature(kid string) []*Signature {
   223  	var sigs []*Signature
   224  	for _, sig := range m.signatures {
   225  		if hdr := sig.PublicHeaders(); hdr != nil {
   226  			hdrKeyID := hdr.KeyID()
   227  			if hdrKeyID == kid {
   228  				sigs = append(sigs, sig)
   229  				continue
   230  			}
   231  		}
   232  
   233  		if hdr := sig.ProtectedHeaders(); hdr != nil {
   234  			hdrKeyID := hdr.KeyID()
   235  			if hdrKeyID == kid {
   236  				sigs = append(sigs, sig)
   237  				continue
   238  			}
   239  		}
   240  	}
   241  	return sigs
   242  }
   243  
   244  // This struct is used to first probe for the structure of the
   245  // incoming JSON object. We then decide how to parse it
   246  // from the fields that are populated.
   247  type messageUnmarshalProbe struct {
   248  	Payload    *string           `json:"payload"`
   249  	Signatures []json.RawMessage `json:"signatures,omitempty"`
   250  	Header     Headers           `json:"header,omitempty"`
   251  	Protected  *string           `json:"protected,omitempty"`
   252  	Signature  *string           `json:"signature,omitempty"`
   253  }
   254  
   255  func (m *Message) UnmarshalJSON(buf []byte) error {
   256  	m.payload = nil
   257  	m.signatures = nil
   258  	m.b64 = true
   259  
   260  	var mup messageUnmarshalProbe
   261  	mup.Header = NewHeaders()
   262  	if err := json.Unmarshal(buf, &mup); err != nil {
   263  		return fmt.Errorf(`failed to unmarshal into temporary structure: %w`, err)
   264  	}
   265  
   266  	b64 := true
   267  	if mup.Signature == nil { // flattened signature is NOT present
   268  		if len(mup.Signatures) == 0 {
   269  			return fmt.Errorf(`required field "signatures" not present`)
   270  		}
   271  
   272  		m.signatures = make([]*Signature, 0, len(mup.Signatures))
   273  		for i, rawsig := range mup.Signatures {
   274  			var sig Signature
   275  			sig.SetDecodeCtx(m.DecodeCtx())
   276  			if err := json.Unmarshal(rawsig, &sig); err != nil {
   277  				return fmt.Errorf(`failed to unmarshal signature #%d: %w`, i+1, err)
   278  			}
   279  			sig.SetDecodeCtx(nil)
   280  
   281  			if sig.protected == nil {
   282  				// Instead of barfing on a nil protected header, use an empty header
   283  				sig.protected = NewHeaders()
   284  			}
   285  
   286  			if i == 0 {
   287  				if !getB64Value(sig.protected) {
   288  					b64 = false
   289  				}
   290  			} else {
   291  				if b64 != getB64Value(sig.protected) {
   292  					return fmt.Errorf(`b64 value must be the same for all signatures`)
   293  				}
   294  			}
   295  
   296  			m.signatures = append(m.signatures, &sig)
   297  		}
   298  	} else { // .signature is present, it's a flattened structure
   299  		if len(mup.Signatures) != 0 {
   300  			return fmt.Errorf(`invalid format ("signatures" and "signature" keys cannot both be present)`)
   301  		}
   302  
   303  		var sig Signature
   304  		sig.headers = mup.Header
   305  		if src := mup.Protected; src != nil {
   306  			decoded, err := base64.DecodeString(*src)
   307  			if err != nil {
   308  				return fmt.Errorf(`failed to base64 decode flattened protected headers: %w`, err)
   309  			}
   310  			prt := NewHeaders()
   311  			//nolint:forcetypeassert
   312  			prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
   313  			if err := json.Unmarshal(decoded, prt); err != nil {
   314  				return fmt.Errorf(`failed to unmarshal flattened protected headers: %w`, err)
   315  			}
   316  			//nolint:forcetypeassert
   317  			prt.(*stdHeaders).SetDecodeCtx(nil)
   318  			sig.protected = prt
   319  		}
   320  
   321  		if sig.protected == nil {
   322  			// Instead of barfing on a nil protected header, use an empty header
   323  			sig.protected = NewHeaders()
   324  		}
   325  
   326  		decoded, err := base64.DecodeString(*mup.Signature)
   327  		if err != nil {
   328  			return fmt.Errorf(`failed to base64 decode flattened signature: %w`, err)
   329  		}
   330  		sig.signature = decoded
   331  
   332  		m.signatures = []*Signature{&sig}
   333  		b64 = getB64Value(sig.protected)
   334  	}
   335  
   336  	if mup.Payload != nil {
   337  		if !b64 { // NOT base64 encoded
   338  			m.payload = []byte(*mup.Payload)
   339  		} else {
   340  			decoded, err := base64.DecodeString(*mup.Payload)
   341  			if err != nil {
   342  				return fmt.Errorf(`failed to base64 decode payload: %w`, err)
   343  			}
   344  			m.payload = decoded
   345  		}
   346  	}
   347  	m.b64 = b64
   348  	return nil
   349  }
   350  
   351  func (m Message) MarshalJSON() ([]byte, error) {
   352  	if len(m.signatures) == 1 {
   353  		return m.marshalFlattened()
   354  	}
   355  	return m.marshalFull()
   356  }
   357  
   358  func (m Message) marshalFlattened() ([]byte, error) {
   359  	buf := pool.GetBytesBuffer()
   360  	defer pool.ReleaseBytesBuffer(buf)
   361  
   362  	sig := m.signatures[0]
   363  
   364  	buf.WriteRune('{')
   365  	var wrote bool
   366  
   367  	if hdr := sig.headers; hdr != nil {
   368  		hdrjs, err := hdr.MarshalJSON()
   369  		if err != nil {
   370  			return nil, fmt.Errorf(`failed to marshal "header" (flattened format): %w`, err)
   371  		}
   372  		buf.WriteString(`"header":`)
   373  		buf.Write(hdrjs)
   374  		wrote = true
   375  	}
   376  
   377  	if wrote {
   378  		buf.WriteRune(',')
   379  	}
   380  	buf.WriteString(`"payload":"`)
   381  	buf.WriteString(base64.EncodeToString(m.payload))
   382  	buf.WriteRune('"')
   383  
   384  	if protected := sig.protected; protected != nil {
   385  		protectedbuf, err := protected.MarshalJSON()
   386  		if err != nil {
   387  			return nil, fmt.Errorf(`failed to marshal "protected" (flattened format): %w`, err)
   388  		}
   389  		buf.WriteString(`,"protected":"`)
   390  		buf.WriteString(base64.EncodeToString(protectedbuf))
   391  		buf.WriteRune('"')
   392  	}
   393  
   394  	buf.WriteString(`,"signature":"`)
   395  	buf.WriteString(base64.EncodeToString(sig.signature))
   396  	buf.WriteRune('"')
   397  	buf.WriteRune('}')
   398  
   399  	ret := make([]byte, buf.Len())
   400  	copy(ret, buf.Bytes())
   401  	return ret, nil
   402  }
   403  
   404  func (m Message) marshalFull() ([]byte, error) {
   405  	buf := pool.GetBytesBuffer()
   406  	defer pool.ReleaseBytesBuffer(buf)
   407  
   408  	buf.WriteString(`{"payload":"`)
   409  	buf.WriteString(base64.EncodeToString(m.payload))
   410  	buf.WriteString(`","signatures":[`)
   411  	for i, sig := range m.signatures {
   412  		if i > 0 {
   413  			buf.WriteRune(',')
   414  		}
   415  
   416  		buf.WriteRune('{')
   417  		var wrote bool
   418  		if hdr := sig.headers; hdr != nil {
   419  			hdrbuf, err := hdr.MarshalJSON()
   420  			if err != nil {
   421  				return nil, fmt.Errorf(`failed to marshal "header" for signature #%d: %w`, i+1, err)
   422  			}
   423  			buf.WriteString(`"header":`)
   424  			buf.Write(hdrbuf)
   425  			wrote = true
   426  		}
   427  
   428  		if protected := sig.protected; protected != nil {
   429  			protectedbuf, err := protected.MarshalJSON()
   430  			if err != nil {
   431  				return nil, fmt.Errorf(`failed to marshal "protected" for signature #%d: %w`, i+1, err)
   432  			}
   433  			if wrote {
   434  				buf.WriteRune(',')
   435  			}
   436  			buf.WriteString(`"protected":"`)
   437  			buf.WriteString(base64.EncodeToString(protectedbuf))
   438  			buf.WriteRune('"')
   439  			wrote = true
   440  		}
   441  
   442  		if len(sig.signature) > 0 {
   443  			// If InsecureNoSignature is enabled, signature may not exist
   444  			if wrote {
   445  				buf.WriteRune(',')
   446  			}
   447  			buf.WriteString(`"signature":"`)
   448  			buf.WriteString(base64.EncodeToString(sig.signature))
   449  			buf.WriteString(`"`)
   450  		}
   451  		buf.WriteString(`}`)
   452  	}
   453  	buf.WriteString(`]}`)
   454  
   455  	ret := make([]byte, buf.Len())
   456  	copy(ret, buf.Bytes())
   457  	return ret, nil
   458  }
   459  
   460  // Compact generates a JWS message in compact serialization format from
   461  // `*jws.Message` object. The object contain exactly one signature, or
   462  // an error is returned.
   463  //
   464  // If using a detached payload, the payload must already be stored in
   465  // the `*jws.Message` object, and the `jws.WithDetached()` option
   466  // must be passed to the function.
   467  func Compact(msg *Message, options ...CompactOption) ([]byte, error) {
   468  	if l := len(msg.signatures); l != 1 {
   469  		return nil, fmt.Errorf(`jws.Compact: cannot serialize message with %d signatures (must be one)`, l)
   470  	}
   471  
   472  	var detached bool
   473  	for _, option := range options {
   474  		//nolint:forcetypeassert
   475  		switch option.Ident() {
   476  		case identDetached{}:
   477  			detached = option.Value().(bool)
   478  		}
   479  	}
   480  
   481  	s := msg.signatures[0]
   482  	// XXX check if this is correct
   483  	hdrs := s.ProtectedHeaders()
   484  
   485  	hdrbuf, err := json.Marshal(hdrs)
   486  	if err != nil {
   487  		return nil, fmt.Errorf(`jws.Compress: failed to marshal headers: %w`, err)
   488  	}
   489  
   490  	buf := pool.GetBytesBuffer()
   491  	defer pool.ReleaseBytesBuffer(buf)
   492  
   493  	buf.WriteString(base64.EncodeToString(hdrbuf))
   494  	buf.WriteByte('.')
   495  
   496  	if !detached {
   497  		if getB64Value(hdrs) {
   498  			encoded := base64.EncodeToString(msg.payload)
   499  			buf.WriteString(encoded)
   500  		} else {
   501  			if bytes.Contains(msg.payload, []byte{'.'}) {
   502  				return nil, fmt.Errorf(`jws.Compress: payload must not contain a "."`)
   503  			}
   504  			buf.Write(msg.payload)
   505  		}
   506  	}
   507  
   508  	buf.WriteByte('.')
   509  	buf.WriteString(base64.EncodeToString(s.signature))
   510  	ret := make([]byte, buf.Len())
   511  	copy(ret, buf.Bytes())
   512  	return ret, nil
   513  }