go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/server/tokens/tokens.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tokens provides means to generate and validate base64 encoded tokens
    16  // compatible with luci-py's components.auth implementation.
    17  package tokens
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/hmac"
    23  	"crypto/sha256"
    24  	"encoding/base64"
    25  	"encoding/json"
    26  	"fmt"
    27  	"hash"
    28  	"strconv"
    29  	"strings"
    30  	"time"
    31  
    32  	"go.chromium.org/luci/common/clock"
    33  
    34  	"go.chromium.org/luci/server/secrets"
    35  )
    36  
    37  // allowedClockDrift is clock drift between machines we can tolerate.
    38  const allowedClockDrift = 30 * time.Second
    39  
    40  // TokenAlgo identifies how token is authenticated.
    41  type TokenAlgo string
    42  
    43  const (
    44  	// TokenAlgoHmacSHA256 algorithm stores public portion of the token as plain
    45  	// text and uses HMAC SHA256 to authenticate its integrity.
    46  	TokenAlgoHmacSHA256 = "HMAC-SHA256"
    47  )
    48  
    49  // hash returns hash.Hash that computes the digest or nil if algo is unknown.
    50  func (a TokenAlgo) hash(secret []byte) hash.Hash {
    51  	switch a {
    52  	case TokenAlgoHmacSHA256:
    53  		return hmac.New(sha256.New, secret)
    54  	}
    55  	return nil
    56  }
    57  
    58  // digestLen returns length of digest generated by an algo or 0 if unknown.
    59  func (a TokenAlgo) digestLen() int {
    60  	switch a {
    61  	case TokenAlgoHmacSHA256:
    62  		return sha256.Size
    63  	}
    64  	return 0
    65  }
    66  
    67  // TokenKind is a configuration of particular type of a token. It can be
    68  // defined statically in a module and then its Generate() and Validate() methods
    69  // can be used to produce and verify tokens.
    70  type TokenKind struct {
    71  	Algo       TokenAlgo
    72  	Expiration time.Duration // how long generated token lives
    73  	SecretKey  string        // name of the secret key in secrets.Store
    74  	Version    byte          // tokens with another version will be rejected
    75  }
    76  
    77  // Generate produces an urlsafe base64 encoded string that contains 'embedded'
    78  // and MAC tag for 'state' + 'embedded' (but not the 'state' itself). The exact
    79  // same 'state' then must be used in Validate to successfully verify the token.
    80  //
    81  // 'embedded' is an optional map with additional data to add to the token. It is
    82  // embedded directly into the token and can be easily extracted from it by
    83  // anyone who has the token. Should be used only for publicly visible data. It
    84  // is tagged by token's MAC, so 'Validate' function can detect any modifications
    85  // (and reject tokens tampered with).
    86  //
    87  // The context is used to grab secrets.Store and the current time.
    88  func (k *TokenKind) Generate(c context.Context, state []byte, embedded map[string]string, exp time.Duration) (string, error) {
    89  	extended := make(map[string]string, len(embedded))
    90  	for k, v := range embedded {
    91  		if len(k) == 0 {
    92  			return "", fmt.Errorf("tokens: empty key in embedded map")
    93  		}
    94  		if k[0] == '_' {
    95  			return "", fmt.Errorf("token: bad key %q in embedded map", k)
    96  		}
    97  		extended[k] = v
    98  	}
    99  
   100  	// Append 'issued' timestamp (in milliseconds) and expiration time (if not
   101  	// default).
   102  	extended["_i"] = strconv.FormatInt(clock.Now(c).UnixNano()/1e6, 10)
   103  	if exp != 0 {
   104  		if exp < 0 {
   105  			return "", fmt.Errorf("tokens: expiration can't be negative")
   106  		}
   107  		extended["_x"] = strconv.FormatInt(exp.Nanoseconds()/1e6, 10)
   108  	}
   109  
   110  	// 'public' will be added to the token as is.
   111  	public, err := json.Marshal(extended)
   112  	if err != nil {
   113  		return "", err
   114  	}
   115  
   116  	// Build HMAC tag.
   117  	secret, err := secrets.RandomSecret(c, k.SecretKey)
   118  	if err != nil {
   119  		return "", err
   120  	}
   121  	mac, err := computeMAC(k.Algo, secret.Active, dataToAuth(k.Version, public, state))
   122  	if err != nil {
   123  		return "", err
   124  	}
   125  
   126  	encoded := base64.RawURLEncoding.EncodeToString(bytes.Join([][]byte{
   127  		{k.Version},
   128  		public,
   129  		mac,
   130  	}, nil))
   131  	return strings.TrimRight(encoded, "="), nil
   132  }
   133  
   134  // Validate checks token MAC and expiration, decodes data embedded into it.
   135  //
   136  // 'state' must be exactly the same as passed to Generate when creating a token.
   137  // If it's different, the token is considered invalid. It usually contains some
   138  // implicitly passed state that should be the same when token is generated and
   139  // validated. For example, it may be an account ID of a current caller. Then if
   140  // such token is used by another account, it is considered invalid.
   141  //
   142  // The context is used to grab secrets.Store and the current time.
   143  func (k *TokenKind) Validate(c context.Context, token string, state []byte) (map[string]string, error) {
   144  	digestLen := k.Algo.digestLen()
   145  	if digestLen == 0 {
   146  		return nil, fmt.Errorf("tokens: unknown algo %q", k.Algo)
   147  	}
   148  	blob, err := base64.RawURLEncoding.DecodeString(token)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	// One byte for version, at least one byte for public embedded dict portion,
   154  	// the rest is MAC digest.
   155  	if len(blob) < digestLen+2 {
   156  		return nil, fmt.Errorf("tokens: the token is too small")
   157  	}
   158  
   159  	// Data inside the token.
   160  	version := blob[0]
   161  	public := blob[1 : len(blob)-digestLen]
   162  	tokenMac := blob[len(blob)-digestLen:]
   163  
   164  	// Data that should have been used to generate HMAC.
   165  	toAuth := dataToAuth(version, public, state)
   166  
   167  	// Token could have been generated by previous value of the secret, so check
   168  	// them too.
   169  	secret, err := secrets.RandomSecret(c, k.SecretKey)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	goodToken := false
   174  	for _, blob := range secret.Blobs() {
   175  		goodMac, err := computeMAC(k.Algo, blob, toAuth)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  		if hmac.Equal(tokenMac, goodMac) {
   180  			goodToken = true
   181  			break
   182  		}
   183  	}
   184  	if !goodToken {
   185  		return nil, fmt.Errorf("tokens: bad token MAC")
   186  	}
   187  
   188  	// Token is authenticated, now check the rest.
   189  	if version != k.Version {
   190  		return nil, fmt.Errorf("tokens: bad version %q, expecting %q", version, k.Version)
   191  	}
   192  	embedded := map[string]string{}
   193  	if err := json.Unmarshal(public, &embedded); err != nil {
   194  		return nil, err
   195  	}
   196  
   197  	// Grab issued time, reject token from the future.
   198  	now := clock.Now(c)
   199  	issuedMs, err := popInt(embedded, "_i")
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	issuedTs := time.Unix(0, issuedMs*1e6)
   204  	if issuedTs.After(now.Add(allowedClockDrift)) {
   205  		return nil, fmt.Errorf("tokens: issued timestamp is in the future")
   206  	}
   207  
   208  	// Grab expiration time embedded into the token, if any.
   209  	expiration := k.Expiration
   210  	if _, ok := embedded["_x"]; ok {
   211  		expirationMs, err := popInt(embedded, "_x")
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  		expiration = time.Duration(expirationMs) * time.Millisecond
   216  	}
   217  	if expiration < 0 {
   218  		return nil, fmt.Errorf("tokens: bad token, expiration can't be negative")
   219  	}
   220  
   221  	// Check token expiration.
   222  	expired := now.Sub(issuedTs.Add(expiration))
   223  	if expired > 0 {
   224  		return nil, fmt.Errorf("tokens: token expired %s ago", expired)
   225  	}
   226  
   227  	return embedded, nil
   228  }
   229  
   230  // extractInt pops integer value from the map.
   231  func popInt(m map[string]string, key string) (int64, error) {
   232  	str, ok := m[key]
   233  	if !ok {
   234  		return 0, fmt.Errorf("tokens: bad token, missing %q key", key)
   235  	}
   236  	asInt, err := strconv.ParseInt(str, 10, 64)
   237  	if err != nil {
   238  		return 0, fmt.Errorf("tokens: bad token, %q is not a number", str)
   239  	}
   240  	delete(m, key)
   241  	return asInt, nil
   242  }
   243  
   244  // dataToAuth generates list of byte blobs authenticated by MAC.
   245  func dataToAuth(version byte, public []byte, state []byte) [][]byte {
   246  	out := [][]byte{
   247  		{version},
   248  		public,
   249  	}
   250  	if len(state) != 0 {
   251  		out = append(out, state)
   252  	}
   253  	return out
   254  }
   255  
   256  // computeMAC packs dataToAuth into single blob and computes its MAC.
   257  func computeMAC(algo TokenAlgo, secret []byte, dataToAuth [][]byte) ([]byte, error) {
   258  	hash := algo.hash(secret)
   259  	if hash == nil {
   260  		return nil, fmt.Errorf("tokens: unknown algo %q", algo)
   261  	}
   262  	for _, chunk := range dataToAuth {
   263  		// Separator between length header and the body is needed because length
   264  		// encoding is variable-length (decimal string).
   265  		if _, err := fmt.Fprintf(hash, "%d\n", len(chunk)); err != nil {
   266  			return nil, err
   267  		}
   268  		if _, err := hash.Write(chunk); err != nil {
   269  			return nil, err
   270  		}
   271  	}
   272  	return hash.Sum(nil), nil
   273  }