github.com/TBD54566975/ftl@v0.219.0/internal/rpc/headers/headers.go (about)

     1  package headers
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  
     7  	"github.com/alecthomas/types/optional"
     8  
     9  	"github.com/TBD54566975/ftl/backend/schema"
    10  	"github.com/TBD54566975/ftl/internal/model"
    11  )
    12  
    13  // Headers used by the internal RPC system.
    14  const (
    15  	DirectRoutingHeader = "FTL-Direct"
    16  	// VerbHeader is the header used to pass the module.verb of the current request.
    17  	//
    18  	// One header will be present for each hop in the request path.
    19  	VerbHeader = "FTL-Verb"
    20  	// RequestIDHeader is the header used to pass the inbound request ID.
    21  	RequestIDHeader = "FTL-Request-ID"
    22  )
    23  
    24  func IsDirectRouted(header http.Header) bool {
    25  	return header.Get(DirectRoutingHeader) != ""
    26  }
    27  
    28  func SetDirectRouted(header http.Header) {
    29  	header.Set(DirectRoutingHeader, "1")
    30  }
    31  
    32  func SetRequestKey(header http.Header, key model.RequestKey) {
    33  	header.Set(RequestIDHeader, key.String())
    34  }
    35  
    36  // GetRequestKey from an incoming request.
    37  //
    38  // Will return ("", false, nil) if no request key is present.
    39  func GetRequestKey(header http.Header) (model.RequestKey, bool, error) {
    40  	keyStr := header.Get(RequestIDHeader)
    41  	if keyStr == "" {
    42  		return model.RequestKey{}, false, nil
    43  	}
    44  
    45  	key, err := model.ParseRequestKey(keyStr)
    46  	if err != nil {
    47  		return model.RequestKey{}, false, err
    48  	}
    49  	return key, true, nil
    50  }
    51  
    52  // GetCallers history from an incoming request.
    53  func GetCallers(header http.Header) ([]*schema.Ref, error) {
    54  	headers := header.Values(VerbHeader)
    55  	if len(headers) == 0 {
    56  		return nil, nil
    57  	}
    58  	refs := make([]*schema.Ref, len(headers))
    59  	for i, header := range headers {
    60  		ref, err := schema.ParseRef(header)
    61  		if err != nil {
    62  			return nil, fmt.Errorf("invalid %s header %q: %w", VerbHeader, header, err)
    63  		}
    64  		refs[i] = ref
    65  	}
    66  	return refs, nil
    67  }
    68  
    69  // GetCaller returns the module.verb of the caller, if any.
    70  //
    71  // Will return an error if the header is malformed.
    72  func GetCaller(header http.Header) (optional.Option[*schema.Ref], error) {
    73  	headers := header.Values(VerbHeader)
    74  	if len(headers) == 0 {
    75  		return optional.None[*schema.Ref](), nil
    76  	}
    77  	ref, err := schema.ParseRef(headers[len(headers)-1])
    78  	if err != nil {
    79  		return optional.None[*schema.Ref](), err
    80  	}
    81  	return optional.Some(ref), nil
    82  }
    83  
    84  // AddCaller to an outgoing request.
    85  func AddCaller(header http.Header, ref *schema.Ref) {
    86  	refStr := ref.String()
    87  	if values := header.Values(VerbHeader); len(values) > 0 {
    88  		if values[len(values)-1] == refStr {
    89  			return
    90  		}
    91  	}
    92  	header.Add(VerbHeader, refStr)
    93  }
    94  
    95  func SetCallers(header http.Header, refs []*schema.Ref) {
    96  	header.Del(VerbHeader)
    97  	for _, ref := range refs {
    98  		AddCaller(header, ref)
    99  	}
   100  }