github.com/cloudwego/kitex@v0.9.0/pkg/remote/bound/transmeta_bound.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package bound
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  
    23  	"github.com/bytedance/gopkg/cloud/metainfo"
    24  
    25  	"github.com/cloudwego/kitex/pkg/consts"
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  )
    28  
    29  // NewTransMetaHandler to build transMetaHandler that handle transport info
    30  func NewTransMetaHandler(mhs []remote.MetaHandler) remote.DuplexBoundHandler {
    31  	return &transMetaHandler{mhs: mhs}
    32  }
    33  
    34  type transMetaHandler struct {
    35  	mhs []remote.MetaHandler
    36  }
    37  
    38  // Write exec before encode
    39  func (h *transMetaHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (context.Context, error) {
    40  	var err error
    41  	for _, hdlr := range h.mhs {
    42  		ctx, err = hdlr.WriteMeta(ctx, sendMsg)
    43  		if err != nil {
    44  			return ctx, err
    45  		}
    46  	}
    47  	return ctx, nil
    48  }
    49  
    50  // OnMessage exec after decode
    51  func (h *transMetaHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
    52  	var err error
    53  	msg, isServer := getValidMsg(args, result)
    54  	if msg == nil {
    55  		return ctx, nil
    56  	}
    57  
    58  	for _, hdlr := range h.mhs {
    59  		ctx, err = hdlr.ReadMeta(ctx, msg)
    60  		if err != nil {
    61  			return ctx, err
    62  		}
    63  	}
    64  	if isServer && result.MessageType() != remote.Exception {
    65  		// Pass through method name using ctx, the method name will be used as from method in the client.
    66  		ctx = context.WithValue(ctx, consts.CtxKeyMethod, msg.RPCInfo().To().Method())
    67  		// TransferForward converts transient values to transient-upstream values and filters out original transient-upstream values.
    68  		// It should be used before the context is passing from server to client.
    69  		// reference https://github.com/bytedance/gopkg/tree/main/cloud/metainfo
    70  		// Notice, it should be after ReadMeta().
    71  		ctx = metainfo.TransferForward(ctx)
    72  	}
    73  	return ctx, nil
    74  }
    75  
    76  // Onactive implements the remote.InboundHandler interface.
    77  func (h *transMetaHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
    78  	return ctx, nil
    79  }
    80  
    81  // OnRead implements the remote.InboundHandler interface.
    82  func (h *transMetaHandler) OnRead(ctx context.Context, conn net.Conn) (context.Context, error) {
    83  	return ctx, nil
    84  }
    85  
    86  // OnInactive implements the remote.InboundHandler interface.
    87  func (h *transMetaHandler) OnInactive(ctx context.Context, conn net.Conn) context.Context {
    88  	return ctx
    89  }
    90  
    91  func getValidMsg(args, result remote.Message) (msg remote.Message, isServer bool) {
    92  	if args != nil && args.RPCRole() == remote.Server {
    93  		// server side, read arg
    94  		return args, true
    95  	}
    96  	// client side, read result
    97  	return result, false
    98  }