github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/serialize.go (about)

     1  package jwt
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/lestrrat-go/jwx/v2/internal/json"
     7  	"github.com/lestrrat-go/jwx/v2/jwe"
     8  	"github.com/lestrrat-go/jwx/v2/jws"
     9  )
    10  
    11  type SerializeCtx interface {
    12  	Step() int
    13  	Nested() bool
    14  }
    15  
    16  type serializeCtx struct {
    17  	step   int
    18  	nested bool
    19  }
    20  
    21  func (ctx *serializeCtx) Step() int {
    22  	return ctx.step
    23  }
    24  
    25  func (ctx *serializeCtx) Nested() bool {
    26  	return ctx.nested
    27  }
    28  
    29  type SerializeStep interface {
    30  	Serialize(SerializeCtx, interface{}) (interface{}, error)
    31  }
    32  
    33  // errStep is always an error. used to indicate that a method like
    34  // serializer.Sign or Encrypt already errored out on configuration
    35  type errStep struct {
    36  	err error
    37  }
    38  
    39  func (e errStep) Serialize(_ SerializeCtx, _ interface{}) (interface{}, error) {
    40  	return nil, e.err
    41  }
    42  
    43  // Serializer is a generic serializer for JWTs. Whereas other conveinience
    44  // functions can only do one thing (such as generate a JWS signed JWT),
    45  // Using this construct you can serialize the token however you want.
    46  //
    47  // By default the serializer only marshals the token into a JSON payload.
    48  // You must set up the rest of the steps that should be taken by the
    49  // serializer.
    50  //
    51  // For example, to marshal the token into JSON, then apply JWS and JWE
    52  // in that order, you would do:
    53  //
    54  //	serialized, err := jwt.NewSerialer().
    55  //	   Sign(jwa.RS256, key).
    56  //	   Encrypt(jwa.RSA_OAEP, key.PublicKey).
    57  //	   Serialize(token)
    58  //
    59  // The `jwt.Sign()` function is equivalent to
    60  //
    61  //	serialized, err := jwt.NewSerializer().
    62  //	   Sign(...args...).
    63  //	   Serialize(token)
    64  type Serializer struct {
    65  	steps []SerializeStep
    66  }
    67  
    68  // NewSerializer creates a new empty serializer.
    69  func NewSerializer() *Serializer {
    70  	return &Serializer{}
    71  }
    72  
    73  // Reset clears all of the registered steps.
    74  func (s *Serializer) Reset() *Serializer {
    75  	s.steps = nil
    76  	return s
    77  }
    78  
    79  // Step adds a new Step to the serialization process
    80  func (s *Serializer) Step(step SerializeStep) *Serializer {
    81  	s.steps = append(s.steps, step)
    82  	return s
    83  }
    84  
    85  type jsonSerializer struct{}
    86  
    87  func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
    88  	token, ok := v.(Token)
    89  	if !ok {
    90  		return nil, fmt.Errorf(`invalid input: expected jwt.Token`)
    91  	}
    92  
    93  	buf, err := json.Marshal(token)
    94  	if err != nil {
    95  		return nil, fmt.Errorf(`failed to serialize as JSON`)
    96  	}
    97  	return buf, nil
    98  }
    99  
   100  type genericHeader interface {
   101  	Get(string) (interface{}, bool)
   102  	Set(string, interface{}) error
   103  }
   104  
   105  func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
   106  	// cty and typ are common between JWE/JWS, so we don't use
   107  	// the constants in jws/jwe package here
   108  	const typKey = `typ`
   109  	const ctyKey = `cty`
   110  
   111  	if ctx.Step() == 1 {
   112  		// We are executed immediately after json marshaling
   113  		if _, ok := hdrs.Get(typKey); !ok {
   114  			if err := hdrs.Set(typKey, `JWT`); err != nil {
   115  				return fmt.Errorf(`failed to set %s key to "JWT": %w`, typKey, err)
   116  			}
   117  		}
   118  	} else {
   119  		if ctx.Nested() {
   120  			// If this is part of a nested sequence, we should set cty = 'JWT'
   121  			// https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
   122  			if err := hdrs.Set(ctyKey, `JWT`); err != nil {
   123  				return fmt.Errorf(`failed to set %s key to "JWT": %w`, ctyKey, err)
   124  			}
   125  		}
   126  	}
   127  	return nil
   128  }
   129  
   130  type jwsSerializer struct {
   131  	options []jws.SignOption
   132  }
   133  
   134  func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
   135  	payload, ok := v.([]byte)
   136  	if !ok {
   137  		return nil, fmt.Errorf(`expected []byte as input`)
   138  	}
   139  
   140  	for _, option := range s.options {
   141  		pc, ok := option.Value().(interface{ Protected(jws.Headers) jws.Headers })
   142  		if !ok {
   143  			continue
   144  		}
   145  		hdrs := pc.Protected(jws.NewHeaders())
   146  		if err := setTypeOrCty(ctx, hdrs); err != nil {
   147  			return nil, err // this is already wrapped
   148  		}
   149  
   150  		// JWTs MUST NOT use b64 = false
   151  		// https://datatracker.ietf.org/doc/html/rfc7797#section-7
   152  		if v, ok := hdrs.Get("b64"); ok {
   153  			if bval, bok := v.(bool); bok {
   154  				if !bval { // b64 = false
   155  					return nil, fmt.Errorf(`b64 cannot be false for JWTs`)
   156  				}
   157  			}
   158  		}
   159  	}
   160  	return jws.Sign(payload, s.options...)
   161  }
   162  
   163  func (s *Serializer) Sign(options ...SignOption) *Serializer {
   164  	var soptions []jws.SignOption
   165  	if l := len(options); l > 0 {
   166  		// we need to from SignOption to Option because ... reasons
   167  		// (todo: when go1.18 prevails, use type parameters
   168  		rawoptions := make([]Option, l)
   169  		for i, option := range options {
   170  			rawoptions[i] = option
   171  		}
   172  
   173  		converted, err := toSignOptions(rawoptions...)
   174  		if err != nil {
   175  			return s.Step(errStep{fmt.Errorf(`(jwt.Serializer).Sign: failed to convert options into jws.SignOption: %w`, err)})
   176  		}
   177  		soptions = converted
   178  	}
   179  	return s.sign(soptions...)
   180  }
   181  
   182  func (s *Serializer) sign(options ...jws.SignOption) *Serializer {
   183  	return s.Step(&jwsSerializer{
   184  		options: options,
   185  	})
   186  }
   187  
   188  type jweSerializer struct {
   189  	options []jwe.EncryptOption
   190  }
   191  
   192  func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
   193  	payload, ok := v.([]byte)
   194  	if !ok {
   195  		return nil, fmt.Errorf(`expected []byte as input`)
   196  	}
   197  
   198  	hdrs := jwe.NewHeaders()
   199  	if err := setTypeOrCty(ctx, hdrs); err != nil {
   200  		return nil, err // this is already wrapped
   201  	}
   202  
   203  	options := append([]jwe.EncryptOption{jwe.WithMergeProtectedHeaders(true), jwe.WithProtectedHeaders(hdrs)}, s.options...)
   204  	return jwe.Encrypt(payload, options...)
   205  }
   206  
   207  // Encrypt specifies the JWT to be serialized as an encrypted payload.
   208  //
   209  // One notable difference between this method and `jwe.Encrypt()` is that
   210  // while `jwe.Encrypt()` OVERWRITES the previous headers when `jwe.WithProtectedHeaders()`
   211  // is provided, this method MERGES them. This is due to the fact that we
   212  // MUST add some extra headers to construct a proper JWE message.
   213  // Be careful when you pass multiple `jwe.EncryptOption`s.
   214  func (s *Serializer) Encrypt(options ...EncryptOption) *Serializer {
   215  	var eoptions []jwe.EncryptOption
   216  	if l := len(options); l > 0 {
   217  		// we need to from SignOption to Option because ... reasons
   218  		// (todo: when go1.18 prevails, use type parameters
   219  		rawoptions := make([]Option, l)
   220  		for i, option := range options {
   221  			rawoptions[i] = option
   222  		}
   223  
   224  		converted, err := toEncryptOptions(rawoptions...)
   225  		if err != nil {
   226  			return s.Step(errStep{fmt.Errorf(`(jwt.Serializer).Encrypt: failed to convert options into jwe.EncryptOption: %w`, err)})
   227  		}
   228  		eoptions = converted
   229  	}
   230  	return s.encrypt(eoptions...)
   231  }
   232  
   233  func (s *Serializer) encrypt(options ...jwe.EncryptOption) *Serializer {
   234  	return s.Step(&jweSerializer{
   235  		options: options,
   236  	})
   237  }
   238  
   239  func (s *Serializer) Serialize(t Token) ([]byte, error) {
   240  	steps := make([]SerializeStep, len(s.steps)+1)
   241  	steps[0] = jsonSerializer{}
   242  	for i, step := range s.steps {
   243  		steps[i+1] = step
   244  	}
   245  
   246  	var ctx serializeCtx
   247  	ctx.nested = len(s.steps) > 1
   248  	var payload interface{} = t
   249  	for i, step := range steps {
   250  		ctx.step = i
   251  		v, err := step.Serialize(&ctx, payload)
   252  		if err != nil {
   253  			return nil, fmt.Errorf(`failed to serialize token at step #%d: %w`, i+1, err)
   254  		}
   255  		payload = v
   256  	}
   257  
   258  	res, ok := payload.([]byte)
   259  	if !ok {
   260  		return nil, fmt.Errorf(`invalid serialization produced`)
   261  	}
   262  
   263  	return res, nil
   264  }