github.com/cloudwego/kitex@v0.9.0/pkg/transmeta/metainfo_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package transmeta
    18  
    19  import (
    20  	"context"
    21  	"testing"
    22  
    23  	"github.com/bytedance/gopkg/cloud/metainfo"
    24  
    25  	"github.com/cloudwego/kitex/internal/mocks"
    26  	"github.com/cloudwego/kitex/internal/test"
    27  	"github.com/cloudwego/kitex/pkg/logid"
    28  	"github.com/cloudwego/kitex/pkg/remote"
    29  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/metadata"
    30  	"github.com/cloudwego/kitex/pkg/remote/transmeta"
    31  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    32  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    33  	"github.com/cloudwego/kitex/transport"
    34  )
    35  
    36  func TestClientReadMetainfo(t *testing.T) {
    37  	ctx := context.Background()
    38  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
    39  	msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)
    40  	hd := map[string]string{
    41  		"hello": "world",
    42  	}
    43  	msg.TransInfo().PutTransStrInfo(hd)
    44  
    45  	var err error
    46  	ctx, err = MetainfoClientHandler.ReadMeta(ctx, msg)
    47  	test.Assert(t, err == nil)
    48  
    49  	kvs := metainfo.RecvAllBackwardValues(ctx)
    50  	test.Assert(t, len(kvs) == 0)
    51  
    52  	ctx = metainfo.WithBackwardValues(context.Background())
    53  	ctx, err = MetainfoClientHandler.ReadMeta(ctx, msg)
    54  	test.Assert(t, err == nil)
    55  
    56  	kvs = metainfo.RecvAllBackwardValues(ctx)
    57  	test.Assert(t, len(kvs) == 1)
    58  	test.Assert(t, kvs["hello"] == "world")
    59  }
    60  
    61  func TestClientWriteMetainfo(t *testing.T) {
    62  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
    63  	ctx := context.Background()
    64  	ctx = metainfo.WithValue(ctx, "tk", "tv")
    65  	ctx = metainfo.WithPersistentValue(ctx, "pk", "pv")
    66  	msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)
    67  
    68  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Thrift))
    69  	ctx, err := MetainfoClientHandler.WriteMeta(ctx, msg)
    70  	test.Assert(t, err == nil)
    71  
    72  	kvs := msg.TransInfo().TransStrInfo()
    73  	test.Assert(t, len(kvs) == 2, kvs)
    74  	test.Assert(t, kvs[metainfo.PrefixTransient+"tk"] == "tv")
    75  	test.Assert(t, kvs[metainfo.PrefixPersistent+"pk"] == "pv")
    76  
    77  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift))
    78  	_, err = MetainfoClientHandler.WriteMeta(ctx, msg)
    79  	test.Assert(t, err == nil)
    80  
    81  	kvs = msg.TransInfo().TransStrInfo()
    82  	test.Assert(t, len(kvs) == 2, kvs)
    83  	test.Assert(t, kvs[metainfo.PrefixTransient+"tk"] == "tv")
    84  	test.Assert(t, kvs[metainfo.PrefixPersistent+"pk"] == "pv")
    85  }
    86  
    87  func TestServerReadMetainfo(t *testing.T) {
    88  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
    89  	ctx0 := context.Background()
    90  	msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)
    91  
    92  	hd := map[string]string{
    93  		"hello":                          "world",
    94  		metainfo.PrefixTransient + "tk":  "tv",
    95  		metainfo.PrefixPersistent + "pk": "pv",
    96  	}
    97  	msg.TransInfo().PutTransStrInfo(hd)
    98  
    99  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Thrift))
   100  	ctx, err := MetainfoServerHandler.ReadMeta(ctx0, msg)
   101  	tvs := metainfo.GetAllValues(ctx)
   102  	pvs := metainfo.GetAllPersistentValues(ctx)
   103  	test.Assert(t, err == nil)
   104  	test.Assert(t, len(tvs) == 1 && tvs["tk"] == "tv", tvs)
   105  	test.Assert(t, len(pvs) == 1 && pvs["pk"] == "pv")
   106  
   107  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift))
   108  	ctx, err = MetainfoServerHandler.ReadMeta(ctx0, msg)
   109  	ctx = metainfo.TransferForward(ctx)
   110  	tvs = metainfo.GetAllValues(ctx)
   111  	pvs = metainfo.GetAllPersistentValues(ctx)
   112  	test.Assert(t, err == nil)
   113  	test.Assert(t, len(tvs) == 1 && tvs["tk"] == "tv", tvs)
   114  	test.Assert(t, len(pvs) == 1 && pvs["pk"] == "pv")
   115  
   116  	ctx = metainfo.TransferForward(ctx)
   117  	tvs = metainfo.GetAllValues(ctx)
   118  	pvs = metainfo.GetAllPersistentValues(ctx)
   119  	test.Assert(t, len(tvs) == 0, len(tvs))
   120  	test.Assert(t, len(pvs) == 1 && pvs["pk"] == "pv")
   121  }
   122  
   123  func TestServerWriteMetainfo(t *testing.T) {
   124  	ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""), nil, rpcinfo.NewRPCStats())
   125  	msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)
   126  
   127  	ctx := context.Background()
   128  	ctx = metainfo.WithBackwardValuesToSend(ctx)
   129  	ctx = metainfo.WithValue(ctx, "tk", "tv")
   130  	ctx = metainfo.WithPersistentValue(ctx, "pk", "pv")
   131  	ok := metainfo.SendBackwardValue(ctx, "bk", "bv")
   132  	test.Assert(t, ok)
   133  
   134  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.PurePayload, serviceinfo.Thrift))
   135  	ctx, err := MetainfoServerHandler.WriteMeta(ctx, msg)
   136  	test.Assert(t, err == nil)
   137  	kvs := msg.TransInfo().TransStrInfo()
   138  	test.Assert(t, len(kvs) == 1 && kvs["bk"] == "bv", kvs)
   139  
   140  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift))
   141  	_, err = MetainfoServerHandler.WriteMeta(ctx, msg)
   142  	test.Assert(t, err == nil)
   143  	kvs = msg.TransInfo().TransStrInfo()
   144  	test.Assert(t, len(kvs) == 1 && kvs["bk"] == "bv", kvs)
   145  }
   146  
   147  func Test_addStreamID(t *testing.T) {
   148  	t.Run("without-stream-log-id", func(t *testing.T) {
   149  		md := metadata.MD{
   150  			transmeta.HTTPStreamLogID: nil,
   151  		}
   152  		ctx := context.Background()
   153  		ctx = addStreamIDToContext(ctx, md)
   154  		logID := logid.GetStreamLogID(ctx)
   155  		test.Assert(t, logID == "", logID) // won't generate a new one
   156  	})
   157  
   158  	t.Run("with-stream-log-id", func(t *testing.T) {
   159  		md := metadata.MD{
   160  			transmeta.HTTPStreamLogID: []string{"test"},
   161  		}
   162  		ctx := context.Background()
   163  		ctx = addStreamIDToContext(ctx, md)
   164  		logID := logid.GetStreamLogID(ctx)
   165  		test.Assert(t, logID == "test", logID)
   166  	})
   167  }