github.com/weaviate/weaviate@v1.24.6/usecases/auth/authentication/oidc/middleware.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package oidc 13 14 import ( 15 "context" 16 "fmt" 17 "strings" 18 19 "github.com/coreos/go-oidc/v3/oidc" 20 errors "github.com/go-openapi/errors" 21 "github.com/weaviate/weaviate/entities/models" 22 "github.com/weaviate/weaviate/usecases/config" 23 ) 24 25 // Client handles the OIDC setup at startup and provides a middleware to be 26 // used with the goswagger API 27 type Client struct { 28 config config.OIDC 29 provider *oidc.Provider 30 verifier *oidc.IDTokenVerifier 31 } 32 33 // New OIDC Client: It tries to retrieve the JWKs at startup (or fails), it 34 // provides a middleware which can be used at runtime with a go-swagger style 35 // API 36 func New(cfg config.Config) (*Client, error) { 37 client := &Client{ 38 config: cfg.Authentication.OIDC, 39 } 40 41 if !client.config.Enabled { 42 // if oidc is not enabled, we are done, no need to setup an actual client. 43 // The "disabled" client is however still valuable to deny any requests 44 // coming in with an OAuth token set. 45 return client, nil 46 } 47 48 if err := client.init(); err != nil { 49 return nil, fmt.Errorf("oidc init: %v", err) 50 } 51 52 return client, nil 53 } 54 55 func (c *Client) init() error { 56 if err := c.validateConfig(); err != nil { 57 return fmt.Errorf("invalid config: %v", err) 58 } 59 60 provider, err := oidc.NewProvider(context.Background(), c.config.Issuer) 61 if err != nil { 62 return fmt.Errorf("could not setup provider: %v", err) 63 } 64 c.provider = provider 65 66 // oauth2 67 68 verifier := provider.Verifier(&oidc.Config{ 69 ClientID: c.config.ClientID, 70 SkipClientIDCheck: c.config.SkipClientIDCheck, 71 }) 72 c.verifier = verifier 73 74 return nil 75 } 76 77 func (c *Client) validateConfig() error { 78 var msgs []string 79 80 if c.config.Issuer == "" { 81 msgs = append(msgs, "missing required field 'issuer'") 82 } 83 84 if c.config.UsernameClaim == "" { 85 msgs = append(msgs, "missing required field 'username_claim'") 86 } 87 88 if !c.config.SkipClientIDCheck && c.config.ClientID == "" { 89 msgs = append(msgs, "missing required field 'client_id': "+ 90 "either set a client_id or explicitly disable the check with 'skip_client_id_check: true'") 91 } 92 93 if len(msgs) == 0 { 94 return nil 95 } 96 97 return fmt.Errorf(strings.Join(msgs, ", ")) 98 } 99 100 // ValidateAndExtract can be used as a middleware for go-swagger 101 func (c *Client) ValidateAndExtract(token string, scopes []string) (*models.Principal, error) { 102 if !c.config.Enabled { 103 return nil, errors.New(401, "oidc auth is not configured, please try another auth scheme or set up weaviate with OIDC configured") 104 } 105 106 parsed, err := c.verifier.Verify(context.Background(), token) 107 if err != nil { 108 return nil, errors.New(401, err.Error()) 109 } 110 111 claims, err := c.extractClaims(parsed) 112 if err != nil { 113 return nil, errors.New(500, fmt.Sprintf("oidc: %v", err)) 114 } 115 116 username, err := c.extractUsername(claims) 117 if err != nil { 118 return nil, errors.New(500, fmt.Sprintf("oidc: %v", err)) 119 } 120 121 groups := c.extractGroups(claims) 122 123 return &models.Principal{ 124 Username: username, 125 Groups: groups, 126 }, nil 127 } 128 129 func (c *Client) extractClaims(token *oidc.IDToken) (map[string]interface{}, error) { 130 var claims map[string]interface{} 131 if err := token.Claims(&claims); err != nil { 132 return nil, fmt.Errorf("could not extract claims from token: %v", err) 133 } 134 135 return claims, nil 136 } 137 138 func (c *Client) extractUsername(claims map[string]interface{}) (string, error) { 139 usernameUntyped, ok := claims[c.config.UsernameClaim] 140 if !ok { 141 return "", fmt.Errorf("token doesn't contain required claim '%s'", c.config.UsernameClaim) 142 } 143 144 username, ok := usernameUntyped.(string) 145 if !ok { 146 return "", fmt.Errorf("claim '%s' is not a string, but %T", c.config.UsernameClaim, usernameUntyped) 147 } 148 149 return username, nil 150 } 151 152 // extractGroups never errors, if groups can't be parsed an empty set of groups 153 // is returned. This is because groups are not a required standard in the OIDC 154 // spec, so we can't error if an OIDC provider does not support them. 155 func (c *Client) extractGroups(claims map[string]interface{}) []string { 156 var groups []string 157 158 groupsUntyped, ok := claims[c.config.GroupsClaim] 159 if !ok { 160 return groups 161 } 162 163 groupsSlice, ok := groupsUntyped.([]interface{}) 164 if !ok { 165 return groups 166 } 167 168 for _, untyped := range groupsSlice { 169 if group, ok := untyped.(string); ok { 170 groups = append(groups, group) 171 } 172 } 173 174 return groups 175 }