github.com/m3db/m3@v1.5.0/src/x/net/http/cors/cors.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 // Derived from https://github.com/etcd-io/etcd/tree/v3.2.10/pkg/cors under 22 // http://www.apache.org/licenses/LICENSE-2.0#redistribution . 23 // See https://github.com/m3db/m3/blob/master/NOTICES.txt for the original copyright. 24 25 // Package cors handles cross-origin HTTP requests (CORS). 26 package cors 27 28 import ( 29 "fmt" 30 "net/http" 31 "net/url" 32 "sort" 33 "strings" 34 ) 35 36 // Info represents a set of allowed origins. 37 type Info map[string]bool 38 39 // Set implements the flag.Value interface to allow users to define a list of CORS origins 40 func (ci *Info) Set(s string) error { 41 m := make(map[string]bool) 42 for _, v := range strings.Split(s, ",") { 43 v = strings.TrimSpace(v) 44 if v == "" { 45 continue 46 } 47 if v != "*" { 48 if _, err := url.Parse(v); err != nil { 49 return fmt.Errorf("Invalid CORS origin: %s", err) 50 } 51 } 52 m[v] = true 53 54 } 55 *ci = Info(m) 56 return nil 57 } 58 59 func (ci *Info) String() string { 60 o := make([]string, 0) 61 for k := range *ci { 62 o = append(o, k) 63 } 64 sort.StringSlice(o).Sort() 65 return strings.Join(o, ",") 66 } 67 68 // OriginAllowed determines whether the server will allow a given CORS origin. 69 func (ci Info) OriginAllowed(origin string) bool { 70 return ci["*"] || ci[origin] 71 } 72 73 // Handler wraps an http.Handler instance to provide configurable CORS support. CORS headers will be added to all 74 // responses. 75 type Handler struct { 76 Handler http.Handler 77 Info *Info 78 } 79 80 // addHeader adds the correct cors headers given an origin 81 func (h *Handler) addHeader(w http.ResponseWriter, origin string) { 82 w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") 83 w.Header().Add("Access-Control-Allow-Origin", origin) 84 w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization") 85 } 86 87 // ServeHTTP adds the correct CORS headers based on the origin and returns immediately 88 // with a 200 OK if the method is OPTIONS. 89 func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { 90 // Write CORS header. 91 if h.Info.OriginAllowed("*") { 92 h.addHeader(w, "*") 93 } else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) { 94 h.addHeader(w, origin) 95 } 96 97 if req.Method == "OPTIONS" { 98 w.WriteHeader(http.StatusOK) 99 return 100 } 101 102 h.Handler.ServeHTTP(w, req) 103 }