github.com/snyk/vervet/v3@v3.7.0/versionware/validator.go (about) 1 package versionware 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/url" 7 "sort" 8 "time" 9 10 "github.com/getkin/kin-openapi/openapi3" 11 "github.com/getkin/kin-openapi/openapi3filter" 12 "github.com/getkin/kin-openapi/routers/gorillamux" 13 14 "github.com/snyk/vervet/v3" 15 ) 16 17 // Validator provides versioned OpenAPI validation middleware for HTTP requests 18 // and responses. 19 type Validator struct { 20 versions vervet.VersionSlice 21 validators []*openapi3filter.Validator 22 errFunc VersionErrorHandler 23 today func() time.Time 24 } 25 26 // ValidatorConfig defines how a new Validator may be configured. 27 type ValidatorConfig struct { 28 // ServerURL overrides the server URLs in the given OpenAPI specs to match 29 // the URL of requests reaching the backend service. If unset, requests 30 // must match the servers defined in OpenAPI specs. 31 ServerURL string 32 33 // VersionError is called on any error that occurs when trying to resolve the 34 // API version. 35 VersionError VersionErrorHandler 36 37 // Options further configure the request and response validation. See 38 // https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3filter#ValidatorOption 39 // for available options. 40 Options []openapi3filter.ValidatorOption 41 } 42 43 var defaultValidatorConfig = ValidatorConfig{ 44 VersionError: DefaultVersionError, 45 Options: []openapi3filter.ValidatorOption{ 46 openapi3filter.OnErr(func(w http.ResponseWriter, status int, code openapi3filter.ErrCode, _ error) { 47 statusText := http.StatusText(http.StatusInternalServerError) 48 switch code { 49 case openapi3filter.ErrCodeCannotFindRoute: 50 statusText = "Not Found" 51 case openapi3filter.ErrCodeRequestInvalid: 52 statusText = "Bad Request" 53 } 54 http.Error(w, statusText, status) 55 }), 56 }, 57 } 58 59 func today() time.Time { 60 now := time.Now().UTC() 61 return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) 62 } 63 64 // NewValidator returns a new validation middleware, which validates versioned 65 // requests according to the given OpenAPI spec versions. For configuration 66 // defaults, a nil config may be used. 67 func NewValidator(config *ValidatorConfig, docs ...*openapi3.T) (*Validator, error) { 68 if len(docs) == 0 { 69 return nil, fmt.Errorf("no OpenAPI versions provided") 70 } 71 if config == nil { 72 config = &defaultValidatorConfig 73 } 74 if config.ServerURL != "" { 75 serverURL, err := url.Parse(config.ServerURL) 76 if err != nil { 77 return nil, fmt.Errorf("invalid ServerURL: %w", err) 78 } 79 switch serverURL.Scheme { 80 case "http", "https": 81 case "": 82 return nil, fmt.Errorf("invalid ServerURL: missing scheme") 83 default: 84 return nil, fmt.Errorf("invalid ServerURL: unsupported scheme %q (did you forget to specify the scheme://?)", serverURL.Scheme) 85 } 86 for i := range docs { 87 docs[i].Servers = []*openapi3.Server{{URL: serverURL.String()}} 88 } 89 } 90 if config.VersionError == nil { 91 config.VersionError = DefaultVersionError 92 } 93 v := &Validator{ 94 versions: make([]vervet.Version, len(docs)), 95 validators: make([]*openapi3filter.Validator, len(docs)), 96 errFunc: config.VersionError, 97 today: today, 98 } 99 validatorVersions := map[string]*openapi3filter.Validator{} 100 for i := range docs { 101 if config.ServerURL != "" { 102 docs[i].Servers = []*openapi3.Server{{URL: config.ServerURL}} 103 } 104 versionStr, err := vervet.ExtensionString(docs[i].ExtensionProps, vervet.ExtSnykApiVersion) 105 if err != nil { 106 return nil, err 107 } 108 version, err := vervet.ParseVersion(versionStr) 109 if err != nil { 110 return nil, err 111 } 112 v.versions[i] = *version 113 router, err := gorillamux.NewRouter(docs[i]) 114 if err != nil { 115 return nil, err 116 } 117 validatorVersions[version.String()] = openapi3filter.NewValidator(router, config.Options...) 118 } 119 sort.Sort(v.versions) 120 for i := range v.versions { 121 v.validators[i] = validatorVersions[v.versions[i].String()] 122 } 123 return v, nil 124 } 125 126 // Middleware returns an http.Handler which wraps the given handler with 127 // request and response validation according to the requested API version. 128 func (v *Validator) Middleware(h http.Handler) http.Handler { 129 handlers := make([]http.Handler, len(v.validators)) 130 for i := range v.versions { 131 handlers[i] = v.validators[i].Middleware(h) 132 } 133 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 134 versionParam := req.URL.Query().Get("version") 135 if versionParam == "" { 136 v.errFunc(w, req, http.StatusBadRequest, fmt.Errorf("missing required query parameter 'version'")) 137 return 138 } 139 requested, err := vervet.ParseVersion(versionParam) 140 if err != nil { 141 v.errFunc(w, req, http.StatusBadRequest, err) 142 return 143 } 144 if t := v.today(); requested.Date.After(t) { 145 v.errFunc(w, req, http.StatusBadRequest, 146 fmt.Errorf("requested version newer than present date %s", t)) 147 return 148 } 149 resolvedIndex, err := v.versions.ResolveIndex(*requested) 150 if err != nil { 151 v.errFunc(w, req, http.StatusNotFound, err) 152 return 153 } 154 handlers[resolvedIndex].ServeHTTP(w, req) 155 }) 156 }