github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/jwt/provider/provider_test.go (about) 1 package provider 2 3 import ( 4 "net/http" 5 "testing" 6 7 jwt "github.com/dgrijalva/jwt-go" 8 "github.com/hellofresh/janus/pkg/config" 9 "github.com/stretchr/testify/assert" 10 ) 11 12 type mockProvider struct{} 13 14 func (p *mockProvider) Build(config config.Credentials) Provider { 15 return &mockProvider{} 16 } 17 func (p *mockProvider) Verify(r *http.Request, httpClient *http.Client) (bool, error) { 18 return true, nil 19 } 20 func (p *mockProvider) GetClaims(httpClient *http.Client) (jwt.MapClaims, error) { 21 return jwt.MapClaims{}, nil 22 } 23 24 type defaultProvider struct{} 25 26 func (p *defaultProvider) Build(config config.Credentials) Provider { 27 return &defaultProvider{} 28 } 29 func (p *defaultProvider) Verify(r *http.Request, httpClient *http.Client) (bool, error) { 30 return true, nil 31 } 32 func (p *defaultProvider) GetClaims(httpClient *http.Client) (jwt.MapClaims, error) { 33 return jwt.MapClaims{}, nil 34 } 35 func TestProviders(t *testing.T) { 36 tests := []struct { 37 scenario string 38 function func(*testing.T, *Factory) 39 }{ 40 { 41 scenario: "it should build providers properly", 42 function: testFactoryCanBuildProvider, 43 }, 44 { 45 scenario: "when given a wrong provider, it should get the default", 46 function: testFactoryCantFindProvider, 47 }, 48 } 49 50 for _, test := range tests { 51 t.Run(test.scenario, func(t *testing.T) { 52 t.Parallel() 53 Register("test", &mockProvider{}) 54 Register("basic", &defaultProvider{}) 55 56 f := &Factory{} 57 test.function(t, f) 58 }) 59 } 60 } 61 62 func testFactoryCanBuildProvider(t *testing.T, f *Factory) { 63 p := f.Build("test", config.Credentials{}) 64 65 assert.Implements(t, (*Provider)(nil), p) 66 assert.IsType(t, (*mockProvider)(nil), p) 67 } 68 69 func testFactoryCantFindProvider(t *testing.T, f *Factory) { 70 p := f.Build("wrong", config.Credentials{}) 71 72 assert.Implements(t, (*Provider)(nil), p) 73 assert.IsType(t, (*defaultProvider)(nil), p) 74 } 75 76 func testCountProvider(t *testing.T, f *Factory) { 77 assert.Len(t, GetProviders(), 2) 78 }