github.com/snyk/vervet/v5@v5.11.1-0.20240202085829-ad4dd7fb6101/versionware/validator.go (about)

     1  package versionware
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"time"
     9  
    10  	"github.com/getkin/kin-openapi/openapi3"
    11  	"github.com/getkin/kin-openapi/openapi3filter"
    12  	"github.com/getkin/kin-openapi/routers/gorillamux"
    13  
    14  	"github.com/snyk/vervet/v5"
    15  )
    16  
    17  // Validator provides versioned OpenAPI validation middleware for HTTP requests
    18  // and responses.
    19  type Validator struct {
    20  	versions   vervet.VersionIndex
    21  	validators map[vervet.Version]*openapi3filter.Validator
    22  	errFunc    VersionErrorHandler
    23  	today      func() time.Time
    24  }
    25  
    26  // ValidatorConfig defines how a new Validator may be configured.
    27  type ValidatorConfig struct {
    28  	// ServerURL overrides the server URLs in the given OpenAPI specs to match
    29  	// the URL of requests reaching the backend service. If unset, requests
    30  	// must match the servers defined in OpenAPI specs.
    31  	ServerURL string
    32  
    33  	// VersionError is called on any error that occurs when trying to resolve the
    34  	// API version.
    35  	VersionError VersionErrorHandler
    36  
    37  	// Options further configure the request and response validation. See
    38  	// https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#ValidatorOption
    39  	// for available options.
    40  	Options []openapi3filter.ValidatorOption
    41  }
    42  
    43  var defaultValidatorConfig = ValidatorConfig{
    44  	VersionError: DefaultVersionError,
    45  	Options: []openapi3filter.ValidatorOption{
    46  		openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, _ error) {
    47  			statusText := http.StatusText(http.StatusInternalServerError)
    48  			switch code {
    49  			case openapi3filter.ErrCodeCannotFindRoute:
    50  				statusText = "Not Found"
    51  			case openapi3filter.ErrCodeRequestInvalid:
    52  				statusText = "Bad Request"
    53  			}
    54  			http.Error(w, statusText, status)
    55  		}),
    56  	},
    57  }
    58  
    59  func today() time.Time {
    60  	now := time.Now().UTC()
    61  	return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
    62  }
    63  
    64  // NewValidator returns a new validation middleware, which validates versioned
    65  // requests according to the given OpenAPI spec versions. For configuration
    66  // defaults, a nil config may be used.
    67  func NewValidator(config *ValidatorConfig, docs ...*openapi3.T) (*Validator, error) {
    68  	if len(docs) == 0 {
    69  		return nil, fmt.Errorf("no OpenAPI versions provided")
    70  	}
    71  	if config == nil {
    72  		config = &defaultValidatorConfig
    73  	}
    74  	if config.ServerURL != "" {
    75  		serverURL, err := url.Parse(config.ServerURL)
    76  		if err != nil {
    77  			return nil, fmt.Errorf("invalid ServerURL: %w", err)
    78  		}
    79  		switch serverURL.Scheme {
    80  		case "http", "https":
    81  		case "":
    82  			return nil, errors.New("invalid ServerURL: missing scheme")
    83  		default:
    84  			return nil, fmt.Errorf(
    85  				"invalid ServerURL: unsupported scheme %q (did you forget to specify the scheme://?)",
    86  				serverURL.Scheme,
    87  			)
    88  		}
    89  	}
    90  	if config.VersionError == nil {
    91  		config.VersionError = DefaultVersionError
    92  	}
    93  	v := &Validator{
    94  		validators: map[vervet.Version]*openapi3filter.Validator{},
    95  		errFunc:    config.VersionError,
    96  		today:      today,
    97  	}
    98  	serviceVersions := make(vervet.VersionSlice, len(docs))
    99  	for i := range docs {
   100  		if config.ServerURL != "" {
   101  			docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}}
   102  		}
   103  		versionStr, err := vervet.ExtensionString(docs[i].Extensions, vervet.ExtSnykApiVersion)
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  		version, err := vervet.ParseVersion(versionStr)
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  		serviceVersions[i] = version
   112  		router, err := gorillamux.NewRouter(docs[i])
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		v.validators[version] = openapi3filter.NewValidator(router, config.Options...)
   117  	}
   118  	v.versions = vervet.NewVersionIndex(serviceVersions)
   119  	return v, nil
   120  }
   121  
   122  // Middleware returns an http.Handler which wraps the given handler with
   123  // request and response validation according to the requested API version.
   124  func (v *Validator) Middleware(h http.Handler) http.Handler {
   125  	handlers := map[vervet.Version]http.Handler{}
   126  	for version, validator := range v.validators {
   127  		handlers[version] = validator.Middleware(h)
   128  	}
   129  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   130  		versionParam := req.URL.Query().Get("version")
   131  		if versionParam == "" {
   132  			v.errFunc(w, req, http.StatusBadRequest, errors.New("missing required query parameter 'version'"))
   133  			return
   134  		}
   135  		requested, err := vervet.ParseVersion(versionParam)
   136  		if err != nil {
   137  			v.errFunc(w, req, http.StatusBadRequest, err)
   138  			return
   139  		}
   140  		if t := v.today(); requested.Date.After(t) {
   141  			v.errFunc(w, req, http.StatusBadRequest,
   142  				fmt.Errorf("requested version newer than present date %s", t))
   143  			return
   144  		}
   145  		resolvedVersion, err := v.versions.Resolve(requested)
   146  		if err != nil {
   147  			v.errFunc(w, req, http.StatusNotFound, err)
   148  			return
   149  		}
   150  		h, ok := handlers[resolvedVersion]
   151  		if !ok {
   152  			// Crash noisily, as this indicates a serious bug. Should not
   153  			// happen if we've initialized our version maps correctly.
   154  			panic(fmt.Sprintf("missing validator for version %q", resolvedVersion))
   155  		}
   156  		h.ServeHTTP(w, req)
   157  	})
   158  }