github.com/snyk/vervet/v5@v5.11.1-0.20240202085829-ad4dd7fb6101/versionware/validator.go (about) 1 package versionware 2 3 import ( 4 "errors" 5 "fmt" 6 "net/http" 7 "net/url" 8 "time" 9 10 "github.com/getkin/kin-openapi/openapi3" 11 "github.com/getkin/kin-openapi/openapi3filter" 12 "github.com/getkin/kin-openapi/routers/gorillamux" 13 14 "github.com/snyk/vervet/v5" 15 ) 16 17 // Validator provides versioned OpenAPI validation middleware for HTTP requests 18 // and responses. 19 type Validator struct { 20 versions vervet.VersionIndex 21 validators map[vervet.Version]*openapi3filter.Validator 22 errFunc VersionErrorHandler 23 today func() time.Time 24 } 25 26 // ValidatorConfig defines how a new Validator may be configured. 27 type ValidatorConfig struct { 28 // ServerURL overrides the server URLs in the given OpenAPI specs to match 29 // the URL of requests reaching the backend service. If unset, requests 30 // must match the servers defined in OpenAPI specs. 31 ServerURL string 32 33 // VersionError is called on any error that occurs when trying to resolve the 34 // API version. 35 VersionError VersionErrorHandler 36 37 // Options further configure the request and response validation. See 38 // https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#ValidatorOption 39 // for available options. 40 Options []openapi3filter.ValidatorOption 41 } 42 43 var defaultValidatorConfig = ValidatorConfig{ 44 VersionError: DefaultVersionError, 45 Options: []openapi3filter.ValidatorOption{ 46 openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, _ error) { 47 statusText := http.StatusText(http.StatusInternalServerError) 48 switch code { 49 case openapi3filter.ErrCodeCannotFindRoute: 50 statusText = "Not Found" 51 case openapi3filter.ErrCodeRequestInvalid: 52 statusText = "Bad Request" 53 } 54 http.Error(w, statusText, status) 55 }), 56 }, 57 } 58 59 func today() time.Time { 60 now := time.Now().UTC() 61 return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) 62 } 63 64 // NewValidator returns a new validation middleware, which validates versioned 65 // requests according to the given OpenAPI spec versions. For configuration 66 // defaults, a nil config may be used. 67 func NewValidator(config *ValidatorConfig, docs ...*openapi3.T) (*Validator, error) { 68 if len(docs) == 0 { 69 return nil, fmt.Errorf("no OpenAPI versions provided") 70 } 71 if config == nil { 72 config = &defaultValidatorConfig 73 } 74 if config.ServerURL != "" { 75 serverURL, err := url.Parse(config.ServerURL) 76 if err != nil { 77 return nil, fmt.Errorf("invalid ServerURL: %w", err) 78 } 79 switch serverURL.Scheme { 80 case "http", "https": 81 case "": 82 return nil, errors.New("invalid ServerURL: missing scheme") 83 default: 84 return nil, fmt.Errorf( 85 "invalid ServerURL: unsupported scheme %q (did you forget to specify the scheme://?)", 86 serverURL.Scheme, 87 ) 88 } 89 } 90 if config.VersionError == nil { 91 config.VersionError = DefaultVersionError 92 } 93 v := &Validator{ 94 validators: map[vervet.Version]*openapi3filter.Validator{}, 95 errFunc: config.VersionError, 96 today: today, 97 } 98 serviceVersions := make(vervet.VersionSlice, len(docs)) 99 for i := range docs { 100 if config.ServerURL != "" { 101 docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}} 102 } 103 versionStr, err := vervet.ExtensionString(docs[i].Extensions, vervet.ExtSnykApiVersion) 104 if err != nil { 105 return nil, err 106 } 107 version, err := vervet.ParseVersion(versionStr) 108 if err != nil { 109 return nil, err 110 } 111 serviceVersions[i] = version 112 router, err := gorillamux.NewRouter(docs[i]) 113 if err != nil { 114 return nil, err 115 } 116 v.validators[version] = openapi3filter.NewValidator(router, config.Options...) 117 } 118 v.versions = vervet.NewVersionIndex(serviceVersions) 119 return v, nil 120 } 121 122 // Middleware returns an http.Handler which wraps the given handler with 123 // request and response validation according to the requested API version. 124 func (v *Validator) Middleware(h http.Handler) http.Handler { 125 handlers := map[vervet.Version]http.Handler{} 126 for version, validator := range v.validators { 127 handlers[version] = validator.Middleware(h) 128 } 129 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 130 versionParam := req.URL.Query().Get("version") 131 if versionParam == "" { 132 v.errFunc(w, req, http.StatusBadRequest, errors.New("missing required query parameter 'version'")) 133 return 134 } 135 requested, err := vervet.ParseVersion(versionParam) 136 if err != nil { 137 v.errFunc(w, req, http.StatusBadRequest, err) 138 return 139 } 140 if t := v.today(); requested.Date.After(t) { 141 v.errFunc(w, req, http.StatusBadRequest, 142 fmt.Errorf("requested version newer than present date %s", t)) 143 return 144 } 145 resolvedVersion, err := v.versions.Resolve(requested) 146 if err != nil { 147 v.errFunc(w, req, http.StatusNotFound, err) 148 return 149 } 150 h, ok := handlers[resolvedVersion] 151 if !ok { 152 // Crash noisily, as this indicates a serious bug. Should not 153 // happen if we've initialized our version maps correctly. 154 panic(fmt.Sprintf("missing validator for version %q", resolvedVersion)) 155 } 156 h.ServeHTTP(w, req) 157 }) 158 }