trpc.group/trpc-go/trpc-go@v1.0.3/restful/serialize_proto.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"errors"
    18  	"reflect"
    19  
    20  	"google.golang.org/protobuf/proto"
    21  )
    22  
    23  func init() {
    24  	RegisterSerializer(&ProtoSerializer{})
    25  }
    26  
    27  var (
    28  	errNotProtoMessageType = errors.New("type is not proto.Message")
    29  )
    30  
    31  // ProtoSerializer is used for content-Type: application/octet-stream.
    32  type ProtoSerializer struct{}
    33  
    34  // Marshal implements Serializer.
    35  func (*ProtoSerializer) Marshal(v interface{}) ([]byte, error) {
    36  	msg, ok := assertProtoMessage(v)
    37  	if !ok {
    38  		return nil, errNotProtoMessageType
    39  	}
    40  	return proto.Marshal(msg)
    41  }
    42  
    43  // Unmarshal implements Serializer.
    44  func (*ProtoSerializer) Unmarshal(data []byte, v interface{}) error {
    45  	msg, ok := assertProtoMessage(v)
    46  	if !ok {
    47  		return errNotProtoMessageType
    48  	}
    49  	return proto.Unmarshal(data, msg)
    50  }
    51  
    52  // assertProtoMessage asserts the type of the input is proto.Message
    53  // or a chain of pointers to proto.Message.
    54  func assertProtoMessage(v interface{}) (proto.Message, bool) {
    55  	msg, ok := v.(proto.Message)
    56  	if ok {
    57  		return msg, true
    58  	}
    59  	// proto reflection
    60  	rv := reflect.ValueOf(v)
    61  	// get the value
    62  	for rv.Kind() == reflect.Ptr {
    63  		if rv.IsNil() { // if the pointer points to nil,New an object
    64  			rv.Set(reflect.New(rv.Type().Elem()))
    65  		}
    66  		// if the type is proto message,return it
    67  		if msg, ok := rv.Interface().(proto.Message); ok {
    68  			return msg, true
    69  		}
    70  		rv = rv.Elem()
    71  	}
    72  	return nil, false
    73  }
    74  
    75  // Name implements Serializer.
    76  func (*ProtoSerializer) Name() string {
    77  	return "application/octet-stream"
    78  }
    79  
    80  // ContentType implements Serializer.
    81  func (*ProtoSerializer) ContentType() string {
    82  	return "application/octet-stream"
    83  }