github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/token_gen.go (about)

     1  // Code generated by tools/cmd/genjwt/main.go. DO NOT EDIT.
     2  
     3  package jwt
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"fmt"
     9  	"sort"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/lestrrat-go/iter/mapiter"
    14  	"github.com/lestrrat-go/jwx/v2/internal/base64"
    15  	"github.com/lestrrat-go/jwx/v2/internal/iter"
    16  	"github.com/lestrrat-go/jwx/v2/internal/json"
    17  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    18  	"github.com/lestrrat-go/jwx/v2/jwt/internal/types"
    19  )
    20  
    21  const (
    22  	AudienceKey   = "aud"
    23  	ExpirationKey = "exp"
    24  	IssuedAtKey   = "iat"
    25  	IssuerKey     = "iss"
    26  	JwtIDKey      = "jti"
    27  	NotBeforeKey  = "nbf"
    28  	SubjectKey    = "sub"
    29  )
    30  
    31  // Token represents a generic JWT token.
    32  // which are type-aware (to an extent). Other claims may be accessed via the `Get`/`Set`
    33  // methods but their types are not taken into consideration at all. If you have non-standard
    34  // claims that you must frequently access, consider creating accessors functions
    35  // like the following
    36  //
    37  // func SetFoo(tok jwt.Token) error
    38  // func GetFoo(tok jwt.Token) (*Customtyp, error)
    39  //
    40  // Embedding jwt.Token into another struct is not recommended, because
    41  // jwt.Token needs to handle private claims, and this really does not
    42  // work well when it is embedded in other structure
    43  type Token interface {
    44  
    45  	// Audience returns the value for "aud" field of the token
    46  	Audience() []string
    47  
    48  	// Expiration returns the value for "exp" field of the token
    49  	Expiration() time.Time
    50  
    51  	// IssuedAt returns the value for "iat" field of the token
    52  	IssuedAt() time.Time
    53  
    54  	// Issuer returns the value for "iss" field of the token
    55  	Issuer() string
    56  
    57  	// JwtID returns the value for "jti" field of the token
    58  	JwtID() string
    59  
    60  	// NotBefore returns the value for "nbf" field of the token
    61  	NotBefore() time.Time
    62  
    63  	// Subject returns the value for "sub" field of the token
    64  	Subject() string
    65  
    66  	// PrivateClaims return the entire set of fields (claims) in the token
    67  	// *other* than the pre-defined fields such as `iss`, `nbf`, `iat`, etc.
    68  	PrivateClaims() map[string]interface{}
    69  
    70  	// Get returns the value of the corresponding field in the token, such as
    71  	// `nbf`, `exp`, `iat`, and other user-defined fields. If the field does not
    72  	// exist in the token, the second return value will be `false`
    73  	//
    74  	// If you need to access fields like `alg`, `kid`, `jku`, etc, you need
    75  	// to access the corresponding fields in the JWS/JWE message. For this,
    76  	// you will need to access them by directly parsing the payload using
    77  	// `jws.Parse` and `jwe.Parse`
    78  	Get(string) (interface{}, bool)
    79  
    80  	// Set assigns a value to the corresponding field in the token. Some
    81  	// pre-defined fields such as `nbf`, `iat`, `iss` need their values to
    82  	// be of a specific type. See the other getter methods in this interface
    83  	// for the types of each of these fields
    84  	Set(string, interface{}) error
    85  	Remove(string) error
    86  
    87  	// Options returns the per-token options associated with this token.
    88  	// The options set value will be copied when the token is cloned via `Clone()`
    89  	// but it will not survive when the token goes through marshaling/unmarshaling
    90  	// such as `json.Marshal` and `json.Unmarshal`
    91  	Options() *TokenOptionSet
    92  	Clone() (Token, error)
    93  	Iterate(context.Context) Iterator
    94  	Walk(context.Context, Visitor) error
    95  	AsMap(context.Context) (map[string]interface{}, error)
    96  }
    97  type stdToken struct {
    98  	mu            *sync.RWMutex
    99  	dc            DecodeCtx          // per-object context for decoding
   100  	options       TokenOptionSet     // per-object option
   101  	audience      types.StringList   // https://tools.ietf.org/html/rfc7519#section-4.1.3
   102  	expiration    *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.4
   103  	issuedAt      *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.6
   104  	issuer        *string            // https://tools.ietf.org/html/rfc7519#section-4.1.1
   105  	jwtID         *string            // https://tools.ietf.org/html/rfc7519#section-4.1.7
   106  	notBefore     *types.NumericDate // https://tools.ietf.org/html/rfc7519#section-4.1.5
   107  	subject       *string            // https://tools.ietf.org/html/rfc7519#section-4.1.2
   108  	privateClaims map[string]interface{}
   109  }
   110  
   111  // New creates a standard token, with minimal knowledge of
   112  // possible claims. Standard claims include"aud", "exp", "iat", "iss", "jti", "nbf" and "sub".
   113  // Convenience accessors are provided for these standard claims
   114  func New() Token {
   115  	return &stdToken{
   116  		mu:            &sync.RWMutex{},
   117  		privateClaims: make(map[string]interface{}),
   118  		options:       DefaultOptionSet(),
   119  	}
   120  }
   121  
   122  func (t *stdToken) Options() *TokenOptionSet {
   123  	return &t.options
   124  }
   125  
   126  func (t *stdToken) Get(name string) (interface{}, bool) {
   127  	t.mu.RLock()
   128  	defer t.mu.RUnlock()
   129  	switch name {
   130  	case AudienceKey:
   131  		if t.audience == nil {
   132  			return nil, false
   133  		}
   134  		v := t.audience.Get()
   135  		return v, true
   136  	case ExpirationKey:
   137  		if t.expiration == nil {
   138  			return nil, false
   139  		}
   140  		v := t.expiration.Get()
   141  		return v, true
   142  	case IssuedAtKey:
   143  		if t.issuedAt == nil {
   144  			return nil, false
   145  		}
   146  		v := t.issuedAt.Get()
   147  		return v, true
   148  	case IssuerKey:
   149  		if t.issuer == nil {
   150  			return nil, false
   151  		}
   152  		v := *(t.issuer)
   153  		return v, true
   154  	case JwtIDKey:
   155  		if t.jwtID == nil {
   156  			return nil, false
   157  		}
   158  		v := *(t.jwtID)
   159  		return v, true
   160  	case NotBeforeKey:
   161  		if t.notBefore == nil {
   162  			return nil, false
   163  		}
   164  		v := t.notBefore.Get()
   165  		return v, true
   166  	case SubjectKey:
   167  		if t.subject == nil {
   168  			return nil, false
   169  		}
   170  		v := *(t.subject)
   171  		return v, true
   172  	default:
   173  		v, ok := t.privateClaims[name]
   174  		return v, ok
   175  	}
   176  }
   177  
   178  func (t *stdToken) Remove(key string) error {
   179  	t.mu.Lock()
   180  	defer t.mu.Unlock()
   181  	switch key {
   182  	case AudienceKey:
   183  		t.audience = nil
   184  	case ExpirationKey:
   185  		t.expiration = nil
   186  	case IssuedAtKey:
   187  		t.issuedAt = nil
   188  	case IssuerKey:
   189  		t.issuer = nil
   190  	case JwtIDKey:
   191  		t.jwtID = nil
   192  	case NotBeforeKey:
   193  		t.notBefore = nil
   194  	case SubjectKey:
   195  		t.subject = nil
   196  	default:
   197  		delete(t.privateClaims, key)
   198  	}
   199  	return nil
   200  }
   201  
   202  func (t *stdToken) Set(name string, value interface{}) error {
   203  	t.mu.Lock()
   204  	defer t.mu.Unlock()
   205  	return t.setNoLock(name, value)
   206  }
   207  
   208  func (t *stdToken) DecodeCtx() DecodeCtx {
   209  	t.mu.RLock()
   210  	defer t.mu.RUnlock()
   211  	return t.dc
   212  }
   213  
   214  func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
   215  	t.mu.Lock()
   216  	defer t.mu.Unlock()
   217  	t.dc = v
   218  }
   219  
   220  func (t *stdToken) setNoLock(name string, value interface{}) error {
   221  	switch name {
   222  	case AudienceKey:
   223  		var acceptor types.StringList
   224  		if err := acceptor.Accept(value); err != nil {
   225  			return fmt.Errorf(`invalid value for %s key: %w`, AudienceKey, err)
   226  		}
   227  		t.audience = acceptor
   228  		return nil
   229  	case ExpirationKey:
   230  		var acceptor types.NumericDate
   231  		if err := acceptor.Accept(value); err != nil {
   232  			return fmt.Errorf(`invalid value for %s key: %w`, ExpirationKey, err)
   233  		}
   234  		t.expiration = &acceptor
   235  		return nil
   236  	case IssuedAtKey:
   237  		var acceptor types.NumericDate
   238  		if err := acceptor.Accept(value); err != nil {
   239  			return fmt.Errorf(`invalid value for %s key: %w`, IssuedAtKey, err)
   240  		}
   241  		t.issuedAt = &acceptor
   242  		return nil
   243  	case IssuerKey:
   244  		if v, ok := value.(string); ok {
   245  			t.issuer = &v
   246  			return nil
   247  		}
   248  		return fmt.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
   249  	case JwtIDKey:
   250  		if v, ok := value.(string); ok {
   251  			t.jwtID = &v
   252  			return nil
   253  		}
   254  		return fmt.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
   255  	case NotBeforeKey:
   256  		var acceptor types.NumericDate
   257  		if err := acceptor.Accept(value); err != nil {
   258  			return fmt.Errorf(`invalid value for %s key: %w`, NotBeforeKey, err)
   259  		}
   260  		t.notBefore = &acceptor
   261  		return nil
   262  	case SubjectKey:
   263  		if v, ok := value.(string); ok {
   264  			t.subject = &v
   265  			return nil
   266  		}
   267  		return fmt.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
   268  	default:
   269  		if t.privateClaims == nil {
   270  			t.privateClaims = map[string]interface{}{}
   271  		}
   272  		t.privateClaims[name] = value
   273  	}
   274  	return nil
   275  }
   276  
   277  func (t *stdToken) Audience() []string {
   278  	t.mu.RLock()
   279  	defer t.mu.RUnlock()
   280  	if t.audience != nil {
   281  		return t.audience.Get()
   282  	}
   283  	return nil
   284  }
   285  
   286  func (t *stdToken) Expiration() time.Time {
   287  	t.mu.RLock()
   288  	defer t.mu.RUnlock()
   289  	if t.expiration != nil {
   290  		return t.expiration.Get()
   291  	}
   292  	return time.Time{}
   293  }
   294  
   295  func (t *stdToken) IssuedAt() time.Time {
   296  	t.mu.RLock()
   297  	defer t.mu.RUnlock()
   298  	if t.issuedAt != nil {
   299  		return t.issuedAt.Get()
   300  	}
   301  	return time.Time{}
   302  }
   303  
   304  func (t *stdToken) Issuer() string {
   305  	t.mu.RLock()
   306  	defer t.mu.RUnlock()
   307  	if t.issuer != nil {
   308  		return *(t.issuer)
   309  	}
   310  	return ""
   311  }
   312  
   313  func (t *stdToken) JwtID() string {
   314  	t.mu.RLock()
   315  	defer t.mu.RUnlock()
   316  	if t.jwtID != nil {
   317  		return *(t.jwtID)
   318  	}
   319  	return ""
   320  }
   321  
   322  func (t *stdToken) NotBefore() time.Time {
   323  	t.mu.RLock()
   324  	defer t.mu.RUnlock()
   325  	if t.notBefore != nil {
   326  		return t.notBefore.Get()
   327  	}
   328  	return time.Time{}
   329  }
   330  
   331  func (t *stdToken) Subject() string {
   332  	t.mu.RLock()
   333  	defer t.mu.RUnlock()
   334  	if t.subject != nil {
   335  		return *(t.subject)
   336  	}
   337  	return ""
   338  }
   339  
   340  func (t *stdToken) PrivateClaims() map[string]interface{} {
   341  	t.mu.RLock()
   342  	defer t.mu.RUnlock()
   343  	return t.privateClaims
   344  }
   345  
   346  func (t *stdToken) makePairs() []*ClaimPair {
   347  	t.mu.RLock()
   348  	defer t.mu.RUnlock()
   349  
   350  	pairs := make([]*ClaimPair, 0, 7)
   351  	if t.audience != nil {
   352  		v := t.audience.Get()
   353  		pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
   354  	}
   355  	if t.expiration != nil {
   356  		v := t.expiration.Get()
   357  		pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
   358  	}
   359  	if t.issuedAt != nil {
   360  		v := t.issuedAt.Get()
   361  		pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
   362  	}
   363  	if t.issuer != nil {
   364  		v := *(t.issuer)
   365  		pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
   366  	}
   367  	if t.jwtID != nil {
   368  		v := *(t.jwtID)
   369  		pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
   370  	}
   371  	if t.notBefore != nil {
   372  		v := t.notBefore.Get()
   373  		pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
   374  	}
   375  	if t.subject != nil {
   376  		v := *(t.subject)
   377  		pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
   378  	}
   379  	for k, v := range t.privateClaims {
   380  		pairs = append(pairs, &ClaimPair{Key: k, Value: v})
   381  	}
   382  	sort.Slice(pairs, func(i, j int) bool {
   383  		return pairs[i].Key.(string) < pairs[j].Key.(string)
   384  	})
   385  	return pairs
   386  }
   387  
   388  func (t *stdToken) UnmarshalJSON(buf []byte) error {
   389  	t.mu.Lock()
   390  	defer t.mu.Unlock()
   391  	t.audience = nil
   392  	t.expiration = nil
   393  	t.issuedAt = nil
   394  	t.issuer = nil
   395  	t.jwtID = nil
   396  	t.notBefore = nil
   397  	t.subject = nil
   398  	dec := json.NewDecoder(bytes.NewReader(buf))
   399  LOOP:
   400  	for {
   401  		tok, err := dec.Token()
   402  		if err != nil {
   403  			return fmt.Errorf(`error reading token: %w`, err)
   404  		}
   405  		switch tok := tok.(type) {
   406  		case json.Delim:
   407  			// Assuming we're doing everything correctly, we should ONLY
   408  			// get either '{' or '}' here.
   409  			if tok == '}' { // End of object
   410  				break LOOP
   411  			} else if tok != '{' {
   412  				return fmt.Errorf(`expected '{', but got '%c'`, tok)
   413  			}
   414  		case string: // Objects can only have string keys
   415  			switch tok {
   416  			case AudienceKey:
   417  				var decoded types.StringList
   418  				if err := dec.Decode(&decoded); err != nil {
   419  					return fmt.Errorf(`failed to decode value for key %s: %w`, AudienceKey, err)
   420  				}
   421  				t.audience = decoded
   422  			case ExpirationKey:
   423  				var decoded types.NumericDate
   424  				if err := dec.Decode(&decoded); err != nil {
   425  					return fmt.Errorf(`failed to decode value for key %s: %w`, ExpirationKey, err)
   426  				}
   427  				t.expiration = &decoded
   428  			case IssuedAtKey:
   429  				var decoded types.NumericDate
   430  				if err := dec.Decode(&decoded); err != nil {
   431  					return fmt.Errorf(`failed to decode value for key %s: %w`, IssuedAtKey, err)
   432  				}
   433  				t.issuedAt = &decoded
   434  			case IssuerKey:
   435  				if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
   436  					return fmt.Errorf(`failed to decode value for key %s: %w`, IssuerKey, err)
   437  				}
   438  			case JwtIDKey:
   439  				if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
   440  					return fmt.Errorf(`failed to decode value for key %s: %w`, JwtIDKey, err)
   441  				}
   442  			case NotBeforeKey:
   443  				var decoded types.NumericDate
   444  				if err := dec.Decode(&decoded); err != nil {
   445  					return fmt.Errorf(`failed to decode value for key %s: %w`, NotBeforeKey, err)
   446  				}
   447  				t.notBefore = &decoded
   448  			case SubjectKey:
   449  				if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
   450  					return fmt.Errorf(`failed to decode value for key %s: %w`, SubjectKey, err)
   451  				}
   452  			default:
   453  				if dc := t.dc; dc != nil {
   454  					if localReg := dc.Registry(); localReg != nil {
   455  						decoded, err := localReg.Decode(dec, tok)
   456  						if err == nil {
   457  							t.setNoLock(tok, decoded)
   458  							continue
   459  						}
   460  					}
   461  				}
   462  				decoded, err := registry.Decode(dec, tok)
   463  				if err == nil {
   464  					t.setNoLock(tok, decoded)
   465  					continue
   466  				}
   467  				return fmt.Errorf(`could not decode field %s: %w`, tok, err)
   468  			}
   469  		default:
   470  			return fmt.Errorf(`invalid token %T`, tok)
   471  		}
   472  	}
   473  	return nil
   474  }
   475  
   476  func (t stdToken) MarshalJSON() ([]byte, error) {
   477  	buf := pool.GetBytesBuffer()
   478  	defer pool.ReleaseBytesBuffer(buf)
   479  	buf.WriteByte('{')
   480  	enc := json.NewEncoder(buf)
   481  	for i, pair := range t.makePairs() {
   482  		f := pair.Key.(string)
   483  		if i > 0 {
   484  			buf.WriteByte(',')
   485  		}
   486  		buf.WriteRune('"')
   487  		buf.WriteString(f)
   488  		buf.WriteString(`":`)
   489  		switch f {
   490  		case AudienceKey:
   491  			if err := json.EncodeAudience(enc, pair.Value.([]string), t.options.IsEnabled(FlattenAudience)); err != nil {
   492  				return nil, fmt.Errorf(`failed to encode "aud": %w`, err)
   493  			}
   494  			continue
   495  		case ExpirationKey, IssuedAtKey, NotBeforeKey:
   496  			enc.Encode(pair.Value.(time.Time).Unix())
   497  			continue
   498  		}
   499  		switch v := pair.Value.(type) {
   500  		case []byte:
   501  			buf.WriteRune('"')
   502  			buf.WriteString(base64.EncodeToString(v))
   503  			buf.WriteRune('"')
   504  		default:
   505  			if err := enc.Encode(v); err != nil {
   506  				return nil, fmt.Errorf(`failed to marshal field %s: %w`, f, err)
   507  			}
   508  			buf.Truncate(buf.Len() - 1)
   509  		}
   510  	}
   511  	buf.WriteByte('}')
   512  	ret := make([]byte, buf.Len())
   513  	copy(ret, buf.Bytes())
   514  	return ret, nil
   515  }
   516  
   517  func (t *stdToken) Iterate(ctx context.Context) Iterator {
   518  	pairs := t.makePairs()
   519  	ch := make(chan *ClaimPair, len(pairs))
   520  	go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
   521  		defer close(ch)
   522  		for _, pair := range pairs {
   523  			select {
   524  			case <-ctx.Done():
   525  				return
   526  			case ch <- pair:
   527  			}
   528  		}
   529  	}(ctx, ch, pairs)
   530  	return mapiter.New(ch)
   531  }
   532  
   533  func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
   534  	return iter.WalkMap(ctx, t, visitor)
   535  }
   536  
   537  func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
   538  	return iter.AsMap(ctx, t)
   539  }