github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/message.go (about)

     1  package jwe
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/lestrrat-go/jwx/v2/internal/base64"
    10  	"github.com/lestrrat-go/jwx/v2/internal/json"
    11  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    12  )
    13  
    14  // NewRecipient creates a Recipient object
    15  func NewRecipient() Recipient {
    16  	return &stdRecipient{
    17  		headers: NewHeaders(),
    18  	}
    19  }
    20  
    21  func (r *stdRecipient) SetHeaders(h Headers) error {
    22  	r.headers = h
    23  	return nil
    24  }
    25  
    26  func (r *stdRecipient) SetEncryptedKey(v []byte) error {
    27  	r.encryptedKey = v
    28  	return nil
    29  }
    30  
    31  func (r *stdRecipient) Headers() Headers {
    32  	return r.headers
    33  }
    34  
    35  func (r *stdRecipient) EncryptedKey() []byte {
    36  	return r.encryptedKey
    37  }
    38  
    39  type recipientMarshalProxy struct {
    40  	Headers      Headers `json:"header"`
    41  	EncryptedKey string  `json:"encrypted_key"`
    42  }
    43  
    44  func (r *stdRecipient) UnmarshalJSON(buf []byte) error {
    45  	var proxy recipientMarshalProxy
    46  	proxy.Headers = NewHeaders()
    47  	if err := json.Unmarshal(buf, &proxy); err != nil {
    48  		return fmt.Errorf(`failed to unmarshal json into recipient: %w`, err)
    49  	}
    50  
    51  	r.headers = proxy.Headers
    52  	decoded, err := base64.DecodeString(proxy.EncryptedKey)
    53  	if err != nil {
    54  		return fmt.Errorf(`failed to decode "encrypted_key": %w`, err)
    55  	}
    56  	r.encryptedKey = decoded
    57  	return nil
    58  }
    59  
    60  func (r *stdRecipient) MarshalJSON() ([]byte, error) {
    61  	buf := pool.GetBytesBuffer()
    62  	defer pool.ReleaseBytesBuffer(buf)
    63  
    64  	buf.WriteString(`{"header":`)
    65  	hdrbuf, err := r.headers.MarshalJSON()
    66  	if err != nil {
    67  		return nil, fmt.Errorf(`failed to marshal recipient header: %w`, err)
    68  	}
    69  	buf.Write(hdrbuf)
    70  	buf.WriteString(`,"encrypted_key":"`)
    71  	buf.WriteString(base64.EncodeToString(r.encryptedKey))
    72  	buf.WriteString(`"}`)
    73  
    74  	ret := make([]byte, buf.Len())
    75  	copy(ret, buf.Bytes())
    76  	return ret, nil
    77  }
    78  
    79  // NewMessage creates a new message
    80  func NewMessage() *Message {
    81  	return &Message{}
    82  }
    83  
    84  func (m *Message) AuthenticatedData() []byte {
    85  	return m.authenticatedData
    86  }
    87  
    88  func (m *Message) CipherText() []byte {
    89  	return m.cipherText
    90  }
    91  
    92  func (m *Message) InitializationVector() []byte {
    93  	return m.initializationVector
    94  }
    95  
    96  func (m *Message) Tag() []byte {
    97  	return m.tag
    98  }
    99  
   100  func (m *Message) ProtectedHeaders() Headers {
   101  	return m.protectedHeaders
   102  }
   103  
   104  func (m *Message) Recipients() []Recipient {
   105  	return m.recipients
   106  }
   107  
   108  func (m *Message) UnprotectedHeaders() Headers {
   109  	return m.unprotectedHeaders
   110  }
   111  
   112  const (
   113  	AuthenticatedDataKey    = "aad"
   114  	CipherTextKey           = "ciphertext"
   115  	CountKey                = "p2c"
   116  	InitializationVectorKey = "iv"
   117  	ProtectedHeadersKey     = "protected"
   118  	RecipientsKey           = "recipients"
   119  	SaltKey                 = "p2s"
   120  	TagKey                  = "tag"
   121  	UnprotectedHeadersKey   = "unprotected"
   122  	HeadersKey              = "header"
   123  	EncryptedKeyKey         = "encrypted_key"
   124  )
   125  
   126  func (m *Message) Set(k string, v interface{}) error {
   127  	switch k {
   128  	case AuthenticatedDataKey:
   129  		buf, ok := v.([]byte)
   130  		if !ok {
   131  			return fmt.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey)
   132  		}
   133  		m.authenticatedData = buf
   134  	case CipherTextKey:
   135  		buf, ok := v.([]byte)
   136  		if !ok {
   137  			return fmt.Errorf(`invalid value %T for %s key`, v, CipherTextKey)
   138  		}
   139  		m.cipherText = buf
   140  	case InitializationVectorKey:
   141  		buf, ok := v.([]byte)
   142  		if !ok {
   143  			return fmt.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey)
   144  		}
   145  		m.initializationVector = buf
   146  	case ProtectedHeadersKey:
   147  		cv, ok := v.(Headers)
   148  		if !ok {
   149  			return fmt.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey)
   150  		}
   151  		m.protectedHeaders = cv
   152  	case RecipientsKey:
   153  		cv, ok := v.([]Recipient)
   154  		if !ok {
   155  			return fmt.Errorf(`invalid value %T for %s key`, v, RecipientsKey)
   156  		}
   157  		m.recipients = cv
   158  	case TagKey:
   159  		buf, ok := v.([]byte)
   160  		if !ok {
   161  			return fmt.Errorf(`invalid value %T for %s key`, v, TagKey)
   162  		}
   163  		m.tag = buf
   164  	case UnprotectedHeadersKey:
   165  		cv, ok := v.(Headers)
   166  		if !ok {
   167  			return fmt.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey)
   168  		}
   169  		m.unprotectedHeaders = cv
   170  	default:
   171  		if m.unprotectedHeaders == nil {
   172  			m.unprotectedHeaders = NewHeaders()
   173  		}
   174  		return m.unprotectedHeaders.Set(k, v)
   175  	}
   176  	return nil
   177  }
   178  
   179  type messageMarshalProxy struct {
   180  	AuthenticatedData    string            `json:"aad,omitempty"`
   181  	CipherText           string            `json:"ciphertext"`
   182  	InitializationVector string            `json:"iv,omitempty"`
   183  	ProtectedHeaders     json.RawMessage   `json:"protected"`
   184  	Recipients           []json.RawMessage `json:"recipients,omitempty"`
   185  	Tag                  string            `json:"tag,omitempty"`
   186  	UnprotectedHeaders   Headers           `json:"unprotected,omitempty"`
   187  
   188  	// For flattened structure. Headers is NOT a Headers type,
   189  	// so that we can detect its presence by checking proxy.Headers != nil
   190  	Headers      json.RawMessage `json:"header,omitempty"`
   191  	EncryptedKey string          `json:"encrypted_key,omitempty"`
   192  }
   193  
   194  type jsonKV struct {
   195  	Key   string
   196  	Value string
   197  }
   198  
   199  func (m *Message) MarshalJSON() ([]byte, error) {
   200  	// This is slightly convoluted, but we need to encode the
   201  	// protected headers, so we do it by hand
   202  	buf := pool.GetBytesBuffer()
   203  	defer pool.ReleaseBytesBuffer(buf)
   204  	enc := json.NewEncoder(buf)
   205  
   206  	var fields []jsonKV
   207  
   208  	if cipherText := m.CipherText(); len(cipherText) > 0 {
   209  		buf.Reset()
   210  		if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil {
   211  			return nil, fmt.Errorf(`failed to encode %s field: %w`, CipherTextKey, err)
   212  		}
   213  		fields = append(fields, jsonKV{
   214  			Key:   CipherTextKey,
   215  			Value: strings.TrimSpace(buf.String()),
   216  		})
   217  	}
   218  
   219  	if iv := m.InitializationVector(); len(iv) > 0 {
   220  		buf.Reset()
   221  		if err := enc.Encode(base64.EncodeToString(iv)); err != nil {
   222  			return nil, fmt.Errorf(`failed to encode %s field: %w`, InitializationVectorKey, err)
   223  		}
   224  		fields = append(fields, jsonKV{
   225  			Key:   InitializationVectorKey,
   226  			Value: strings.TrimSpace(buf.String()),
   227  		})
   228  	}
   229  
   230  	var encodedProtectedHeaders []byte
   231  	if h := m.ProtectedHeaders(); h != nil {
   232  		v, err := h.Encode()
   233  		if err != nil {
   234  			return nil, fmt.Errorf(`failed to encode protected headers: %w`, err)
   235  		}
   236  
   237  		encodedProtectedHeaders = v
   238  		if len(encodedProtectedHeaders) <= 2 { // '{}'
   239  			encodedProtectedHeaders = nil
   240  		} else {
   241  			fields = append(fields, jsonKV{
   242  				Key:   ProtectedHeadersKey,
   243  				Value: fmt.Sprintf("%q", encodedProtectedHeaders),
   244  			})
   245  		}
   246  	}
   247  
   248  	if aad := m.AuthenticatedData(); len(aad) > 0 {
   249  		aad = base64.Encode(aad)
   250  		if encodedProtectedHeaders != nil {
   251  			tmp := append(encodedProtectedHeaders, '.')
   252  			aad = append(tmp, aad...)
   253  		}
   254  
   255  		buf.Reset()
   256  		if err := enc.Encode(aad); err != nil {
   257  			return nil, fmt.Errorf(`failed to encode %s field: %w`, AuthenticatedDataKey, err)
   258  		}
   259  		fields = append(fields, jsonKV{
   260  			Key:   AuthenticatedDataKey,
   261  			Value: strings.TrimSpace(buf.String()),
   262  		})
   263  	}
   264  
   265  	if recipients := m.Recipients(); len(recipients) > 0 {
   266  		if len(recipients) == 1 { // Use flattened format
   267  			if hdrs := recipients[0].Headers(); hdrs != nil {
   268  				buf.Reset()
   269  				if err := enc.Encode(hdrs); err != nil {
   270  					return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err)
   271  				}
   272  				fields = append(fields, jsonKV{
   273  					Key:   HeadersKey,
   274  					Value: strings.TrimSpace(buf.String()),
   275  				})
   276  			}
   277  
   278  			if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
   279  				buf.Reset()
   280  				if err := enc.Encode(base64.EncodeToString(ek)); err != nil {
   281  					return nil, fmt.Errorf(`failed to encode %s field: %w`, EncryptedKeyKey, err)
   282  				}
   283  				fields = append(fields, jsonKV{
   284  					Key:   EncryptedKeyKey,
   285  					Value: strings.TrimSpace(buf.String()),
   286  				})
   287  			}
   288  		} else {
   289  			buf.Reset()
   290  			if err := enc.Encode(recipients); err != nil {
   291  				return nil, fmt.Errorf(`failed to encode %s field: %w`, RecipientsKey, err)
   292  			}
   293  			fields = append(fields, jsonKV{
   294  				Key:   RecipientsKey,
   295  				Value: strings.TrimSpace(buf.String()),
   296  			})
   297  		}
   298  	}
   299  
   300  	if tag := m.Tag(); len(tag) > 0 {
   301  		buf.Reset()
   302  		if err := enc.Encode(base64.EncodeToString(tag)); err != nil {
   303  			return nil, fmt.Errorf(`failed to encode %s field: %w`, TagKey, err)
   304  		}
   305  		fields = append(fields, jsonKV{
   306  			Key:   TagKey,
   307  			Value: strings.TrimSpace(buf.String()),
   308  		})
   309  	}
   310  
   311  	if h := m.UnprotectedHeaders(); h != nil {
   312  		unprotected, err := json.Marshal(h)
   313  		if err != nil {
   314  			return nil, fmt.Errorf(`failed to encode unprotected headers: %w`, err)
   315  		}
   316  
   317  		if len(unprotected) > 2 {
   318  			fields = append(fields, jsonKV{
   319  				Key:   UnprotectedHeadersKey,
   320  				Value: fmt.Sprintf("%q", unprotected),
   321  			})
   322  		}
   323  	}
   324  
   325  	sort.Slice(fields, func(i, j int) bool {
   326  		return fields[i].Key < fields[j].Key
   327  	})
   328  	buf.Reset()
   329  	fmt.Fprintf(buf, `{`)
   330  	for i, kv := range fields {
   331  		if i > 0 {
   332  			fmt.Fprintf(buf, `,`)
   333  		}
   334  		fmt.Fprintf(buf, `%q:%s`, kv.Key, kv.Value)
   335  	}
   336  	fmt.Fprintf(buf, `}`)
   337  
   338  	ret := make([]byte, buf.Len())
   339  	copy(ret, buf.Bytes())
   340  	return ret, nil
   341  }
   342  
   343  func (m *Message) UnmarshalJSON(buf []byte) error {
   344  	var proxy messageMarshalProxy
   345  	proxy.UnprotectedHeaders = NewHeaders()
   346  
   347  	if err := json.Unmarshal(buf, &proxy); err != nil {
   348  		return fmt.Errorf(`failed to unmashal JSON into message: %w`, err)
   349  	}
   350  
   351  	// Get the string value
   352  	var protectedHeadersStr string
   353  	if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil {
   354  		return fmt.Errorf(`failed to decode protected headers (1): %w`, err)
   355  	}
   356  
   357  	// It's now in _quoted_ base64 string. Decode it
   358  	protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr)
   359  	if err != nil {
   360  		return fmt.Errorf(`failed to base64 decoded protected headers buffer: %w`, err)
   361  	}
   362  
   363  	h := NewHeaders()
   364  	if err := json.Unmarshal(protectedHeadersRaw, h); err != nil {
   365  		return fmt.Errorf(`failed to decode protected headers (2): %w`, err)
   366  	}
   367  
   368  	// if this were a flattened message, we would see a "header" and "ciphertext"
   369  	// field. TODO: do both of these conditions need to meet, or just one?
   370  	if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
   371  		recipient := NewRecipient()
   372  		hdrs := NewHeaders()
   373  		if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
   374  			return fmt.Errorf(`failed to decode headers field: %w`, err)
   375  		}
   376  
   377  		if err := recipient.SetHeaders(hdrs); err != nil {
   378  			return fmt.Errorf(`failed to set new headers: %w`, err)
   379  		}
   380  
   381  		if v := proxy.EncryptedKey; len(v) > 0 {
   382  			buf, err := base64.DecodeString(v)
   383  			if err != nil {
   384  				return fmt.Errorf(`failed to decode encrypted key: %w`, err)
   385  			}
   386  			if err := recipient.SetEncryptedKey(buf); err != nil {
   387  				return fmt.Errorf(`failed to set encrypted key: %w`, err)
   388  			}
   389  		}
   390  
   391  		m.recipients = append(m.recipients, recipient)
   392  	} else {
   393  		for i, recipientbuf := range proxy.Recipients {
   394  			recipient := NewRecipient()
   395  			if err := json.Unmarshal(recipientbuf, recipient); err != nil {
   396  				return fmt.Errorf(`failed to decode recipient at index %d: %w`, i, err)
   397  			}
   398  
   399  			m.recipients = append(m.recipients, recipient)
   400  		}
   401  	}
   402  
   403  	if src := proxy.AuthenticatedData; len(src) > 0 {
   404  		v, err := base64.DecodeString(src)
   405  		if err != nil {
   406  			return fmt.Errorf(`failed to decode "aad": %w`, err)
   407  		}
   408  		m.authenticatedData = v
   409  	}
   410  
   411  	if src := proxy.CipherText; len(src) > 0 {
   412  		v, err := base64.DecodeString(src)
   413  		if err != nil {
   414  			return fmt.Errorf(`failed to decode "ciphertext": %w`, err)
   415  		}
   416  		m.cipherText = v
   417  	}
   418  
   419  	if src := proxy.InitializationVector; len(src) > 0 {
   420  		v, err := base64.DecodeString(src)
   421  		if err != nil {
   422  			return fmt.Errorf(`failed to decode "iv": %w`, err)
   423  		}
   424  		m.initializationVector = v
   425  	}
   426  
   427  	if src := proxy.Tag; len(src) > 0 {
   428  		v, err := base64.DecodeString(src)
   429  		if err != nil {
   430  			return fmt.Errorf(`failed to decode "tag": %w`, err)
   431  		}
   432  		m.tag = v
   433  	}
   434  
   435  	m.protectedHeaders = h
   436  	if m.storeProtectedHeaders {
   437  		// this is later used for decryption
   438  		m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw)
   439  	}
   440  
   441  	if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok {
   442  		if !iz.isZero() {
   443  			m.unprotectedHeaders = proxy.UnprotectedHeaders
   444  		}
   445  	}
   446  
   447  	if len(m.recipients) == 0 {
   448  		if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil {
   449  			return fmt.Errorf(`failed to setup recipient: %w`, err)
   450  		}
   451  	}
   452  
   453  	return nil
   454  }
   455  
   456  func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error {
   457  	// Recipients in this case should not contain the content encryption key,
   458  	// so move that out
   459  	hdrs, err := protected.Clone(context.TODO())
   460  	if err != nil {
   461  		return fmt.Errorf(`failed to clone headers: %w`, err)
   462  	}
   463  
   464  	if err := hdrs.Remove(ContentEncryptionKey); err != nil {
   465  		return fmt.Errorf(`failed to remove %#v from public header: %w`, ContentEncryptionKey, err)
   466  	}
   467  
   468  	enckey, err := base64.DecodeString(enckeybuf)
   469  	if err != nil {
   470  		return fmt.Errorf(`failed to decode encrypted key: %w`, err)
   471  	}
   472  
   473  	if err := m.Set(RecipientsKey, []Recipient{
   474  		&stdRecipient{
   475  			headers:      hdrs,
   476  			encryptedKey: enckey,
   477  		},
   478  	}); err != nil {
   479  		return fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err)
   480  	}
   481  	return nil
   482  }
   483  
   484  // Compact generates a JWE message in compact serialization format from a
   485  // `*jwe.Message` object. The object contain exactly one recipient, or
   486  // an error is returned.
   487  //
   488  // This function currently does not take any options, but the function
   489  // signature contains `options` for possible future expansion of the API
   490  func Compact(m *Message, _ ...CompactOption) ([]byte, error) {
   491  	if len(m.recipients) != 1 {
   492  		return nil, fmt.Errorf(`wrong number of recipients for compact serialization`)
   493  	}
   494  
   495  	recipient := m.recipients[0]
   496  
   497  	// The protected header must be a merge between the message-wide
   498  	// protected header AND the recipient header
   499  
   500  	// There's something wrong if m.protectedHeaders is nil, but
   501  	// it could happen
   502  	if m.protectedHeaders == nil {
   503  		return nil, fmt.Errorf(`invalid protected header`)
   504  	}
   505  
   506  	ctx := context.TODO()
   507  	hcopy, err := m.protectedHeaders.Clone(ctx)
   508  	if err != nil {
   509  		return nil, fmt.Errorf(`failed to copy protected header: %w`, err)
   510  	}
   511  	hcopy, err = hcopy.Merge(ctx, m.unprotectedHeaders)
   512  	if err != nil {
   513  		return nil, fmt.Errorf(`failed to merge unprotected header: %w`, err)
   514  	}
   515  	hcopy, err = hcopy.Merge(ctx, recipient.Headers())
   516  	if err != nil {
   517  		return nil, fmt.Errorf(`failed to merge recipient header: %w`, err)
   518  	}
   519  
   520  	protected, err := hcopy.Encode()
   521  	if err != nil {
   522  		return nil, fmt.Errorf(`failed to encode header: %w`, err)
   523  	}
   524  
   525  	encryptedKey := base64.Encode(recipient.EncryptedKey())
   526  	iv := base64.Encode(m.initializationVector)
   527  	cipher := base64.Encode(m.cipherText)
   528  	tag := base64.Encode(m.tag)
   529  
   530  	buf := pool.GetBytesBuffer()
   531  	defer pool.ReleaseBytesBuffer(buf)
   532  
   533  	buf.Grow(len(protected) + len(encryptedKey) + len(iv) + len(cipher) + len(tag) + 4)
   534  	buf.Write(protected)
   535  	buf.WriteByte('.')
   536  	buf.Write(encryptedKey)
   537  	buf.WriteByte('.')
   538  	buf.Write(iv)
   539  	buf.WriteByte('.')
   540  	buf.Write(cipher)
   541  	buf.WriteByte('.')
   542  	buf.Write(tag)
   543  
   544  	result := make([]byte, buf.Len())
   545  	copy(result, buf.Bytes())
   546  	return result, nil
   547  }