github.com/chenbh/concourse/v6@v6.4.2/skymarshal/token/access_token_test.go (about)

     1  package token_test
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"time"
     8  
     9  	"code.cloudfoundry.org/lager/lagertest"
    10  	"github.com/chenbh/concourse/v6/atc/db"
    11  	"github.com/chenbh/concourse/v6/atc/db/dbfakes"
    12  	"github.com/chenbh/concourse/v6/skymarshal/token"
    13  	"github.com/chenbh/concourse/v6/skymarshal/token/tokenfakes"
    14  	. "github.com/onsi/ginkgo"
    15  	. "github.com/onsi/gomega"
    16  	"gopkg.in/square/go-jose.v2/jwt"
    17  )
    18  
    19  var _ = Describe("Access Tokens", func() {
    20  
    21  	Describe("StoreAccessToken", func() {
    22  		var (
    23  			generator          *tokenfakes.FakeGenerator
    24  			claimsParser       *tokenfakes.FakeClaimsParser
    25  			accessTokenFactory *dbfakes.FakeAccessTokenFactory
    26  			userFactory        *dbfakes.FakeUserFactory
    27  
    28  			dummyLogger *lagertest.TestLogger
    29  		)
    30  
    31  		BeforeEach(func() {
    32  			generator = new(tokenfakes.FakeGenerator)
    33  			claimsParser = new(tokenfakes.FakeClaimsParser)
    34  			accessTokenFactory = new(dbfakes.FakeAccessTokenFactory)
    35  			userFactory = new(dbfakes.FakeUserFactory)
    36  
    37  			dummyLogger = lagertest.NewTestLogger("whatever")
    38  		})
    39  
    40  		type testCase struct {
    41  			it string
    42  
    43  			path       string
    44  			statusCode int
    45  			body       string
    46  
    47  			parseClaimsErrors   bool
    48  			generateTokenErrors bool
    49  			storeTokenErrors    bool
    50  			storeUserErrors     bool
    51  
    52  			expectStatusCode int
    53  			expectBody       string
    54  		}
    55  
    56  		for _, t := range []testCase{
    57  			{
    58  				it: "forwards non-token requests",
    59  
    60  				path:       "/sky/issuer/callback",
    61  				statusCode: 200,
    62  				body:       "some payload",
    63  
    64  				expectStatusCode: 200,
    65  				expectBody:       "some payload",
    66  			},
    67  			{
    68  				it: "modifies the access token",
    69  
    70  				path:       "/sky/issuer/token",
    71  				statusCode: 200,
    72  				body:       `{"access_token":"123","token_type":"bearer","expires_in":1234,"id_token":"a.b.c"}`,
    73  
    74  				expectStatusCode: 200,
    75  				expectBody:       `{"access_token":"123abc","token_type":"bearer","expires_in":1234,"id_token":"a.b.c"}`,
    76  			},
    77  			{
    78  				it: "forwards failure response",
    79  
    80  				path:       "/sky/issuer/token",
    81  				statusCode: 418,
    82  				body:       "i've made a huge mistake",
    83  
    84  				expectStatusCode: 418,
    85  				expectBody:       "i've made a huge mistake",
    86  			},
    87  			{
    88  				it: "errors if parsing claims fails",
    89  
    90  				path:       "/sky/issuer/token",
    91  				statusCode: 200,
    92  				body:       `{"access_token":"123","token_type":"bearer","expires_in":1234,"id_token":"invalid"}`,
    93  
    94  				parseClaimsErrors: true,
    95  
    96  				expectStatusCode: 500,
    97  			},
    98  			{
    99  				it: "errors if generating token fails",
   100  
   101  				path:       "/sky/issuer/token",
   102  				statusCode: 200,
   103  				body:       `{"access_token":"123","token_type":"bearer","expires_in":1234,"id_token":"a.b.c"}`,
   104  
   105  				generateTokenErrors: true,
   106  
   107  				expectStatusCode: 500,
   108  			},
   109  			{
   110  				it: "errors if storing token fails",
   111  
   112  				path:       "/sky/issuer/token",
   113  				statusCode: 200,
   114  				body:       `{"access_token":"123","token_type":"bearer","expires_in":1234,"id_token":"a.b.c"}`,
   115  
   116  				storeTokenErrors: true,
   117  
   118  				expectStatusCode: 500,
   119  			},
   120  			{
   121  				it: "errors if storing user fails",
   122  
   123  				path:       "/sky/issuer/token",
   124  				statusCode: 200,
   125  				body:       `{"access_token":"123","token_type":"bearer","expires_in":1234,"id_token":"a.b.c"}`,
   126  
   127  				storeUserErrors: true,
   128  
   129  				expectStatusCode: 500,
   130  			},
   131  		} {
   132  			t := t
   133  
   134  			It(t.it, func() {
   135  				baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   136  					w.WriteHeader(t.statusCode)
   137  					w.Write([]byte(t.body))
   138  				})
   139  				handler := token.StoreAccessToken(dummyLogger, baseHandler, generator, claimsParser, accessTokenFactory, userFactory)
   140  				r, _ := http.NewRequest("GET", t.path, nil)
   141  				rec := httptest.NewRecorder()
   142  
   143  				if t.parseClaimsErrors {
   144  					claimsParser.ParseClaimsReturns(db.Claims{}, errors.New("claims parse error"))
   145  				}
   146  
   147  				if t.generateTokenErrors {
   148  					generator.GenerateAccessTokenReturns("", errors.New("generate error"))
   149  				} else {
   150  					generator.GenerateAccessTokenReturns("123abc", nil)
   151  				}
   152  
   153  				if t.storeTokenErrors {
   154  					accessTokenFactory.CreateAccessTokenReturns(errors.New("store token error"))
   155  				}
   156  
   157  				if t.storeUserErrors {
   158  					userFactory.CreateOrUpdateUserReturns(errors.New("upsert user error"))
   159  				}
   160  
   161  				handler.ServeHTTP(rec, r)
   162  
   163  				result := rec.Result()
   164  				Expect(result.StatusCode).To(Equal(t.expectStatusCode))
   165  				Expect(rec.Body.String()).To(Equal(t.expectBody))
   166  			})
   167  		}
   168  	})
   169  
   170  	Describe("Token Generation", func() {
   171  		It("generates a token with the unix timestamp", func() {
   172  			factory := token.Factory{}
   173  			expectExpiry := jwt.NewNumericDate(time.Now())
   174  			rawToken, err := factory.GenerateAccessToken(db.Claims{
   175  				Claims: jwt.Claims{Expiry: expectExpiry},
   176  			})
   177  			Expect(err).ToNot(HaveOccurred())
   178  			expiry, err := factory.ParseExpiry(rawToken)
   179  			Expect(err).ToNot(HaveOccurred())
   180  
   181  			Expect(expiry).To(Equal(expectExpiry.Time()))
   182  		})
   183  	})
   184  })