trpc.group/trpc-go/trpc-go@v1.0.3/restful/serializer.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"net/http"
    18  	"strings"
    19  )
    20  
    21  // Serializer is the interface for http body marshaling/unmarshalling.
    22  type Serializer interface {
    23  	// Marshal marshals the tRPC message itself or a field of it to http body.
    24  	Marshal(v interface{}) ([]byte, error)
    25  	// Unmarshal unmarshalls http body to the tRPC message itself or a field of it.
    26  	Unmarshal(data []byte, v interface{}) error
    27  	// Name returns name of the Serializer.
    28  	Name() string
    29  	// ContentType returns the original media type indicated by Content-Encoding response header.
    30  	ContentType() string
    31  }
    32  
    33  // jsonpb as default
    34  var defaultSerializer Serializer = &JSONPBSerializer{AllowUnmarshalNil: true}
    35  
    36  // serialization related http header
    37  var (
    38  	headerAccept      = http.CanonicalHeaderKey("Accept")
    39  	headerContentType = http.CanonicalHeaderKey("Content-Type")
    40  )
    41  
    42  var serializers = make(map[string]Serializer)
    43  
    44  // RegisterSerializer registers a Serializer.
    45  // This function is not thread-safe, it should only be called in init() function.
    46  func RegisterSerializer(s Serializer) {
    47  	if s == nil || s.Name() == "" {
    48  		panic("tried to register nil or anonymous serializer")
    49  	}
    50  	serializers[s.Name()] = s
    51  }
    52  
    53  // SetDefaultSerializer sets the default Serializer.
    54  // This function is not thread-safe, it should only be called in init() function.
    55  func SetDefaultSerializer(s Serializer) {
    56  	if s == nil || s.Name() == "" {
    57  		panic("tried to set nil or anonymous serializer as the default serializer")
    58  	}
    59  	defaultSerializer = s
    60  }
    61  
    62  // GetSerializer returns a Serializer by its name.
    63  func GetSerializer(name string) Serializer {
    64  	return serializers[name]
    65  }
    66  
    67  // serializerForTranscoding returns inbound/outbound Serializer for transcoding.
    68  func serializerForTranscoding(contentTypes []string, accepts []string) (Serializer, Serializer) {
    69  	var reqSerializer, respSerializer Serializer // neither should be nil
    70  
    71  	// ContentType => Req Serializer
    72  	for _, contentType := range contentTypes {
    73  		if s := getSerializerWithDirectives(contentType); s != nil {
    74  			reqSerializer = s
    75  			break
    76  		}
    77  	}
    78  
    79  	// Accept => Resp Serializer
    80  	for _, accept := range accepts {
    81  		if s := getSerializerWithDirectives(accept); s != nil {
    82  			respSerializer = s
    83  			break
    84  		}
    85  	}
    86  
    87  	if reqSerializer == nil { // use defaultSerializer if reqSerializer is nil
    88  		reqSerializer = defaultSerializer
    89  	}
    90  	if respSerializer == nil { // use reqSerializer if respSerializer is nil
    91  		respSerializer = reqSerializer
    92  	}
    93  
    94  	return reqSerializer, respSerializer
    95  }
    96  
    97  // getSerializerWithDirectives get Serializer by Content-Type or Accept. The name may have directives after ';'.
    98  // All Serializers are considered the same as the one with only one directive "charset=UTF-8".
    99  // Other directives are not supported, and will cause the function to return nil.
   100  func getSerializerWithDirectives(name string) Serializer {
   101  	if s, ok := serializers[name]; ok {
   102  		return s
   103  	}
   104  	pos := strings.Index(name, ";")
   105  	const charsetUTF8 = "charset=utf-8"
   106  	if pos == -1 || strings.ToLower(strings.TrimSpace(name[pos+1:])) != charsetUTF8 {
   107  		return nil
   108  	}
   109  	if s, ok := serializers[name[:pos]]; ok {
   110  		return s
   111  	}
   112  	return nil
   113  }