github.com/verrazzano/verrazzano@v1.7.0/authproxy/src/cors/cors.go (about)

     1  // Copyright (c) 2023, Oracle and/or its affiliates.
     2  // Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl.
     3  
     4  package cors
     5  
     6  import (
     7  	"fmt"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  )
    12  
    13  var allowedOriginsFunc = allowedOrigins
    14  
    15  func AddCORSHeaders(req *http.Request, rw http.ResponseWriter, ingressHost string) (int, error) {
    16  	var origin string
    17  	var err error
    18  	if origin, err = getOriginHeaderValue(req); err != nil {
    19  		return http.StatusBadRequest, err
    20  	}
    21  	// nothing to do for CORS
    22  	if origin == "" {
    23  		return http.StatusOK, nil
    24  	}
    25  
    26  	allowed := originAllowed(origin, ingressHost)
    27  
    28  	// From https://tools.ietf.org/id/draft-abarth-origin-03.html#server-behavior, if the request Origin is not in
    29  	// allowed origins and the request method is not a safe non state changing (i.e. GET or HEAD), we should
    30  	// abort the request.
    31  	if !allowed && req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
    32  		// TODO forbidden doesn't seem right here, but that's what current authproxy does
    33  		return http.StatusForbidden, fmt.Errorf("Origin %s is not allowed", origin)
    34  	}
    35  
    36  	// Add response headers if it's an allowed origin
    37  	if allowed {
    38  		rw.Header().Set("Access-Control-Allow-Origin", origin)
    39  		rw.Header().Set("Access-Control-Allow-Credentials", "true")
    40  	}
    41  	if req.Method == http.MethodOptions {
    42  		rw.Header().Set("Access-Control-Allow-Headers", "authorization, content-type")
    43  		rw.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, DELETE, OPTIONS, PATCH")
    44  	}
    45  	return http.StatusOK, nil
    46  }
    47  
    48  func getOriginHeaderValue(req *http.Request) (string, error) {
    49  	originValues := req.Header["Origin"]
    50  	if len(originValues) > 1 {
    51  		return "", fmt.Errorf("Origin header must have a single value")
    52  	}
    53  	if len(originValues) == 0 {
    54  		return "", nil
    55  	}
    56  	origin := originValues[0]
    57  	if origin == "*" {
    58  		// not a legit origin, could be intended to trick us into oversharing
    59  		return origin, fmt.Errorf("Invalid Origin header: '*'")
    60  	}
    61  	return origin, nil
    62  }
    63  
    64  // originAllowed validates the origin string and returns true if it is an allowed value
    65  func originAllowed(origin string, ingressHost string) bool {
    66  	// Origin may be set to "null" in private-sensitive contexts as defined by the application,
    67  	// according to https://datatracker.ietf.org/doc/rfc6454
    68  	if origin == "null" {
    69  		return false
    70  	}
    71  
    72  	ingressURL := "https://" + ingressHost
    73  	if origin == ingressURL {
    74  		return true
    75  	}
    76  
    77  	// Check list of allowed origins if provided
    78  	var allowedOriginsStr string
    79  	if allowedOriginsStr = allowedOriginsFunc(); allowedOriginsStr == "" {
    80  		return false
    81  	}
    82  	allowedOrigins := strings.Split(allowedOriginsStr, ",")
    83  	for _, allowed := range allowedOrigins {
    84  		if origin == strings.TrimSpace(allowed) {
    85  			return true
    86  		}
    87  	}
    88  
    89  	return false
    90  }
    91  
    92  // allowedOrigins returns an "allow list" of permitted origins
    93  func allowedOrigins() string {
    94  	return os.Getenv("VZ_API_ALLOWED_ORIGINS")
    95  }