github.com/snyk/vervet/v3@v3.7.0/versionware/validator.go (about)

     1  package versionware
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/url"
     7  	"sort"
     8  	"time"
     9  
    10  	"github.com/getkin/kin-openapi/openapi3"
    11  	"github.com/getkin/kin-openapi/openapi3filter"
    12  	"github.com/getkin/kin-openapi/routers/gorillamux"
    13  
    14  	"github.com/snyk/vervet/v3"
    15  )
    16  
    17  // Validator provides versioned OpenAPI validation middleware for HTTP requests
    18  // and responses.
    19  type Validator struct {
    20  	versions   vervet.VersionSlice
    21  	validators []*openapi3filter.Validator
    22  	errFunc    VersionErrorHandler
    23  	today      func() time.Time
    24  }
    25  
    26  // ValidatorConfig defines how a new Validator may be configured.
    27  type ValidatorConfig struct {
    28  	// ServerURL overrides the server URLs in the given OpenAPI specs to match
    29  	// the URL of requests reaching the backend service. If unset, requests
    30  	// must match the servers defined in OpenAPI specs.
    31  	ServerURL string
    32  
    33  	// VersionError is called on any error that occurs when trying to resolve the
    34  	// API version.
    35  	VersionError VersionErrorHandler
    36  
    37  	// Options further configure the request and response validation. See
    38  	// https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#ValidatorOption
    39  	// for available options.
    40  	Options []openapi3filter.ValidatorOption
    41  }
    42  
    43  var defaultValidatorConfig = ValidatorConfig{
    44  	VersionError: DefaultVersionError,
    45  	Options: []openapi3filter.ValidatorOption{
    46  		openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, _ error) {
    47  			statusText := http.StatusText(http.StatusInternalServerError)
    48  			switch code {
    49  			case openapi3filter.ErrCodeCannotFindRoute:
    50  				statusText = "Not Found"
    51  			case openapi3filter.ErrCodeRequestInvalid:
    52  				statusText = "Bad Request"
    53  			}
    54  			http.Error(w, statusText, status)
    55  		}),
    56  	},
    57  }
    58  
    59  func today() time.Time {
    60  	now := time.Now().UTC()
    61  	return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
    62  }
    63  
    64  // NewValidator returns a new validation middleware, which validates versioned
    65  // requests according to the given OpenAPI spec versions. For configuration
    66  // defaults, a nil config may be used.
    67  func NewValidator(config *ValidatorConfig, docs ...*openapi3.T) (*Validator, error) {
    68  	if len(docs) == 0 {
    69  		return nil, fmt.Errorf("no OpenAPI versions provided")
    70  	}
    71  	if config == nil {
    72  		config = &defaultValidatorConfig
    73  	}
    74  	if config.ServerURL != "" {
    75  		serverURL, err := url.Parse(config.ServerURL)
    76  		if err != nil {
    77  			return nil, fmt.Errorf("invalid ServerURL: %w", err)
    78  		}
    79  		switch serverURL.Scheme {
    80  		case "http", "https":
    81  		case "":
    82  			return nil, fmt.Errorf("invalid ServerURL: missing scheme")
    83  		default:
    84  			return nil, fmt.Errorf("invalid ServerURL: unsupported scheme %q (did you forget to specify the scheme://?)", serverURL.Scheme)
    85  		}
    86  		for i := range docs {
    87  			docs[i].Servers = []*openapi3.Server{{URL: serverURL.String()}}
    88  		}
    89  	}
    90  	if config.VersionError == nil {
    91  		config.VersionError = DefaultVersionError
    92  	}
    93  	v := &Validator{
    94  		versions:   make([]vervet.Version, len(docs)),
    95  		validators: make([]*openapi3filter.Validator, len(docs)),
    96  		errFunc:    config.VersionError,
    97  		today:      today,
    98  	}
    99  	validatorVersions := map[string]*openapi3filter.Validator{}
   100  	for i := range docs {
   101  		if config.ServerURL != "" {
   102  			docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}}
   103  		}
   104  		versionStr, err := vervet.ExtensionString(docs[i].ExtensionProps, vervet.ExtSnykApiVersion)
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  		version, err := vervet.ParseVersion(versionStr)
   109  		if err != nil {
   110  			return nil, err
   111  		}
   112  		v.versions[i] = *version
   113  		router, err := gorillamux.NewRouter(docs[i])
   114  		if err != nil {
   115  			return nil, err
   116  		}
   117  		validatorVersions[version.String()] = openapi3filter.NewValidator(router, config.Options...)
   118  	}
   119  	sort.Sort(v.versions)
   120  	for i := range v.versions {
   121  		v.validators[i] = validatorVersions[v.versions[i].String()]
   122  	}
   123  	return v, nil
   124  }
   125  
   126  // Middleware returns an http.Handler which wraps the given handler with
   127  // request and response validation according to the requested API version.
   128  func (v *Validator) Middleware(h http.Handler) http.Handler {
   129  	handlers := make([]http.Handler, len(v.validators))
   130  	for i := range v.versions {
   131  		handlers[i] = v.validators[i].Middleware(h)
   132  	}
   133  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   134  		versionParam := req.URL.Query().Get("version")
   135  		if versionParam == "" {
   136  			v.errFunc(w, req, http.StatusBadRequest, fmt.Errorf("missing required query parameter 'version'"))
   137  			return
   138  		}
   139  		requested, err := vervet.ParseVersion(versionParam)
   140  		if err != nil {
   141  			v.errFunc(w, req, http.StatusBadRequest, err)
   142  			return
   143  		}
   144  		if t := v.today(); requested.Date.After(t) {
   145  			v.errFunc(w, req, http.StatusBadRequest,
   146  				fmt.Errorf("requested version newer than present date %s", t))
   147  			return
   148  		}
   149  		resolvedIndex, err := v.versions.ResolveIndex(*requested)
   150  		if err != nil {
   151  			v.errFunc(w, req, http.StatusNotFound, err)
   152  			return
   153  		}
   154  		handlers[resolvedIndex].ServeHTTP(w, req)
   155  	})
   156  }