go.uber.org/yarpc@v1.72.1/transport/tchannel/header.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"context"
    25  	"encoding/binary"
    26  	"io"
    27  	"strings"
    28  
    29  	"github.com/uber/tchannel-go"
    30  	"go.uber.org/yarpc/api/transport"
    31  	"go.uber.org/yarpc/transport/tchannel/internal"
    32  	"go.uber.org/yarpc/yarpcerrors"
    33  )
    34  
    35  const (
    36  	/** Response headers **/
    37  
    38  	// ErrorCodeHeaderKey is the response header key for the error code.
    39  	ErrorCodeHeaderKey = "$rpc$-error-code"
    40  	// ErrorNameHeaderKey is the response header key for the error name.
    41  	ErrorNameHeaderKey = "$rpc$-error-name"
    42  	// ErrorMessageHeaderKey is the response header key for the error message.
    43  	ErrorMessageHeaderKey = "$rpc$-error-message"
    44  	// ServiceHeaderKey is the response header key for the respond service
    45  	ServiceHeaderKey = "$rpc$-service"
    46  	// ApplicationErrorNameHeaderKey is the response header key for the application error name.
    47  	ApplicationErrorNameHeaderKey = "$rpc$-application-error-name"
    48  	// ApplicationErrorDetailsHeaderKey is the response header key for the
    49  	// application error details string.
    50  	ApplicationErrorDetailsHeaderKey = "$rpc$-application-error-details"
    51  	// ApplicationErrorCodeHeaderKey is the response header key for the application error code.
    52  	ApplicationErrorCodeHeaderKey = "$rpc$-application-error-code"
    53  
    54  	/** Request headers **/
    55  
    56  	// CallerProcedureHeader is the header key for the procedure of the caller making the request.
    57  	CallerProcedureHeader = "rpc-caller-procedure"
    58  )
    59  
    60  var _reservedHeaderKeys = map[string]struct{}{
    61  	ErrorCodeHeaderKey:               {},
    62  	ErrorNameHeaderKey:               {},
    63  	ErrorMessageHeaderKey:            {},
    64  	ServiceHeaderKey:                 {},
    65  	ApplicationErrorNameHeaderKey:    {},
    66  	ApplicationErrorDetailsHeaderKey: {},
    67  	ApplicationErrorCodeHeaderKey:    {},
    68  	CallerProcedureHeader:            {},
    69  }
    70  
    71  func isReservedHeaderKey(key string) bool {
    72  	_, ok := _reservedHeaderKeys[strings.ToLower(key)]
    73  	return ok
    74  }
    75  
    76  // readRequestHeaders reads headers and baggage from an incoming request.
    77  func readRequestHeaders(
    78  	ctx context.Context,
    79  	format tchannel.Format,
    80  	getReader func() (tchannel.ArgReader, error),
    81  ) (context.Context, transport.Headers, error) {
    82  	headers, err := readHeaders(format, getReader)
    83  	if err != nil {
    84  		return ctx, headers, err
    85  	}
    86  	return ctx, headers, nil
    87  }
    88  
    89  // readHeaders reads headers using the given function to get the arg reader.
    90  //
    91  // This may be used with the Arg2Reader functions on InboundCall and
    92  // OutboundCallResponse.
    93  //
    94  // If the format is JSON, the headers are expected to be JSON encoded.
    95  //
    96  // This function always returns a non-nil Headers object in case of success.
    97  func readHeaders(format tchannel.Format, getReader func() (tchannel.ArgReader, error)) (transport.Headers, error) {
    98  	if format == tchannel.JSON {
    99  		// JSON is special
   100  		var headers map[string]string
   101  		err := tchannel.NewArgReader(getReader()).ReadJSON(&headers)
   102  		return transport.HeadersFromMap(headers), err
   103  	}
   104  
   105  	r, err := getReader()
   106  	if err != nil {
   107  		return transport.Headers{}, err
   108  	}
   109  
   110  	headers, err := decodeHeaders(r)
   111  	if err != nil {
   112  		return headers, err
   113  	}
   114  
   115  	return headers, r.Close()
   116  }
   117  
   118  var emptyMap = map[string]string{}
   119  
   120  // writeHeaders writes the given headers using the given function to get the
   121  // arg writer.
   122  //
   123  // This may be used with the Arg2Writer functions on OutboundCall and
   124  // InboundCallResponse.
   125  //
   126  // If the format is JSON, the headers are JSON encoded.
   127  func writeHeaders(format tchannel.Format, headers map[string]string, tracingBaggage map[string]string, getWriter func() (tchannel.ArgWriter, error)) error {
   128  	merged := mergeHeaders(headers, tracingBaggage)
   129  	if format == tchannel.JSON {
   130  		// JSON is special
   131  		if merged == nil {
   132  			// We want to write "{}", not "null" for empty map.
   133  			merged = emptyMap
   134  		}
   135  		return tchannel.NewArgWriter(getWriter()).WriteJSON(merged)
   136  	}
   137  	return tchannel.NewArgWriter(getWriter()).Write(encodeHeaders(merged))
   138  }
   139  
   140  // mergeHeaders will keep the last value if the same key appears multiple times
   141  func mergeHeaders(m1, m2 map[string]string) map[string]string {
   142  	if len(m1) == 0 {
   143  		return m2
   144  	}
   145  	if len(m2) == 0 {
   146  		return m1
   147  	}
   148  	// merge and return
   149  	merged := make(map[string]string, len(m1)+len(m2))
   150  	for k, v := range m1 {
   151  		merged[k] = v
   152  	}
   153  	for k, v := range m2 {
   154  		merged[k] = v
   155  	}
   156  	return merged
   157  }
   158  
   159  // decodeHeaders decodes headers using the format:
   160  //
   161  // 	nh:2 (k~2 v~2){nh}
   162  func decodeHeaders(r io.Reader) (transport.Headers, error) {
   163  	reader := internal.NewReader(r)
   164  
   165  	count := reader.ReadUint16()
   166  	if count == 0 {
   167  		return transport.Headers{}, reader.Err()
   168  	}
   169  
   170  	headers := transport.NewHeadersWithCapacity(int(count))
   171  	for i := 0; i < int(count) && reader.Err() == nil; i++ {
   172  		k := reader.ReadLen16String()
   173  		v := reader.ReadLen16String()
   174  		headers = headers.With(k, v)
   175  	}
   176  
   177  	return headers, reader.Err()
   178  }
   179  
   180  // headerCallerProcedureToRequest copies callerProcedure from headers to req.CallerProcedure
   181  // and then deletes it from headers.
   182  func headerCallerProcedureToRequest(req *transport.Request, headers *transport.Headers) *transport.Request {
   183  	if callerProcedure, ok := headers.Get(CallerProcedureHeader); ok {
   184  		req.CallerProcedure = callerProcedure
   185  		headers.Del(CallerProcedureHeader)
   186  		return req
   187  	}
   188  	return req
   189  }
   190  
   191  // requestCallerProcedureToHeader add callerProcedure header as an application header.
   192  func requestCallerProcedureToHeader(req *transport.Request, reqHeaders map[string]string) map[string]string {
   193  	if req.CallerProcedure == "" {
   194  		return reqHeaders
   195  	}
   196  
   197  	if reqHeaders == nil {
   198  		reqHeaders = make(map[string]string)
   199  	}
   200  	reqHeaders[CallerProcedureHeader] = req.CallerProcedure
   201  	return reqHeaders
   202  }
   203  
   204  // encodeHeaders encodes headers using the format:
   205  //
   206  // 	nh:2 (k~2 v~2){nh}
   207  func encodeHeaders(hs map[string]string) []byte {
   208  	if len(hs) == 0 {
   209  		return []byte{0, 0} // nh = 2
   210  	}
   211  
   212  	size := 2 // nh:2
   213  	for k, v := range hs {
   214  		size += len(k) + 2 // k~2
   215  		size += len(v) + 2 // v~2
   216  	}
   217  
   218  	out := make([]byte, size)
   219  
   220  	i := 2
   221  	binary.BigEndian.PutUint16(out, uint16(len(hs))) // nh:2
   222  	for k, v := range hs {
   223  		i += _putStr16(k, out[i:]) // k~2
   224  		i += _putStr16(v, out[i:]) // v~2
   225  	}
   226  
   227  	return out
   228  }
   229  
   230  func headerMap(hs transport.Headers, headerCase headerCase) map[string]string {
   231  	switch headerCase {
   232  	case originalHeaderCase:
   233  		return hs.OriginalItems()
   234  	default:
   235  		return hs.Items()
   236  	}
   237  }
   238  
   239  func deleteReservedHeaders(headers transport.Headers) {
   240  	for headerKey := range _reservedHeaderKeys {
   241  		headers.Del(headerKey)
   242  	}
   243  }
   244  
   245  // this check ensures that the service we're issuing a request to is the one
   246  // responding
   247  func validateServiceName(requestService, responseService string) error {
   248  	// an empty service string means that we're talking to an older YARPC
   249  	// TChannel client
   250  	if responseService == "" || requestService == responseService {
   251  		return nil
   252  	}
   253  	return yarpcerrors.InternalErrorf(
   254  		"service name sent from the request does not match the service name "+
   255  			"received in the response: sent %q, got: %q", requestService, responseService)
   256  }
   257  
   258  // _putStr16 writes the bytes `in` into `out` using the encoding `s~2`.
   259  func _putStr16(in string, out []byte) int {
   260  	binary.BigEndian.PutUint16(out, uint16(len(in)))
   261  	return copy(out[2:], in) + 2
   262  }