github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/protobuf/protobuf.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package protobuf
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  
    24  	"github.com/cloudwego/fastpb"
    25  
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  	"github.com/cloudwego/kitex/pkg/remote/codec"
    28  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    29  )
    30  
    31  /**
    32   *  Kitex Protobuf Protocol
    33   *  |----------Len--------|--------------------------------MetaInfo--------------------------------|
    34   *  |---------4Byte-------|----2Byte----|----2Byte----|---------String-------|---------4Byte-------|
    35   *	+----------------------------------------------------------------------------------------------+
    36   *	|      PayloadLen     |    Magic    |   MsgType   |      MethodName      |        SeqID        |
    37   *	+----------------------------------------------------------------------------------------------+
    38   *	|  									 												           |
    39   *	|                         Protobuf  Argument/Result/Error   			                       |
    40   *	|   							 													           |
    41   *	+----------------------------------------------------------------------------------------------+
    42   */
    43  
    44  const (
    45  	metaInfoFixLen = 8
    46  )
    47  
    48  // NewProtobufCodec ...
    49  func NewProtobufCodec() remote.PayloadCodec {
    50  	return &protobufCodec{}
    51  }
    52  
    53  // protobufCodec implements  PayloadMarshaler
    54  type protobufCodec struct{}
    55  
    56  // Len encode outside not here
    57  func (c protobufCodec) Marshal(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
    58  	// 1. prepare info
    59  	methodName := message.RPCInfo().Invocation().MethodName()
    60  	if methodName == "" {
    61  		return errors.New("empty methodName in protobuf Marshal")
    62  	}
    63  	data, err := getValidData(methodName, message)
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	// 3. encode metainfo
    69  	// 3.1 magic && msgType
    70  	if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(message.MessageType()), out); err != nil {
    71  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write meta info failed: %s", err.Error()))
    72  	}
    73  	// 3.2 methodName
    74  	if _, err := codec.WriteString(methodName, out); err != nil {
    75  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write method name failed: %s", err.Error()))
    76  	}
    77  	// 3.3 seqID
    78  	if err := codec.WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out); err != nil {
    79  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write seqID failed: %s", err.Error()))
    80  	}
    81  
    82  	// 4. write actual message buf
    83  	msg, ok := data.(ProtobufMsgCodec)
    84  	if !ok {
    85  		// If Using Generics
    86  		// if data is a MessageWriterWithContext
    87  		// Do msg.WritePb(ctx context.Context, out remote.ByteBuffer)
    88  		genmsg, isgen := data.(MessageWriterWithContext)
    89  		if isgen {
    90  			actualMsg, err := genmsg.WritePb(ctx)
    91  			if err != nil {
    92  				return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error()))
    93  			}
    94  			actualMsgBuf, ok := actualMsg.([]byte)
    95  			if !ok {
    96  				return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error()))
    97  			}
    98  			_, err = out.WriteBinary(actualMsgBuf)
    99  			if err != nil {
   100  				return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write message buffer failed: %s", err.Error()))
   101  			}
   102  			return nil
   103  		}
   104  		// return error otherwise
   105  		return remote.NewTransErrorWithMsg(remote.InvalidProtocol, "encode failed, codec msg type not match with protobufCodec")
   106  	}
   107  
   108  	// 2. encode pb struct
   109  	// fast write
   110  	if msg, ok := data.(fastpb.Writer); ok {
   111  		msgsize := msg.Size()
   112  		actualMsgBuf, err := out.Malloc(msgsize)
   113  		if err != nil {
   114  			return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf malloc size %d failed: %s", msgsize, err.Error()))
   115  		}
   116  		msg.FastWrite(actualMsgBuf)
   117  		return nil
   118  	}
   119  
   120  	var actualMsgBuf []byte
   121  	if actualMsgBuf, err = msg.Marshal(actualMsgBuf); err != nil {
   122  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf marshal message failed: %s", err.Error()))
   123  	}
   124  	if _, err = out.WriteBinary(actualMsgBuf); err != nil {
   125  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write message buffer failed: %s", err.Error()))
   126  	}
   127  	return nil
   128  }
   129  
   130  func (c protobufCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   131  	payloadLen := message.PayloadLen()
   132  	magicAndMsgType, err := codec.ReadUint32(in)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	if magicAndMsgType&codec.MagicMask != codec.ProtobufV1Magic {
   137  		return perrors.NewProtocolErrorWithType(perrors.BadVersion, "Bad version in protobuf Unmarshal")
   138  	}
   139  	msgType := magicAndMsgType & codec.FrontMask
   140  	if err := codec.UpdateMsgType(msgType, message); err != nil {
   141  		return err
   142  	}
   143  
   144  	methodName, methodFieldLen, err := codec.ReadString(in)
   145  	if err != nil {
   146  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read method name failed: %s", err.Error()))
   147  	}
   148  	if err = codec.SetOrCheckMethodName(methodName, message); err != nil && msgType != uint32(remote.Exception) {
   149  		return err
   150  	}
   151  	seqID, err := codec.ReadUint32(in)
   152  	if err != nil {
   153  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read seqID failed: %s", err.Error()))
   154  	}
   155  	if err = codec.SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) {
   156  		return err
   157  	}
   158  	actualMsgLen := payloadLen - metaInfoFixLen - methodFieldLen
   159  	actualMsgBuf, err := in.Next(actualMsgLen)
   160  	if err != nil {
   161  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf unmarshal, read message buffer failed: %s", err.Error()))
   162  	}
   163  	// exception message
   164  	if message.MessageType() == remote.Exception {
   165  		var exception pbError
   166  		if err := exception.Unmarshal(actualMsgBuf); err != nil {
   167  			return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf unmarshal Exception failed: %s", err.Error()))
   168  		}
   169  		return remote.NewTransError(exception.TypeID(), &exception)
   170  	}
   171  
   172  	if err = codec.NewDataIfNeeded(methodName, message); err != nil {
   173  		return err
   174  	}
   175  	data := message.Data()
   176  
   177  	// fast read
   178  	if msg, ok := data.(fastpb.Reader); ok {
   179  		if len(actualMsgBuf) == 0 {
   180  			// if all fields of a struct is default value, actualMsgLen will be zero and actualMsgBuf will be nil
   181  			// In the implementation of fastpb, if actualMsgBuf is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected.
   182  			// So, when actualMsgBuf is nil, use default protobuf unmarshal method to decode the struct.
   183  			// todo: fix fastpb
   184  		} else {
   185  			_, err := fastpb.ReadMessage(actualMsgBuf, fastpb.SkipTypeCheck, msg)
   186  			if err != nil {
   187  				return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error())
   188  			}
   189  			return nil
   190  		}
   191  	}
   192  
   193  	// JSONPB Generic Case
   194  	if msg, ok := data.(MessageReaderWithMethodWithContext); ok {
   195  		err := msg.ReadPb(ctx, methodName, actualMsgBuf)
   196  		if err != nil {
   197  			return err
   198  		}
   199  		return nil
   200  	}
   201  
   202  	msg, ok := data.(ProtobufMsgCodec)
   203  	if !ok {
   204  		return remote.NewTransErrorWithMsg(remote.InvalidProtocol, "decode failed, codec msg type not match with protobufCodec")
   205  	}
   206  	if err = msg.Unmarshal(actualMsgBuf); err != nil {
   207  		return remote.NewTransErrorWithMsg(remote.ProtocolError, err.Error())
   208  	}
   209  	return err
   210  }
   211  
   212  func (c protobufCodec) Name() string {
   213  	return "protobuf"
   214  }
   215  
   216  // MessageWriterWithContext  writes to output bytebuffer
   217  type MessageWriterWithContext interface {
   218  	WritePb(ctx context.Context) (interface{}, error)
   219  }
   220  
   221  // MessageReaderWithMethodWithContext read from ActualMsgBuf with method
   222  type MessageReaderWithMethodWithContext interface {
   223  	ReadPb(ctx context.Context, method string, in []byte) error
   224  }
   225  
   226  type ProtobufMsgCodec interface {
   227  	Marshal(out []byte) ([]byte, error)
   228  	Unmarshal(in []byte) error
   229  }
   230  
   231  func getValidData(methodName string, message remote.Message) (interface{}, error) {
   232  	if err := codec.NewDataIfNeeded(methodName, message); err != nil {
   233  		return nil, err
   234  	}
   235  	data := message.Data()
   236  	if message.MessageType() != remote.Exception {
   237  		return data, nil
   238  	}
   239  	transErr, isTransErr := data.(*remote.TransError)
   240  	if !isTransErr {
   241  		if err, isError := data.(error); isError {
   242  			encodeErr := NewPbError(remote.InternalError, err.Error())
   243  			return encodeErr, nil
   244  		}
   245  		return nil, errors.New("exception relay need error type data")
   246  	}
   247  	encodeErr := NewPbError(transErr.TypeID(), transErr.Error())
   248  	return encodeErr, nil
   249  }