github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/skymarshal/token/access_token.go (about)

     1  package token
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/base64"
     7  	"encoding/binary"
     8  	"encoding/json"
     9  	"errors"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"time"
    14  
    15  	"code.cloudfoundry.org/lager"
    16  	"github.com/pf-qiu/concourse/v6/atc/db"
    17  	"gopkg.in/square/go-jose.v2/jwt"
    18  )
    19  
    20  //go:generate counterfeiter . Generator
    21  
    22  type Generator interface {
    23  	GenerateAccessToken(claims db.Claims) (string, error)
    24  }
    25  
    26  //go:generate counterfeiter . Parser
    27  
    28  type Parser interface {
    29  	ParseExpiry(raw string) (time.Time, error)
    30  }
    31  
    32  //go:generate counterfeiter . ClaimsParser
    33  
    34  type ClaimsParser interface {
    35  	ParseClaims(idToken string) (db.Claims, error)
    36  }
    37  
    38  func StoreAccessToken(
    39  	logger lager.Logger,
    40  	handler http.Handler,
    41  	generator Generator,
    42  	claimsParser ClaimsParser,
    43  	accessTokenFactory db.AccessTokenFactory,
    44  	userFactory db.UserFactory,
    45  ) http.Handler {
    46  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    47  		if r.URL.Path != "/sky/issuer/token" {
    48  			handler.ServeHTTP(w, r)
    49  			return
    50  		}
    51  		logger := logger.Session("token-request")
    52  		logger.Debug("start")
    53  		defer logger.Debug("end")
    54  
    55  		rec := httptest.NewRecorder()
    56  		handler.ServeHTTP(rec, r)
    57  
    58  		var body io.Reader
    59  		defer func() {
    60  			copyResponseHeaders(w, rec.Result())
    61  			if body != nil {
    62  				io.Copy(w, body)
    63  			}
    64  		}()
    65  		if rec.Code < 200 || rec.Code > 299 {
    66  			body = rec.Body
    67  			return
    68  		}
    69  		var resp struct {
    70  			AccessToken  string `json:"access_token"`
    71  			TokenType    string `json:"token_type"`
    72  			ExpiresIn    int    `json:"expires_in"`
    73  			RefreshToken string `json:"refresh_token,omitempty"`
    74  			IDToken      string `json:"id_token"`
    75  		}
    76  		err := json.Unmarshal(rec.Body.Bytes(), &resp)
    77  		if err != nil {
    78  			logger.Error("unmarshal-response-from-dex", err)
    79  			w.WriteHeader(http.StatusInternalServerError)
    80  			return
    81  		}
    82  		claims, err := claimsParser.ParseClaims(resp.IDToken)
    83  		if err != nil {
    84  			logger.Error("parse-id-token", err)
    85  			w.WriteHeader(http.StatusInternalServerError)
    86  			return
    87  		}
    88  		resp.AccessToken, err = generator.GenerateAccessToken(claims)
    89  		if err != nil {
    90  			logger.Error("generate-access-token", err)
    91  			w.WriteHeader(http.StatusInternalServerError)
    92  			return
    93  		}
    94  		err = accessTokenFactory.CreateAccessToken(resp.AccessToken, claims)
    95  		if err != nil {
    96  			logger.Error("create-access-token-in-db", err)
    97  			w.WriteHeader(http.StatusInternalServerError)
    98  			return
    99  		}
   100  		username := claims.Username
   101  		if claims.PreferredUsername != "" {
   102  			username = claims.PreferredUsername
   103  		}
   104  		err = userFactory.CreateOrUpdateUser(username, claims.Connector, claims.Subject)
   105  		if err != nil {
   106  			logger.Error("create-or-update-user", err)
   107  			w.WriteHeader(http.StatusInternalServerError)
   108  			return
   109  		}
   110  		newResp, err := json.Marshal(resp)
   111  		if err != nil {
   112  			logger.Error("marshal-new-response", err)
   113  			w.WriteHeader(http.StatusInternalServerError)
   114  			return
   115  		}
   116  		body = bytes.NewReader(newResp)
   117  	})
   118  }
   119  
   120  func copyResponseHeaders(w http.ResponseWriter, res *http.Response) {
   121  	for k, v := range res.Header {
   122  		k = http.CanonicalHeaderKey(k)
   123  		if k != "Content-Length" {
   124  			w.Header()[k] = v
   125  		}
   126  	}
   127  	w.WriteHeader(res.StatusCode)
   128  }
   129  
   130  func NewClaimsParser() ClaimsParser {
   131  	return claimsParserNoVerify{}
   132  }
   133  
   134  type claimsParserNoVerify struct {
   135  }
   136  
   137  func (claimsParserNoVerify) ParseClaims(idToken string) (db.Claims, error) {
   138  	token, err := jwt.ParseSigned(idToken)
   139  	if err != nil {
   140  		return db.Claims{}, err
   141  	}
   142  
   143  	var claims db.Claims
   144  	err = token.UnsafeClaimsWithoutVerification(&claims)
   145  	if err != nil {
   146  		return db.Claims{}, err
   147  	}
   148  	return claims, nil
   149  }
   150  
   151  type Factory struct {
   152  }
   153  
   154  // GenerateAccessToken generates a token with 20 bytes of entropy with the
   155  // unix timestamp appended.
   156  func (Factory) GenerateAccessToken(claims db.Claims) (string, error) {
   157  	b := [28]byte{}
   158  	_, err := rand.Read(b[:20])
   159  	if err != nil {
   160  		return "", err
   161  	}
   162  	if claims.Expiry == nil {
   163  		return "", errors.New("missing 'exp' claim")
   164  	}
   165  	binary.LittleEndian.PutUint64(b[20:], uint64(*claims.Expiry))
   166  	return base64.RawStdEncoding.EncodeToString(b[:]), nil
   167  }
   168  
   169  func (Factory) ParseExpiry(accessToken string) (time.Time, error) {
   170  	raw, err := base64.RawStdEncoding.DecodeString(accessToken)
   171  	if err != nil {
   172  		return time.Time{}, err
   173  	}
   174  	if len(raw) != 28 {
   175  		return time.Time{}, errors.New("invalid access token length")
   176  	}
   177  	expiry := jwt.NumericDate(binary.LittleEndian.Uint64(raw[20:]))
   178  	return expiry.Time(), nil
   179  }