trpc.group/trpc-go/trpc-go@v1.0.3/http/value_detached_transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"net"
    20  	"net/http"
    21  	"net/http/httptrace"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/require"
    26  
    27  	trpc "trpc.group/trpc-go/trpc-go"
    28  	"trpc.group/trpc-go/trpc-go/client"
    29  	"trpc.group/trpc-go/trpc-go/codec"
    30  )
    31  
    32  func TestValueDetachedTransport(t *testing.T) {
    33  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    34  	require.Nil(t, err)
    35  	defer ln.Close()
    36  	s := &http.Server{Addr: ln.Addr().String(), Handler: &handler{}}
    37  	serveReturn := make(chan error)
    38  	go func() {
    39  		serveReturn <- s.Serve(ln)
    40  	}()
    41  	clientTransport := NewClientTransport(false)
    42  	httpClientTransport := clientTransport.(*ClientTransport)
    43  	ctx, cancel := context.WithTimeout(trpc.BackgroundContext(), time.Second*10)
    44  	defer cancel()
    45  
    46  	type contextType struct{}
    47  	ctx = context.WithValue(ctx, contextType{}, struct{}{})
    48  
    49  	// Detects whether the data has been unloaded in the native RoundTripper.
    50  	vdt := httpClientTransport.Client.Transport.(*valueDetachedTransport)
    51  	vdt.RoundTripper = &testTransport{
    52  		RoundTripper: vdt.RoundTripper,
    53  		assertFunc: func(request *http.Request) {
    54  			ctx := request.Context()
    55  			// Cannot get value.
    56  			if ctx.Value(contextType{}) != nil {
    57  				t.Fatal("valueDetachedTransport not detach the value")
    58  			}
    59  			if httptrace.ContextClientTrace(ctx) == nil {
    60  				t.Fatal("valueDetachedTransport not transmit the httptrace")
    61  			}
    62  		},
    63  	}
    64  
    65  	// Check if data is still retained in http.ClientTransport.
    66  	// Check if httpClientTransport is type of http.ClientTransport or not.
    67  	// Check if httpClientTransport.Client.Transport is type of *testTransport and can get data.
    68  	// Check if httpClientTransport.Client.Transport.RoundTripper is type of *valueDetachedTransport and detached data.
    69  	// Check if httpClientTransport.Client.Transport.RoundTripper.RoundTripper is type of *testTransport and cannot get value.
    70  	// Check if httpClientTransport.Client.Transport.RoundTripper.RoundTripper.RoundTripper is type of *http.Transport,
    71  	// which is equals to StdHTTPTransport.
    72  	httpClientTransport.Client.Transport = &testTransport{
    73  		RoundTripper: httpClientTransport.Client.Transport,
    74  		assertFunc: func(request *http.Request) {
    75  			t.Log(fmt.Sprintf("%+v", request))
    76  			ctx := request.Context()
    77  			// Can get data.
    78  			if ctx.Value(contextType{}) == nil {
    79  				t.Fatal("ClientTransport detach the value")
    80  			}
    81  			if httptrace.ContextClientTrace(ctx) == nil {
    82  				t.Fatal("valueDetachedTransport not transmit the httptrace")
    83  			}
    84  		},
    85  	}
    86  
    87  	target := "ip://" + ln.Addr().String()
    88  	c := NewClientProxy("", client.WithTarget(target), client.WithTransport(clientTransport))
    89  	rsp := &codec.Body{}
    90  	require.Eventually(t,
    91  		func() bool {
    92  			err = c.Get(ctx, "/", rsp,
    93  				client.WithCurrentSerializationType(codec.SerializationTypeNoop),
    94  				client.WithTimeout(10*time.Second))
    95  			return err == nil
    96  		}, time.Second, 10*time.Millisecond,
    97  		"get %s return failure %v", target, err)
    98  	require.Nil(t, s.Shutdown(ctx))
    99  	<-serveReturn
   100  }
   101  
   102  type handler struct{}
   103  
   104  func (h *handler) ServeHTTP(http.ResponseWriter, *http.Request) {
   105  	return
   106  }
   107  
   108  type testTransport struct {
   109  	http.RoundTripper
   110  	assertFunc func(request *http.Request)
   111  }
   112  
   113  // RoundTrip implements http.RoundTripper.
   114  func (rt *testTransport) RoundTrip(request *http.Request) (*http.Response, error) {
   115  	rt.assertFunc(request)
   116  	response, err := rt.RoundTripper.RoundTrip(request)
   117  	return response, err
   118  }