github.com/anycable/anycable-go@v1.5.1/server/request_info.go (about)

     1  package server
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  
     9  	nanoid "github.com/matoous/go-nanoid"
    10  )
    11  
    12  const (
    13  	remoteAddrHeader = "REMOTE_ADDR"
    14  )
    15  
    16  type RequestInfo struct {
    17  	UID     string
    18  	URL     string
    19  	Headers *map[string]string
    20  
    21  	anycableHeaders map[string]string
    22  	params          map[string]string
    23  }
    24  
    25  func NewRequestInfo(r *http.Request, extractor HeadersExtractor) (*RequestInfo, error) {
    26  	var headers map[string]string
    27  
    28  	if extractor == nil {
    29  		headers = make(map[string]string)
    30  	} else {
    31  		headers = extractor.FromRequest(r)
    32  	}
    33  
    34  	anycableHeaders := make(map[string]string)
    35  
    36  	// Extract headers prefixed with `X-AnyCable-` from request headers
    37  	for k, v := range r.Header {
    38  		if strings.HasPrefix(strings.ToLower(k), "x-anycable-") {
    39  			anycableHeaders[strings.ToLower(k)] = v[len(v)-1]
    40  		}
    41  	}
    42  
    43  	uid, err := FetchUID(r)
    44  
    45  	if err != nil {
    46  		return nil, errors.New("failed to retrieve connection uid")
    47  	}
    48  
    49  	url := r.URL.String()
    50  
    51  	if !r.URL.IsAbs() {
    52  		// See https://github.com/golang/go/issues/28940#issuecomment-441749380
    53  		scheme := "http://"
    54  		if r.TLS != nil {
    55  			scheme = "https://"
    56  		}
    57  		url = fmt.Sprintf("%s%s%s", scheme, r.Host, url)
    58  	}
    59  
    60  	params := make(map[string]string)
    61  	urlParams := r.URL.Query()
    62  
    63  	for k, v := range urlParams {
    64  		params[k] = v[len(v)-1]
    65  	}
    66  
    67  	return &RequestInfo{UID: uid, Headers: &headers, URL: url, params: params, anycableHeaders: anycableHeaders}, nil
    68  }
    69  
    70  func (i *RequestInfo) Param(key string) string {
    71  	if i.params == nil {
    72  		return ""
    73  	}
    74  
    75  	return i.params[key]
    76  }
    77  
    78  func (i *RequestInfo) AnyCableHeader(key string) string {
    79  	if i.anycableHeaders == nil {
    80  		return ""
    81  	}
    82  
    83  	return i.anycableHeaders[strings.ToLower(key)]
    84  }
    85  
    86  // FetchUID safely extracts uid from `X-Request-ID` header or generates a new one
    87  func FetchUID(r *http.Request) (string, error) {
    88  	requestID := r.Header.Get("X-Request-ID")
    89  	if requestID == "" {
    90  		return nanoid.Nanoid()
    91  	}
    92  
    93  	return requestID, nil
    94  }
    95  
    96  func parseCookies(value string, cookieFilter []string) string {
    97  	if len(cookieFilter) == 0 {
    98  		return value
    99  	}
   100  
   101  	filter := make(map[string]bool)
   102  	for _, cookie := range cookieFilter {
   103  		filter[cookie] = true
   104  	}
   105  
   106  	result := ""
   107  	cookies := strings.Split(value, ";")
   108  	for _, cookie := range cookies {
   109  		parts := strings.Split(cookie, "=")
   110  		if len(parts) != 2 {
   111  			continue
   112  		}
   113  
   114  		if filter[parts[0]] {
   115  			result += cookie + ";"
   116  		}
   117  	}
   118  
   119  	return result
   120  }