github.com/astaxie/beego@v1.12.3/plugins/cors/cors.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package cors provides handlers to enable CORS support.
    16  // Usage
    17  //	import (
    18  // 		"github.com/astaxie/beego"
    19  //		"github.com/astaxie/beego/plugins/cors"
    20  // )
    21  //
    22  //	func main() {
    23  //		// CORS for https://foo.* origins, allowing:
    24  //		// - PUT and PATCH methods
    25  //		// - Origin header
    26  //		// - Credentials share
    27  //		beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
    28  //			AllowOrigins:     []string{"https://*.foo.com"},
    29  //			AllowMethods:     []string{"PUT", "PATCH"},
    30  //			AllowHeaders:     []string{"Origin"},
    31  //			ExposeHeaders:    []string{"Content-Length"},
    32  //			AllowCredentials: true,
    33  //		}))
    34  //		beego.Run()
    35  //	}
    36  package cors
    37  
    38  import (
    39  	"net/http"
    40  	"regexp"
    41  	"strconv"
    42  	"strings"
    43  	"time"
    44  
    45  	"github.com/astaxie/beego"
    46  	"github.com/astaxie/beego/context"
    47  )
    48  
    49  const (
    50  	headerAllowOrigin      = "Access-Control-Allow-Origin"
    51  	headerAllowCredentials = "Access-Control-Allow-Credentials"
    52  	headerAllowHeaders     = "Access-Control-Allow-Headers"
    53  	headerAllowMethods     = "Access-Control-Allow-Methods"
    54  	headerExposeHeaders    = "Access-Control-Expose-Headers"
    55  	headerMaxAge           = "Access-Control-Max-Age"
    56  
    57  	headerOrigin         = "Origin"
    58  	headerRequestMethod  = "Access-Control-Request-Method"
    59  	headerRequestHeaders = "Access-Control-Request-Headers"
    60  )
    61  
    62  var (
    63  	defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
    64  	// Regex patterns are generated from AllowOrigins. These are used and generated internally.
    65  	allowOriginPatterns = []string{}
    66  )
    67  
    68  // Options represents Access Control options.
    69  type Options struct {
    70  	// If set, all origins are allowed.
    71  	AllowAllOrigins bool
    72  	// A list of allowed origins. Wild cards and FQDNs are supported.
    73  	AllowOrigins []string
    74  	// If set, allows to share auth credentials such as cookies.
    75  	AllowCredentials bool
    76  	// A list of allowed HTTP methods.
    77  	AllowMethods []string
    78  	// A list of allowed HTTP headers.
    79  	AllowHeaders []string
    80  	// A list of exposed HTTP headers.
    81  	ExposeHeaders []string
    82  	// Max age of the CORS headers.
    83  	MaxAge time.Duration
    84  }
    85  
    86  // Header converts options into CORS headers.
    87  func (o *Options) Header(origin string) (headers map[string]string) {
    88  	headers = make(map[string]string)
    89  	// if origin is not allowed, don't extend the headers
    90  	// with CORS headers.
    91  	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
    92  		return
    93  	}
    94  
    95  	// add allow origin
    96  	if o.AllowAllOrigins {
    97  		headers[headerAllowOrigin] = "*"
    98  	} else {
    99  		headers[headerAllowOrigin] = origin
   100  	}
   101  
   102  	// add allow credentials
   103  	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
   104  
   105  	// add allow methods
   106  	if len(o.AllowMethods) > 0 {
   107  		headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
   108  	}
   109  
   110  	// add allow headers
   111  	if len(o.AllowHeaders) > 0 {
   112  		headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
   113  	}
   114  
   115  	// add exposed header
   116  	if len(o.ExposeHeaders) > 0 {
   117  		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
   118  	}
   119  	// add a max age header
   120  	if o.MaxAge > time.Duration(0) {
   121  		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
   122  	}
   123  	return
   124  }
   125  
   126  // PreflightHeader converts options into CORS headers for a preflight response.
   127  func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
   128  	headers = make(map[string]string)
   129  	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
   130  		return
   131  	}
   132  	// verify if requested method is allowed
   133  	for _, method := range o.AllowMethods {
   134  		if method == rMethod {
   135  			headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
   136  			break
   137  		}
   138  	}
   139  
   140  	// verify if requested headers are allowed
   141  	var allowed []string
   142  	for _, rHeader := range strings.Split(rHeaders, ",") {
   143  		rHeader = strings.TrimSpace(rHeader)
   144  	lookupLoop:
   145  		for _, allowedHeader := range o.AllowHeaders {
   146  			if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
   147  				allowed = append(allowed, rHeader)
   148  				break lookupLoop
   149  			}
   150  		}
   151  	}
   152  
   153  	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
   154  	// add allow origin
   155  	if o.AllowAllOrigins {
   156  		headers[headerAllowOrigin] = "*"
   157  	} else {
   158  		headers[headerAllowOrigin] = origin
   159  	}
   160  
   161  	// add allowed headers
   162  	if len(allowed) > 0 {
   163  		headers[headerAllowHeaders] = strings.Join(allowed, ",")
   164  	}
   165  
   166  	// add exposed headers
   167  	if len(o.ExposeHeaders) > 0 {
   168  		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
   169  	}
   170  	// add a max age header
   171  	if o.MaxAge > time.Duration(0) {
   172  		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
   173  	}
   174  	return
   175  }
   176  
   177  // IsOriginAllowed looks up if the origin matches one of the patterns
   178  // generated from Options.AllowOrigins patterns.
   179  func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
   180  	for _, pattern := range allowOriginPatterns {
   181  		allowed, _ = regexp.MatchString(pattern, origin)
   182  		if allowed {
   183  			return
   184  		}
   185  	}
   186  	return
   187  }
   188  
   189  // Allow enables CORS for requests those match the provided options.
   190  func Allow(opts *Options) beego.FilterFunc {
   191  	// Allow default headers if nothing is specified.
   192  	if len(opts.AllowHeaders) == 0 {
   193  		opts.AllowHeaders = defaultAllowHeaders
   194  	}
   195  
   196  	for _, origin := range opts.AllowOrigins {
   197  		pattern := regexp.QuoteMeta(origin)
   198  		pattern = strings.Replace(pattern, "\\*", ".*", -1)
   199  		pattern = strings.Replace(pattern, "\\?", ".", -1)
   200  		allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
   201  	}
   202  
   203  	return func(ctx *context.Context) {
   204  		var (
   205  			origin           = ctx.Input.Header(headerOrigin)
   206  			requestedMethod  = ctx.Input.Header(headerRequestMethod)
   207  			requestedHeaders = ctx.Input.Header(headerRequestHeaders)
   208  			// additional headers to be added
   209  			// to the response.
   210  			headers map[string]string
   211  		)
   212  
   213  		if ctx.Input.Method() == "OPTIONS" &&
   214  			(requestedMethod != "" || requestedHeaders != "") {
   215  			headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
   216  			for key, value := range headers {
   217  				ctx.Output.Header(key, value)
   218  			}
   219  			ctx.ResponseWriter.WriteHeader(http.StatusOK)
   220  			return
   221  		}
   222  		headers = opts.Header(origin)
   223  
   224  		for key, value := range headers {
   225  			ctx.Output.Header(key, value)
   226  		}
   227  	}
   228  }