github.com/weaviate/weaviate@v1.24.6/usecases/auth/authentication/oidc/middleware.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package oidc
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"strings"
    18  
    19  	"github.com/coreos/go-oidc/v3/oidc"
    20  	errors "github.com/go-openapi/errors"
    21  	"github.com/weaviate/weaviate/entities/models"
    22  	"github.com/weaviate/weaviate/usecases/config"
    23  )
    24  
    25  // Client handles the OIDC setup at startup and provides a middleware to be
    26  // used with the goswagger API
    27  type Client struct {
    28  	config   config.OIDC
    29  	provider *oidc.Provider
    30  	verifier *oidc.IDTokenVerifier
    31  }
    32  
    33  // New OIDC Client: It tries to retrieve the JWKs at startup (or fails), it
    34  // provides a middleware which can be used at runtime with a go-swagger style
    35  // API
    36  func New(cfg config.Config) (*Client, error) {
    37  	client := &Client{
    38  		config: cfg.Authentication.OIDC,
    39  	}
    40  
    41  	if !client.config.Enabled {
    42  		// if oidc is not enabled, we are done, no need to setup an actual client.
    43  		// The "disabled" client is however still valuable to deny any requests
    44  		// coming in with an OAuth token set.
    45  		return client, nil
    46  	}
    47  
    48  	if err := client.init(); err != nil {
    49  		return nil, fmt.Errorf("oidc init: %v", err)
    50  	}
    51  
    52  	return client, nil
    53  }
    54  
    55  func (c *Client) init() error {
    56  	if err := c.validateConfig(); err != nil {
    57  		return fmt.Errorf("invalid config: %v", err)
    58  	}
    59  
    60  	provider, err := oidc.NewProvider(context.Background(), c.config.Issuer)
    61  	if err != nil {
    62  		return fmt.Errorf("could not setup provider: %v", err)
    63  	}
    64  	c.provider = provider
    65  
    66  	// oauth2
    67  
    68  	verifier := provider.Verifier(&oidc.Config{
    69  		ClientID:          c.config.ClientID,
    70  		SkipClientIDCheck: c.config.SkipClientIDCheck,
    71  	})
    72  	c.verifier = verifier
    73  
    74  	return nil
    75  }
    76  
    77  func (c *Client) validateConfig() error {
    78  	var msgs []string
    79  
    80  	if c.config.Issuer == "" {
    81  		msgs = append(msgs, "missing required field 'issuer'")
    82  	}
    83  
    84  	if c.config.UsernameClaim == "" {
    85  		msgs = append(msgs, "missing required field 'username_claim'")
    86  	}
    87  
    88  	if !c.config.SkipClientIDCheck && c.config.ClientID == "" {
    89  		msgs = append(msgs, "missing required field 'client_id': "+
    90  			"either set a client_id or explicitly disable the check with 'skip_client_id_check: true'")
    91  	}
    92  
    93  	if len(msgs) == 0 {
    94  		return nil
    95  	}
    96  
    97  	return fmt.Errorf(strings.Join(msgs, ", "))
    98  }
    99  
   100  // ValidateAndExtract can be used as a middleware for go-swagger
   101  func (c *Client) ValidateAndExtract(token string, scopes []string) (*models.Principal, error) {
   102  	if !c.config.Enabled {
   103  		return nil, errors.New(401, "oidc auth is not configured, please try another auth scheme or set up weaviate with OIDC configured")
   104  	}
   105  
   106  	parsed, err := c.verifier.Verify(context.Background(), token)
   107  	if err != nil {
   108  		return nil, errors.New(401, err.Error())
   109  	}
   110  
   111  	claims, err := c.extractClaims(parsed)
   112  	if err != nil {
   113  		return nil, errors.New(500, fmt.Sprintf("oidc: %v", err))
   114  	}
   115  
   116  	username, err := c.extractUsername(claims)
   117  	if err != nil {
   118  		return nil, errors.New(500, fmt.Sprintf("oidc: %v", err))
   119  	}
   120  
   121  	groups := c.extractGroups(claims)
   122  
   123  	return &models.Principal{
   124  		Username: username,
   125  		Groups:   groups,
   126  	}, nil
   127  }
   128  
   129  func (c *Client) extractClaims(token *oidc.IDToken) (map[string]interface{}, error) {
   130  	var claims map[string]interface{}
   131  	if err := token.Claims(&claims); err != nil {
   132  		return nil, fmt.Errorf("could not extract claims from token: %v", err)
   133  	}
   134  
   135  	return claims, nil
   136  }
   137  
   138  func (c *Client) extractUsername(claims map[string]interface{}) (string, error) {
   139  	usernameUntyped, ok := claims[c.config.UsernameClaim]
   140  	if !ok {
   141  		return "", fmt.Errorf("token doesn't contain required claim '%s'", c.config.UsernameClaim)
   142  	}
   143  
   144  	username, ok := usernameUntyped.(string)
   145  	if !ok {
   146  		return "", fmt.Errorf("claim '%s' is not a string, but %T", c.config.UsernameClaim, usernameUntyped)
   147  	}
   148  
   149  	return username, nil
   150  }
   151  
   152  // extractGroups never errors, if groups can't be parsed an empty set of groups
   153  // is returned. This is because groups are not a required standard in the OIDC
   154  // spec, so we can't error if an OIDC provider does not support them.
   155  func (c *Client) extractGroups(claims map[string]interface{}) []string {
   156  	var groups []string
   157  
   158  	groupsUntyped, ok := claims[c.config.GroupsClaim]
   159  	if !ok {
   160  		return groups
   161  	}
   162  
   163  	groupsSlice, ok := groupsUntyped.([]interface{})
   164  	if !ok {
   165  		return groups
   166  	}
   167  
   168  	for _, untyped := range groupsSlice {
   169  		if group, ok := untyped.(string); ok {
   170  			groups = append(groups, group)
   171  		}
   172  	}
   173  
   174  	return groups
   175  }