github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/skymarshal/token/access_token.go (about) 1 package token 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "encoding/base64" 7 "encoding/binary" 8 "encoding/json" 9 "errors" 10 "io" 11 "net/http" 12 "net/http/httptest" 13 "time" 14 15 "code.cloudfoundry.org/lager" 16 "github.com/pf-qiu/concourse/v6/atc/db" 17 "gopkg.in/square/go-jose.v2/jwt" 18 ) 19 20 //go:generate counterfeiter . Generator 21 22 type Generator interface { 23 GenerateAccessToken(claims db.Claims) (string, error) 24 } 25 26 //go:generate counterfeiter . Parser 27 28 type Parser interface { 29 ParseExpiry(raw string) (time.Time, error) 30 } 31 32 //go:generate counterfeiter . ClaimsParser 33 34 type ClaimsParser interface { 35 ParseClaims(idToken string) (db.Claims, error) 36 } 37 38 func StoreAccessToken( 39 logger lager.Logger, 40 handler http.Handler, 41 generator Generator, 42 claimsParser ClaimsParser, 43 accessTokenFactory db.AccessTokenFactory, 44 userFactory db.UserFactory, 45 ) http.Handler { 46 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 47 if r.URL.Path != "/sky/issuer/token" { 48 handler.ServeHTTP(w, r) 49 return 50 } 51 logger := logger.Session("token-request") 52 logger.Debug("start") 53 defer logger.Debug("end") 54 55 rec := httptest.NewRecorder() 56 handler.ServeHTTP(rec, r) 57 58 var body io.Reader 59 defer func() { 60 copyResponseHeaders(w, rec.Result()) 61 if body != nil { 62 io.Copy(w, body) 63 } 64 }() 65 if rec.Code < 200 || rec.Code > 299 { 66 body = rec.Body 67 return 68 } 69 var resp struct { 70 AccessToken string `json:"access_token"` 71 TokenType string `json:"token_type"` 72 ExpiresIn int `json:"expires_in"` 73 RefreshToken string `json:"refresh_token,omitempty"` 74 IDToken string `json:"id_token"` 75 } 76 err := json.Unmarshal(rec.Body.Bytes(), &resp) 77 if err != nil { 78 logger.Error("unmarshal-response-from-dex", err) 79 w.WriteHeader(http.StatusInternalServerError) 80 return 81 } 82 claims, err := claimsParser.ParseClaims(resp.IDToken) 83 if err != nil { 84 logger.Error("parse-id-token", err) 85 w.WriteHeader(http.StatusInternalServerError) 86 return 87 } 88 resp.AccessToken, err = generator.GenerateAccessToken(claims) 89 if err != nil { 90 logger.Error("generate-access-token", err) 91 w.WriteHeader(http.StatusInternalServerError) 92 return 93 } 94 err = accessTokenFactory.CreateAccessToken(resp.AccessToken, claims) 95 if err != nil { 96 logger.Error("create-access-token-in-db", err) 97 w.WriteHeader(http.StatusInternalServerError) 98 return 99 } 100 username := claims.Username 101 if claims.PreferredUsername != "" { 102 username = claims.PreferredUsername 103 } 104 err = userFactory.CreateOrUpdateUser(username, claims.Connector, claims.Subject) 105 if err != nil { 106 logger.Error("create-or-update-user", err) 107 w.WriteHeader(http.StatusInternalServerError) 108 return 109 } 110 newResp, err := json.Marshal(resp) 111 if err != nil { 112 logger.Error("marshal-new-response", err) 113 w.WriteHeader(http.StatusInternalServerError) 114 return 115 } 116 body = bytes.NewReader(newResp) 117 }) 118 } 119 120 func copyResponseHeaders(w http.ResponseWriter, res *http.Response) { 121 for k, v := range res.Header { 122 k = http.CanonicalHeaderKey(k) 123 if k != "Content-Length" { 124 w.Header()[k] = v 125 } 126 } 127 w.WriteHeader(res.StatusCode) 128 } 129 130 func NewClaimsParser() ClaimsParser { 131 return claimsParserNoVerify{} 132 } 133 134 type claimsParserNoVerify struct { 135 } 136 137 func (claimsParserNoVerify) ParseClaims(idToken string) (db.Claims, error) { 138 token, err := jwt.ParseSigned(idToken) 139 if err != nil { 140 return db.Claims{}, err 141 } 142 143 var claims db.Claims 144 err = token.UnsafeClaimsWithoutVerification(&claims) 145 if err != nil { 146 return db.Claims{}, err 147 } 148 return claims, nil 149 } 150 151 type Factory struct { 152 } 153 154 // GenerateAccessToken generates a token with 20 bytes of entropy with the 155 // unix timestamp appended. 156 func (Factory) GenerateAccessToken(claims db.Claims) (string, error) { 157 b := [28]byte{} 158 _, err := rand.Read(b[:20]) 159 if err != nil { 160 return "", err 161 } 162 if claims.Expiry == nil { 163 return "", errors.New("missing 'exp' claim") 164 } 165 binary.LittleEndian.PutUint64(b[20:], uint64(*claims.Expiry)) 166 return base64.RawStdEncoding.EncodeToString(b[:]), nil 167 } 168 169 func (Factory) ParseExpiry(accessToken string) (time.Time, error) { 170 raw, err := base64.RawStdEncoding.DecodeString(accessToken) 171 if err != nil { 172 return time.Time{}, err 173 } 174 if len(raw) != 28 { 175 return time.Time{}, errors.New("invalid access token length") 176 } 177 expiry := jwt.NumericDate(binary.LittleEndian.Uint64(raw[20:])) 178 return expiry.Time(), nil 179 }