github.com/cloudwego/kitex@v0.9.0/pkg/transmeta/ttheader.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package transmeta
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strconv"
    23  	"time"
    24  
    25  	"github.com/cloudwego/kitex/pkg/kerrors"
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  	"github.com/cloudwego/kitex/pkg/remote/transmeta"
    28  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    29  	"github.com/cloudwego/kitex/pkg/utils"
    30  	"github.com/cloudwego/kitex/transport"
    31  )
    32  
    33  const (
    34  	framedTransportType   = "framed"
    35  	unframedTransportType = "unframed"
    36  
    37  	// for biz error
    38  	bizStatus  = "biz-status"
    39  	bizMessage = "biz-message"
    40  	bizExtra   = "biz-extra"
    41  )
    42  
    43  // TTHeader handlers.
    44  var (
    45  	ClientTTHeaderHandler remote.MetaHandler = &clientTTHeaderHandler{}
    46  	ServerTTHeaderHandler remote.MetaHandler = &serverTTHeaderHandler{}
    47  )
    48  
    49  // clientTTHeaderHandler implement remote.MetaHandler
    50  type clientTTHeaderHandler struct{}
    51  
    52  // WriteMeta of clientTTHeaderHandler writes headers of TTHeader protocol to transport
    53  func (ch *clientTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
    54  	if !isTTHeader(msg) {
    55  		return ctx, nil
    56  	}
    57  	ri := msg.RPCInfo()
    58  	transInfo := msg.TransInfo()
    59  
    60  	hd := map[uint16]string{
    61  		transmeta.FromService: ri.From().ServiceName(),
    62  		transmeta.FromMethod:  ri.From().Method(),
    63  		transmeta.ToService:   ri.To().ServiceName(),
    64  		transmeta.ToMethod:    ri.To().Method(),
    65  		transmeta.MsgType:     strconv.Itoa(int(msg.MessageType())),
    66  	}
    67  	if msg.ProtocolInfo().TransProto&transport.Framed == transport.Framed {
    68  		hd[transmeta.TransportType] = framedTransportType
    69  	} else {
    70  		hd[transmeta.TransportType] = unframedTransportType
    71  	}
    72  
    73  	cfg := rpcinfo.AsMutableRPCConfig(ri.Config())
    74  	if cfg.IsRPCTimeoutLocked() {
    75  		hd[transmeta.RPCTimeout] = strconv.Itoa(int(ri.Config().RPCTimeout().Milliseconds()))
    76  	}
    77  
    78  	transInfo.PutTransIntInfo(hd)
    79  	transInfo.PutTransStrInfo(map[string]string{transmeta.HeaderIDLServiceName: ri.Invocation().ServiceName()})
    80  	return ctx, nil
    81  }
    82  
    83  // ReadMeta of clientTTHeaderHandler reads headers of TTHeader protocol from transport
    84  func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
    85  	if !isTTHeader(msg) {
    86  		return ctx, nil
    87  	}
    88  	ri := msg.RPCInfo()
    89  	transInfo := msg.TransInfo()
    90  	strInfo := transInfo.TransStrInfo()
    91  
    92  	if code, err := strconv.Atoi(strInfo[bizStatus]); err == nil && code != 0 {
    93  		if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok {
    94  			if bizExtra := strInfo[bizExtra]; bizExtra != "" {
    95  				extra, err := utils.JSONStr2Map(bizExtra)
    96  				if err != nil {
    97  					return ctx, fmt.Errorf("malformed header info, extra: %s", bizExtra)
    98  				}
    99  				setter.SetBizStatusErr(kerrors.NewBizStatusErrorWithExtra(int32(code), strInfo[bizMessage], extra))
   100  			} else {
   101  				setter.SetBizStatusErr(kerrors.NewBizStatusError(int32(code), strInfo[bizMessage]))
   102  			}
   103  		}
   104  	}
   105  	return ctx, nil
   106  }
   107  
   108  // serverTTHeaderHandler implement remote.MetaHandler
   109  type serverTTHeaderHandler struct{}
   110  
   111  // ReadMeta of serverTTHeaderHandler reads headers of TTHeader protocol to transport
   112  func (sh *serverTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
   113  	if !isTTHeader(msg) {
   114  		return ctx, nil
   115  	}
   116  	ri := msg.RPCInfo()
   117  	transInfo := msg.TransInfo()
   118  	intInfo := transInfo.TransIntInfo()
   119  
   120  	ci := rpcinfo.AsMutableEndpointInfo(ri.From())
   121  	if ci != nil {
   122  		if v := intInfo[transmeta.FromService]; v != "" {
   123  			ci.SetServiceName(v)
   124  		}
   125  		if v := intInfo[transmeta.FromMethod]; v != "" {
   126  			ci.SetMethod(v)
   127  		}
   128  	}
   129  
   130  	if cfg := rpcinfo.AsMutableRPCConfig(ri.Config()); cfg != nil {
   131  		timeout := intInfo[transmeta.RPCTimeout]
   132  		if timeoutMS, err := strconv.Atoi(timeout); err == nil {
   133  			cfg.SetRPCTimeout(time.Duration(timeoutMS) * time.Millisecond)
   134  		}
   135  	}
   136  	return ctx, nil
   137  }
   138  
   139  // WriteMeta of serverTTHeaderHandler writes headers of TTHeader protocol to transport
   140  func (sh *serverTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Message) (context.Context, error) {
   141  	if !isTTHeader(msg) {
   142  		return ctx, nil
   143  	}
   144  	ri := msg.RPCInfo()
   145  	transInfo := msg.TransInfo()
   146  	intInfo := transInfo.TransIntInfo()
   147  	strInfo := transInfo.TransStrInfo()
   148  
   149  	intInfo[transmeta.MsgType] = strconv.Itoa(int(msg.MessageType()))
   150  
   151  	if bizErr := ri.Invocation().BizStatusErr(); bizErr != nil {
   152  		strInfo[bizStatus] = strconv.Itoa(int(bizErr.BizStatusCode()))
   153  		strInfo[bizMessage] = bizErr.BizMessage()
   154  		if len(bizErr.BizExtra()) != 0 {
   155  			strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra())
   156  		}
   157  	}
   158  
   159  	return ctx, nil
   160  }
   161  
   162  func isTTHeader(msg remote.Message) bool {
   163  	transProto := msg.ProtocolInfo().TransProto
   164  	return transProto&transport.TTHeader == transport.TTHeader
   165  }