github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/util.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package codec
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  
    23  	"github.com/cloudwego/kitex/pkg/remote"
    24  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    25  )
    26  
    27  const (
    28  	// FrontMask is used in protocol sniffing.
    29  	FrontMask = 0x0000ffff
    30  )
    31  
    32  // SetOrCheckMethodName is used to set method name to invocation.
    33  func SetOrCheckMethodName(methodName string, message remote.Message) error {
    34  	ri := message.RPCInfo()
    35  	ink := ri.Invocation()
    36  	callMethodName := ink.MethodName()
    37  	if methodName == "" {
    38  		return fmt.Errorf("method name that receive is empty")
    39  	}
    40  	if callMethodName == methodName {
    41  		return nil
    42  	}
    43  	// the server's callMethodName may not be empty if RPCInfo is based on connection multiplexing
    44  	// for the server side callMethodName ! = methodName is normal
    45  	if message.RPCRole() == remote.Client {
    46  		return fmt.Errorf("wrong method name, expect=%s, actual=%s", callMethodName, methodName)
    47  	}
    48  	svcInfo, err := message.SpecifyServiceInfo(ink.ServiceName(), methodName)
    49  	if err != nil {
    50  		return err
    51  	}
    52  	if ink, ok := ink.(rpcinfo.InvocationSetter); ok {
    53  		ink.SetMethodName(methodName)
    54  		ink.SetPackageName(svcInfo.GetPackageName())
    55  		ink.SetServiceName(svcInfo.ServiceName)
    56  	} else {
    57  		return errors.New("the interface Invocation doesn't implement InvocationSetter")
    58  	}
    59  
    60  	// unknown method doesn't set methodName for RPCInfo.To(), or lead inconsistent with old version
    61  	rpcinfo.AsMutableEndpointInfo(ri.To()).SetMethod(methodName)
    62  	return nil
    63  }
    64  
    65  // SetOrCheckSeqID is used to check the sequence ID.
    66  func SetOrCheckSeqID(seqID int32, message remote.Message) error {
    67  	switch message.MessageType() {
    68  	case remote.Call, remote.Oneway:
    69  		if ink, ok := message.RPCInfo().Invocation().(rpcinfo.InvocationSetter); ok {
    70  			ink.SetSeqID(seqID)
    71  		} else {
    72  			return errors.New("the interface Invocation doesn't implement InvocationSetter")
    73  		}
    74  	case remote.Reply:
    75  		expectSeqID := message.RPCInfo().Invocation().SeqID()
    76  		if expectSeqID != seqID {
    77  			methodName := message.RPCInfo().Invocation().MethodName()
    78  			return remote.NewTransErrorWithMsg(remote.BadSequenceID, fmt.Sprintf("method[%s] out of order sequence response, expect[%d], receive[%d]", methodName, expectSeqID, seqID))
    79  		}
    80  	case remote.Exception:
    81  		// don't check, proxy may build Exception with seqID = 0
    82  		// thrift 0.13 check seqID for Exception but thrift 0.9.2 doesn't check
    83  	}
    84  	return nil
    85  }
    86  
    87  // UpdateMsgType updates msg type.
    88  func UpdateMsgType(msgType uint32, message remote.Message) error {
    89  	rpcRole := message.RPCRole()
    90  	mt := remote.MessageType(msgType)
    91  	if mt == message.MessageType() {
    92  		return nil
    93  	}
    94  	if rpcRole == remote.Server {
    95  		if mt != remote.Call && mt != remote.Oneway && mt != remote.Stream {
    96  			return remote.NewTransErrorWithMsg(remote.InvalidMessageTypeException, fmt.Sprintf("server side, invalid message type %d", mt))
    97  		}
    98  	} else {
    99  		if mt != remote.Reply && mt != remote.Exception && mt != remote.Stream {
   100  			return remote.NewTransErrorWithMsg(remote.InvalidMessageTypeException, fmt.Sprintf("client side, invalid message type %d", mt))
   101  		}
   102  	}
   103  
   104  	message.SetMessageType(mt)
   105  	return nil
   106  }
   107  
   108  // NewDataIfNeeded is used to create the data if not exist.
   109  func NewDataIfNeeded(method string, message remote.Message) error {
   110  	if message.Data() != nil {
   111  		return nil
   112  	}
   113  	if message.NewData(method) {
   114  		return nil
   115  	}
   116  	return remote.NewTransErrorWithMsg(remote.InternalError, "message data for codec is nil")
   117  }