github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/transports/middlewares/cors/cors.go (about)

     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   * 	http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    17  
    18  package cors
    19  
    20  import (
    21  	"bytes"
    22  	"github.com/aacfactory/errors"
    23  	"github.com/aacfactory/fns/commons/bytex"
    24  	"github.com/aacfactory/fns/commons/wildcard"
    25  	"github.com/aacfactory/fns/transports"
    26  	"net/http"
    27  	"slices"
    28  	"strconv"
    29  	"strings"
    30  )
    31  
    32  func New() transports.Middleware {
    33  	return &corsMiddleware{}
    34  }
    35  
    36  type corsMiddleware struct {
    37  	allowedOrigins      [][]byte
    38  	allowedWOrigins     []*wildcard.Wildcard
    39  	allowedOriginsAll   bool
    40  	allowedHeaders      [][]byte
    41  	allowedHeadersAll   bool
    42  	allowedMethods      [][]byte
    43  	exposedHeaders      [][]byte
    44  	maxAge              int
    45  	allowCredentials    bool
    46  	allowPrivateNetwork bool
    47  	preflightVary       [][]byte
    48  	handler             transports.Handler
    49  }
    50  
    51  func (c *corsMiddleware) Name() string {
    52  	return "cors"
    53  }
    54  
    55  func (c *corsMiddleware) Construct(options transports.MiddlewareOptions) (err error) {
    56  	config := Config{}
    57  	err = options.Config.As(&config)
    58  	if err != nil {
    59  		err = errors.Warning("fns: build cors middleware failed").WithCause(err)
    60  		return
    61  	}
    62  	allowedOrigins := make([][]byte, 0, 1)
    63  	allowedWOrigins := make([]*wildcard.Wildcard, 0, 1)
    64  	allowedOriginsAll := false
    65  	if config.AllowedHeaders == nil {
    66  		config.AllowedHeaders = make([]string, 0, 1)
    67  	}
    68  	if len(config.AllowedHeaders) == 0 || config.AllowedHeaders[0] != "*" {
    69  		defaultAllowedHeaders := []string{
    70  			string(transports.OriginHeaderName), string(transports.AcceptHeaderName), string(transports.ContentTypeHeaderName),
    71  			string(transports.AcceptEncodingHeaderName),
    72  			string(transports.XRequestedWithHeaderName),
    73  			string(transports.ConnectionHeaderName), string(transports.UpgradeHeaderName),
    74  			string(transports.XForwardedForHeaderName), string(transports.TrueClientIpHeaderName), string(transports.XRealIpHeaderName),
    75  			string(transports.DeviceIpHeaderName), string(transports.DeviceIdHeaderName),
    76  			string(transports.RequestIdHeaderName),
    77  			string(transports.RequestTimeoutHeaderName), string(transports.RequestVersionsHeaderName),
    78  			string(transports.CacheControlHeaderIfNonMatch), string(transports.CacheControlHeaderName),
    79  			string(transports.SignatureHeaderName),
    80  		}
    81  		for _, header := range defaultAllowedHeaders {
    82  			if !slices.Contains(config.AllowedHeaders, header) {
    83  				config.AllowedHeaders = append(config.AllowedHeaders, header)
    84  			}
    85  		}
    86  	}
    87  	if len(config.AllowedOrigins) == 0 {
    88  		config.AllowedOrigins = []string{"*"}
    89  	}
    90  	for _, origin := range config.AllowedOrigins {
    91  		origin = strings.ToLower(origin)
    92  		if origin == "*" {
    93  			allowedOriginsAll = true
    94  			allowedOrigins = nil
    95  			allowedWOrigins = nil
    96  			break
    97  		} else if i := strings.IndexByte(origin, '*'); i >= 0 {
    98  			w := wildcard.New(bytex.FromString(origin))
    99  			allowedWOrigins = append(allowedWOrigins, w)
   100  		} else {
   101  			allowedOrigins = append(allowedOrigins, bytex.FromString(origin))
   102  		}
   103  	}
   104  	allowedHeadersAll := false
   105  	allowedHeaders := make([][]byte, 0, 1)
   106  	for _, header := range config.AllowedHeaders {
   107  		allowedHeaders = append(allowedHeaders, bytex.FromString(header))
   108  	}
   109  	allowedHeaders = convert(allowedHeaders, http.CanonicalHeaderKey)
   110  	for _, h := range config.AllowedHeaders {
   111  		if h == "*" {
   112  			allowedHeadersAll = true
   113  			allowedHeaders = nil
   114  			break
   115  		}
   116  	}
   117  
   118  	exposedHeaders := make([][]byte, 0, 1)
   119  	if config.ExposedHeaders == nil {
   120  		config.ExposedHeaders = make([]string, 0, 1)
   121  	}
   122  	defaultExposedHeaders := []string{
   123  		string(transports.VaryHeaderName),
   124  		string(transports.DeviceIdHeaderName),
   125  		string(transports.EndpointIdHeaderName), string(transports.EndpointVersionHeaderName),
   126  		string(transports.ContentEncodingHeaderName),
   127  		string(transports.RequestIdHeaderName), string(transports.HandleLatencyHeaderName),
   128  		string(transports.CacheControlHeaderName), string(transports.ETagHeaderName), string(transports.ClearSiteDataHeaderName), string(transports.AgeHeaderName),
   129  		string(transports.ResponseRetryAfterHeaderName), string(transports.SignatureHeaderName),
   130  		string(transports.DeprecatedHeaderName),
   131  	}
   132  	for _, header := range defaultExposedHeaders {
   133  		if !slices.Contains(config.ExposedHeaders, header) {
   134  			config.ExposedHeaders = append(config.ExposedHeaders, header)
   135  		}
   136  	}
   137  	for _, header := range config.ExposedHeaders {
   138  		exposedHeaders = append(exposedHeaders, bytex.FromString(header))
   139  	}
   140  	exposedHeaders = convert(exposedHeaders, http.CanonicalHeaderKey)
   141  
   142  	c.allowedOrigins = allowedOrigins
   143  	c.allowedWOrigins = allowedWOrigins
   144  	c.allowedOriginsAll = allowedOriginsAll
   145  	c.allowedHeaders = allowedHeaders
   146  	c.allowedHeadersAll = allowedHeadersAll
   147  	c.allowedMethods = [][]byte{methodGet, methodPost, methodHead}
   148  	c.exposedHeaders = exposedHeaders
   149  	c.maxAge = config.MaxAge
   150  	c.allowCredentials = config.AllowCredentials
   151  	c.allowPrivateNetwork = config.AllowPrivateNetwork
   152  
   153  	if c.allowPrivateNetwork {
   154  		c.preflightVary = [][]byte{[]byte("Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network")}
   155  	} else {
   156  		c.preflightVary = [][]byte{[]byte("Origin, Access-Control-Request-Method, Access-Control-Request-Headers")}
   157  	}
   158  	return
   159  }
   160  
   161  func (c *corsMiddleware) Handler(next transports.Handler) transports.Handler {
   162  	c.handler = next
   163  	return c
   164  }
   165  
   166  func (c *corsMiddleware) Close() (err error) {
   167  	return
   168  }
   169  
   170  func (c *corsMiddleware) Handle(w transports.ResponseWriter, r transports.Request) {
   171  	if bytes.Equal(r.Method(), methodOptions) && len(r.Header().Get(accessControlRequestMethodHeader)) > 0 {
   172  		c.handlePreflight(w, r)
   173  		w.SetStatus(http.StatusNoContent)
   174  	} else {
   175  		c.handleActualRequest(w, r)
   176  		c.handler.Handle(w, r)
   177  	}
   178  }
   179  
   180  func (c *corsMiddleware) handlePreflight(w transports.ResponseWriter, r transports.Request) {
   181  	headers := w.Header()
   182  	origin := r.Header().Get(originHeader)
   183  
   184  	if !bytes.Equal(r.Method(), methodOptions) {
   185  		return
   186  	}
   187  
   188  	if vary := headers.Get(varyHeader); len(vary) > 0 {
   189  		headers.Add(varyHeader, c.preflightVary[0])
   190  	} else {
   191  		for _, preflightVary := range c.preflightVary {
   192  			headers.Add(varyHeader, preflightVary)
   193  		}
   194  	}
   195  
   196  	if len(origin) == 0 {
   197  		return
   198  	}
   199  	if !c.isOriginAllowed(origin) {
   200  		return
   201  	}
   202  
   203  	reqMethod := r.Header().Get(accessControlRequestMethodHeader)
   204  	if !c.isMethodAllowed(reqMethod) {
   205  		return
   206  	}
   207  	reqHeadersRaw := r.Header().Values(accessControlRequestHeadersHeader)
   208  	reqHeaders, reqHeadersEdited := parseHeaderList(reqHeadersRaw)
   209  	if !c.areHeadersAllowed(reqHeaders) {
   210  		return
   211  	}
   212  	if c.allowedOriginsAll {
   213  		headers.Set(accessControlAllowOriginHeader, all)
   214  	} else {
   215  		origins := w.Header().Values(originHeader)
   216  		for _, ori := range origins {
   217  			headers.Add(accessControlAllowOriginHeader, ori)
   218  		}
   219  	}
   220  	headers.Set(accessControlAllowMethodsHeader, bytes.ToUpper(reqMethod))
   221  	if len(reqHeaders) > 0 {
   222  		if reqHeadersEdited || len(reqHeaders) != len(reqHeadersRaw) {
   223  			headers.Set(accessControlAllowHeadersHeader, bytes.Join(reqHeaders, joinBytes))
   224  		} else {
   225  			for _, raw := range reqHeadersRaw {
   226  				headers.Add(accessControlAllowHeadersHeader, raw)
   227  			}
   228  		}
   229  	}
   230  	if c.allowCredentials {
   231  		headers.Set(accessControlAllowCredentialsHeader, trueBytes)
   232  	}
   233  
   234  	if c.allowPrivateNetwork && bytes.Equal(r.Header().Get(accessControlRequestPrivateNetworkHeader), trueBytes) {
   235  		headers.Set(accessControlAllowPrivateNetworkHeader, trueBytes)
   236  	}
   237  
   238  	if c.maxAge > 0 {
   239  		headers.Set(accessControlMaxAgeHeader, bytex.FromString(strconv.Itoa(c.maxAge)))
   240  	}
   241  }
   242  
   243  func (c *corsMiddleware) handleActualRequest(w transports.ResponseWriter, r transports.Request) {
   244  	headers := w.Header()
   245  	origin := r.Header().Get(originHeader)
   246  
   247  	if len(origin) == 0 {
   248  		return
   249  	}
   250  	if !c.isOriginAllowed(origin) {
   251  		return
   252  	}
   253  
   254  	if !c.isMethodAllowed(r.Method()) {
   255  		return
   256  	}
   257  	if c.allowedOriginsAll {
   258  		headers.Set(accessControlAllowOriginHeader, all)
   259  	} else {
   260  		origins := w.Header().Values(originHeader)
   261  		for _, ori := range origins {
   262  			headers.Add(accessControlAllowOriginHeader, ori)
   263  		}
   264  	}
   265  	if len(c.exposedHeaders) > 0 {
   266  		for _, exposedHeader := range c.exposedHeaders {
   267  			headers.Add(accessControlExposeHeadersHeader, exposedHeader)
   268  		}
   269  	}
   270  	if c.allowCredentials {
   271  		headers.Set(accessControlAllowCredentialsHeader, trueBytes)
   272  	}
   273  }
   274  
   275  func (c *corsMiddleware) isOriginAllowed(origin []byte) bool {
   276  	if c.allowedOriginsAll {
   277  		return true
   278  	}
   279  	origin = bytes.ToLower(origin)
   280  	for _, o := range c.allowedOrigins {
   281  		if bytes.Equal(o, origin) {
   282  			return true
   283  		}
   284  	}
   285  	for _, w := range c.allowedWOrigins {
   286  		if w.Match(origin) {
   287  			return true
   288  		}
   289  	}
   290  	return false
   291  }
   292  
   293  func (c *corsMiddleware) isMethodAllowed(method []byte) bool {
   294  	if len(c.allowedMethods) == 0 {
   295  		return false
   296  	}
   297  	ms := bytes.ToUpper(method)
   298  	if bytes.Equal(ms, methodOptions) {
   299  		return true
   300  	}
   301  	for _, m := range c.allowedMethods {
   302  		if bytes.Equal(ms, m) {
   303  			return true
   304  		}
   305  	}
   306  	return false
   307  }
   308  
   309  func (c *corsMiddleware) areHeadersAllowed(requestedHeaders [][]byte) bool {
   310  	if c.allowedHeadersAll || len(requestedHeaders) == 0 {
   311  		return true
   312  	}
   313  	for _, header := range requestedHeaders {
   314  		hs := bytex.FromString(http.CanonicalHeaderKey(bytex.ToString(header)))
   315  		found := false
   316  		for _, h := range c.allowedHeaders {
   317  			if bytes.Equal(hs, h) {
   318  				found = true
   319  				break
   320  			}
   321  			if bytes.Index(hs, transports.UserHeaderNamePrefix) == 0 {
   322  				found = true
   323  				break
   324  			}
   325  		}
   326  		if !found {
   327  			return false
   328  		}
   329  	}
   330  	return true
   331  }