github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/accessor/verifier_test.go (about)

     1  package accessor_test
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"time"
     7  
     8  	"github.com/pf-qiu/concourse/v6/atc/api/accessor/accessorfakes"
     9  	"github.com/pf-qiu/concourse/v6/atc/db"
    10  	. "github.com/onsi/ginkgo"
    11  	. "github.com/onsi/gomega"
    12  	"gopkg.in/square/go-jose.v2/jwt"
    13  
    14  	"github.com/pf-qiu/concourse/v6/atc/api/accessor"
    15  )
    16  
    17  var _ = Describe("Verifier", func() {
    18  	var (
    19  		accessTokenFetcher *accessorfakes.FakeAccessTokenFetcher
    20  		accessToken        db.AccessToken
    21  
    22  		req *http.Request
    23  
    24  		verifier accessor.TokenVerifier
    25  
    26  		err error
    27  	)
    28  
    29  	BeforeEach(func() {
    30  		accessTokenFetcher = new(accessorfakes.FakeAccessTokenFetcher)
    31  		accessTokenFetcher.GetAccessTokenCalls(func(string) (db.AccessToken, bool, error) {
    32  			return accessToken, true, nil
    33  		})
    34  
    35  		req, _ = http.NewRequest("GET", "localhost:8080", nil)
    36  		req.Header.Set("Authorization", "bearer 1234567890")
    37  
    38  		verifier = accessor.NewVerifier(accessTokenFetcher, []string{"some-aud"})
    39  	})
    40  
    41  	Describe("Verify", func() {
    42  
    43  		JustBeforeEach(func() {
    44  			_, err = verifier.Verify(req)
    45  		})
    46  
    47  		Context("when request has no token", func() {
    48  			BeforeEach(func() {
    49  				req.Header.Del("Authorization")
    50  			})
    51  
    52  			It("fails with no token", func() {
    53  				Expect(err).To(Equal(accessor.ErrVerificationNoToken))
    54  			})
    55  		})
    56  
    57  		Context("when request has an invalid auth header", func() {
    58  			BeforeEach(func() {
    59  				req.Header.Set("Authorization", "invalid")
    60  			})
    61  
    62  			It("fails verification", func() {
    63  				Expect(err).To(Equal(accessor.ErrVerificationInvalidToken))
    64  			})
    65  		})
    66  
    67  		Context("when request has an invalid token type", func() {
    68  			BeforeEach(func() {
    69  				req.Header.Set("Authorization", "not-bearer 1234567890")
    70  			})
    71  
    72  			It("fails verification", func() {
    73  				Expect(err).To(Equal(accessor.ErrVerificationInvalidToken))
    74  			})
    75  		})
    76  
    77  		Context("when getting the access token errors", func() {
    78  			BeforeEach(func() {
    79  				accessTokenFetcher.GetAccessTokenReturns(db.AccessToken{}, false, errors.New("db error"))
    80  			})
    81  
    82  			It("errors", func() {
    83  				Expect(err).To(MatchError("db error"))
    84  			})
    85  		})
    86  
    87  		Context("when the token is not found in the DB", func() {
    88  			BeforeEach(func() {
    89  				accessTokenFetcher.GetAccessTokenReturns(db.AccessToken{}, false, nil)
    90  			})
    91  
    92  			It("fails verification", func() {
    93  				Expect(err).To(Equal(accessor.ErrVerificationInvalidToken))
    94  			})
    95  		})
    96  
    97  		Context("when the claims have expired", func() {
    98  			BeforeEach(func() {
    99  				oneHourAgo := jwt.NewNumericDate(time.Now().Add(-1 * time.Hour))
   100  				accessToken.Claims = db.Claims{
   101  					Claims: jwt.Claims{
   102  						Expiry: oneHourAgo,
   103  					},
   104  				}
   105  			})
   106  
   107  			It("fails verification", func() {
   108  				Expect(err).To(Equal(accessor.ErrVerificationTokenExpired))
   109  			})
   110  		})
   111  
   112  		Context("when the claims have invalid audience", func() {
   113  			BeforeEach(func() {
   114  				oneHourFromNow := jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
   115  				accessToken.Claims = db.Claims{
   116  					Claims: jwt.Claims{
   117  						Expiry:   oneHourFromNow,
   118  						Audience: []string{"invalid"},
   119  					},
   120  				}
   121  			})
   122  
   123  			It("fails verification", func() {
   124  				Expect(err).To(Equal(accessor.ErrVerificationInvalidAudience))
   125  			})
   126  		})
   127  
   128  		Context("when the claims are valid", func() {
   129  			BeforeEach(func() {
   130  				oneHourFromNow := jwt.NewNumericDate(time.Now().Add(1 * time.Hour))
   131  				accessToken.Claims = db.Claims{
   132  					Claims: jwt.Claims{
   133  						Expiry:   oneHourFromNow,
   134  						Audience: []string{"some-aud"},
   135  					},
   136  				}
   137  			})
   138  
   139  			It("succeeds", func() {
   140  				Expect(err).ToNot(HaveOccurred())
   141  			})
   142  		})
   143  	})
   144  })