trpc.group/trpc-go/trpc-go@v1.0.3/http/transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http_test
    15  
    16  // https certificate file generation method:
    17  // 1. ca certificate:
    18  // openssl genrsa -out ca.key 2048
    19  // openssl req -x509 -new -nodes -key ca.key -subj "/CN=*" -days 5000 -out ca.pem
    20  // 2. server certificate:
    21  // openssl genrsa -out server.key 2048
    22  // openssl req -new -key server.key -subj "/CN=*" -out server.csr
    23  // openssl x509 -req -in server.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out server.crt -days 5000 <(printf "subjectAltName=DNS:localhost")
    24  // 3. client certificate:
    25  // openssl genrsa -out client.key 2048
    26  // openssl req -new -key client.key -subj "/CN=*" -out client.csr
    27  // openssl x509 -req -in client.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out client.crt -days 5000 <(printf "subjectAltName=DNS:localhost")
    28  // 4. show certificate content:
    29  // openssl x509 -text -in server.crt -noout
    30  
    31  import (
    32  	"bytes"
    33  	"context"
    34  	"crypto/tls"
    35  	"errors"
    36  	"fmt"
    37  	"io"
    38  	"mime/multipart"
    39  	"net"
    40  	"net/http"
    41  	"net/http/httptest"
    42  	"net/url"
    43  	"os"
    44  	"path"
    45  	"path/filepath"
    46  	"strconv"
    47  	"strings"
    48  	"testing"
    49  	"time"
    50  
    51  	"github.com/stretchr/testify/require"
    52  	"golang.org/x/net/http2"
    53  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    54  
    55  	"trpc.group/trpc-go/trpc-go/client"
    56  	"trpc.group/trpc-go/trpc-go/codec"
    57  	"trpc.group/trpc-go/trpc-go/errs"
    58  	"trpc.group/trpc-go/trpc-go/filter"
    59  	thttp "trpc.group/trpc-go/trpc-go/http"
    60  	"trpc.group/trpc-go/trpc-go/log"
    61  	"trpc.group/trpc-go/trpc-go/naming/registry"
    62  	"trpc.group/trpc-go/trpc-go/server"
    63  	"trpc.group/trpc-go/trpc-go/testdata/restful/helloworld"
    64  	"trpc.group/trpc-go/trpc-go/transport"
    65  )
    66  
    67  func newNoopStdHTTPServer() *http.Server { return &http.Server{} }
    68  
    69  func TestStartServer(t *testing.T) {
    70  	ctx := context.Background()
    71  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
    72  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    73  	require.Nil(t, err)
    74  	defer ln.Close()
    75  	option := transport.WithListener(ln)
    76  	handler := transport.WithHandler(transport.Handler(&h{}))
    77  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
    78  	require.NotNil(t, tp.ListenAndServe(ctx, transport.WithListenAddress("127.0.0.1:8888"), handler, transport.WithListenNetwork("tcp1")))
    79  	tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "ca1")
    80  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls))
    81  }
    82  
    83  func TestH2C(t *testing.T) {
    84  	ctx := context.Background()
    85  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    86  	require.Nil(t, err)
    87  	defer ln.Close()
    88  	handler := transport.WithHandler(transport.Handler(&h{}))
    89  	tp := thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithReusePort(), thttp.WithEnableH2C())
    90  	require.Nil(t, tp.ListenAndServe(ctx, transport.WithListener(ln), handler))
    91  }
    92  
    93  func TestDisableReusePort(t *testing.T) {
    94  	ctx := context.Background()
    95  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
    96  	ln1, err := net.Listen("tcp", "127.0.0.1:0")
    97  	require.Nil(t, err)
    98  	defer ln1.Close()
    99  	option := transport.WithListener(ln1)
   100  	handler := transport.WithHandler(transport.Handler(&h{}))
   101  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   102  
   103  	option = transport.WithListenAddress(ln1.Addr().String())
   104  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, transport.WithListenNetwork("tcp1")))
   105  
   106  	ln2, err := net.Listen("tcp", "127.0.0.1:0")
   107  	require.Nil(t, err)
   108  	defer ln2.Close()
   109  	option = transport.WithListener(ln2)
   110  	tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "")
   111  	require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls))
   112  
   113  	ln3, err := net.Listen("tcp", "127.0.0.1:0")
   114  	require.Nil(t, err)
   115  	defer ln3.Close()
   116  	option = transport.WithListener(ln3)
   117  	tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "root")
   118  	require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls))
   119  
   120  	ln4, err := net.Listen("tcp", "127.0.0.1:0")
   121  	require.Nil(t, err)
   122  	defer ln4.Close()
   123  	option = transport.WithListener(ln4)
   124  	tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.key")
   125  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls))
   126  }
   127  
   128  func TestStartServerWithNoHandler(t *testing.T) {
   129  	ctx := context.Background()
   130  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   131  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   132  	require.Nil(t, err)
   133  	defer ln.Close()
   134  	option := transport.WithListener(ln)
   135  	require.NotNil(t, tp.ListenAndServe(ctx, option), "http server transport handler empty")
   136  }
   137  
   138  func TestErrHandler(t *testing.T) {
   139  	ctx := context.Background()
   140  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   141  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   142  	require.Nil(t, err)
   143  	defer ln.Close()
   144  	option := transport.WithListener(ln)
   145  	h := transport.WithHandler(transport.Handler(&errHandler{}))
   146  	require.Nil(t, tp.ListenAndServe(ctx, option, h))
   147  
   148  	ct := thttp.NewClientTransport(true)
   149  	ctx, msg := codec.WithNewMessage(ctx)
   150  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   151  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   152  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   153  
   154  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   155  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   156  		transport.WithDialAddress(ln.Addr().String()),
   157  	)
   158  	require.Nil(t, rsp, "roundtrip rsp not empty")
   159  	require.Nil(t, err, "Failed to roundtrip")
   160  }
   161  
   162  func TestErrHeaderHandler(t *testing.T) {
   163  	ctx := context.Background()
   164  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   165  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   166  	require.Nil(t, err)
   167  	defer func() { require.Nil(t, ln.Close()) }()
   168  	err = tp.ListenAndServe(ctx,
   169  		transport.WithHandler(transport.Handler(&errHeaderHandler{})),
   170  		transport.WithListener(ln),
   171  	)
   172  	require.Nil(t, err)
   173  
   174  	ct := thttp.NewClientTransport(true)
   175  	ctx, msg := codec.WithNewMessage(ctx)
   176  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   177  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   178  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   179  
   180  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   181  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   182  		transport.WithDialAddress(ln.Addr().String()),
   183  	)
   184  	require.Nil(t, rsp, "roundtrip rsp not empty")
   185  	require.Nil(t, err, "Failed to roundtrip")
   186  }
   187  
   188  func TestListenAndServeFailedDueToBadCertificationFile(t *testing.T) {
   189  	ctx := context.Background()
   190  	oldLogger := log.DefaultLogger
   191  	defer func() {
   192  		log.DefaultLogger = oldLogger
   193  	}()
   194  	errorCh := make(chan error)
   195  	log.DefaultLogger = &testLog{Logger: oldLogger, errorCh: errorCh}
   196  
   197  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   198  	require.Nil(t, err)
   199  	defer func() { require.Nil(t, ln.Close()) }()
   200  	const badCertFile = "bad-file.cert"
   201  	require.Nil(
   202  		t,
   203  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   204  			ctx,
   205  			transport.WithListener(ln),
   206  			transport.WithHandler(transport.Handler(&h{})),
   207  			transport.WithServeTLS(badCertFile, "../testdata/server.key", ""),
   208  		),
   209  		"failed to new client transport",
   210  	)
   211  
   212  	select {
   213  	case <-time.After(time.Second):
   214  		t.Fatal("listen on a bad cert should log an error")
   215  	case err := <-errorCh:
   216  		require.Contains(t, err.Error(), badCertFile)
   217  	}
   218  }
   219  
   220  func TestStartTLSServerAndNoCheckServer(t *testing.T) {
   221  	ctx := context.Background()
   222  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   223  	require.Nil(t, err)
   224  	defer func() { require.Nil(t, ln.Close()) }()
   225  	// Only enables https server and do not verify client certificate.
   226  	require.Nil(
   227  		t,
   228  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   229  			ctx,
   230  			transport.WithListener(ln),
   231  			transport.WithHandler(transport.Handler(&h{})),
   232  			transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   233  		),
   234  		"Failed to new client transport",
   235  	)
   236  
   237  	ct := thttp.NewClientTransport(false)
   238  	ctx, msg := codec.WithNewMessage(ctx)
   239  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   240  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   241  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   242  
   243  	rsp, err := ct.RoundTrip(
   244  		ctx,
   245  		[]byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   246  		transport.WithDialAddress(ln.Addr().String()),
   247  		// Fully trust the https server and do not verify server certificate,
   248  		// can only be used in test env.
   249  		transport.WithDialTLS("", "", "none", ""),
   250  	)
   251  	require.Nil(t, rsp, "roundtrip rsp not empty")
   252  	require.Nil(t, err, "Failed to roundtrip")
   253  }
   254  
   255  func TestServerWithListenerOption(t *testing.T) {
   256  	ln, err := net.Listen("tcp", "localhost:0")
   257  	require.Nil(t, err)
   258  	defer ln.Close()
   259  	service := server.New(
   260  		server.WithServiceName("trpc.http.server.ListenerTest"),
   261  		server.WithNetwork("tcp"),
   262  		server.WithProtocol("http"),
   263  		server.WithListener(ln),
   264  	)
   265  	thttp.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) error {
   266  		fmt.Printf("Protocol: %s\n", r.Proto)
   267  		w.Write([]byte(r.Proto))
   268  		return nil
   269  	})
   270  	thttp.RegisterDefaultService(service)
   271  	s := &server.Server{}
   272  	s.AddService("trpc.http.server.ListenerTest", service)
   273  	go func() {
   274  		require.Nil(t, s.Serve())
   275  	}()
   276  	defer s.Close(nil)
   277  	time.Sleep(100 * time.Millisecond)
   278  
   279  	resp, err := http.Get(fmt.Sprintf("http://%v/index", ln.Addr()))
   280  	require.Nil(t, err)
   281  	defer resp.Body.Close()
   282  	body, err := io.ReadAll(resp.Body)
   283  	require.Nil(t, err)
   284  	require.Equal(t, []byte("HTTP/1.1"), body)
   285  
   286  	const invalidAddr = "localhost:910439"
   287  	resp, err = http.Get(fmt.Sprintf("http://%s/index", invalidAddr))
   288  	require.NotNil(t, err)
   289  	require.Nil(t, resp)
   290  }
   291  
   292  func TestStartDisableKeepAlivesServer(t *testing.T) {
   293  	ln, err := net.Listen("tcp", "localhost:0")
   294  	require.Nil(t, err)
   295  	defer ln.Close()
   296  	s := &server.Server{}
   297  	service := server.New(
   298  		server.WithListener(ln),
   299  		server.WithServiceName("trpc.http.server.ListenerTest"),
   300  		server.WithNetwork("tcp"),
   301  		server.WithProtocol("http"),
   302  		server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer)),
   303  		server.WithDisableKeepAlives(true),
   304  	)
   305  	thttp.HandleFunc("/disable-keepalives", func(w http.ResponseWriter, _ *http.Request) error {
   306  		w.Header().Set("Connection", "keep-alive")
   307  		return nil
   308  	})
   309  	thttp.RegisterDefaultService(service)
   310  	s.AddService("trpc.http.server.ListenerTest", service)
   311  	go func() {
   312  		err := s.Serve()
   313  		require.Nil(t, err)
   314  	}()
   315  	defer func() {
   316  		_ = s.Close(nil)
   317  	}()
   318  
   319  	time.Sleep(100 * time.Millisecond)
   320  
   321  	dailCount := 0
   322  	client := &http.Client{
   323  		Transport: &http.Transport{
   324  			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   325  				dailCount++
   326  				conn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
   327  				return conn, err
   328  			},
   329  		},
   330  	}
   331  	num := 3
   332  	url := fmt.Sprintf("http://%s/disable-keepalives", ln.Addr())
   333  	for i := 0; i < num; i++ {
   334  		resp, err := client.Get(url)
   335  		require.Nil(t, err)
   336  		defer resp.Body.Close()
   337  		_, err = io.Copy(io.Discard, resp.Body)
   338  		require.Nil(t, err)
   339  	}
   340  	require.Equal(t, num, dailCount)
   341  }
   342  
   343  func TestStartH2cServer(t *testing.T) {
   344  	ln, err := net.Listen("tcp", "localhost:0")
   345  	require.Nil(t, err)
   346  	defer ln.Close()
   347  	s := &server.Server{}
   348  	service := server.New(
   349  		server.WithListener(ln),
   350  		server.WithServiceName("trpc.h2c.server.Greeter"),
   351  		server.WithNetwork("tcp"),
   352  		server.WithProtocol("http2"),
   353  		server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithEnableH2C())),
   354  	)
   355  	thttp.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) error {
   356  		fmt.Printf("Protocol: %s\n", r.Proto)
   357  		w.Write([]byte(r.Proto))
   358  		return nil
   359  	})
   360  	thttp.HandleFunc("/main", func(w http.ResponseWriter, r *http.Request) error {
   361  		fmt.Printf("Protocol: %s\n", r.Proto)
   362  		w.Write([]byte(r.Proto))
   363  		return nil
   364  	})
   365  	thttp.RegisterDefaultService(service)
   366  	s.AddService("trpc.h2c.server.Greeter", service)
   367  
   368  	go func() {
   369  		err := s.Serve()
   370  		require.Nil(t, err)
   371  	}()
   372  
   373  	time.Sleep(100 * time.Millisecond)
   374  
   375  	// h2c client
   376  	h2cClient := http.Client{
   377  		Transport: &http2.Transport{
   378  			AllowHTTP: true,
   379  			DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   380  				return net.Dial(network, addr)
   381  			},
   382  		},
   383  	}
   384  	url := fmt.Sprintf("http://%s/", ln.Addr())
   385  	resp, err := h2cClient.Get(url + "main")
   386  	require.Nil(t, err)
   387  	defer resp.Body.Close()
   388  	body, err := io.ReadAll(resp.Body)
   389  	require.Nil(t, err)
   390  	require.Equal(t, []byte("HTTP/2.0"), body)
   391  
   392  	// http1 client
   393  	resp2, err := http.Get(url)
   394  	require.Nil(t, err)
   395  	defer resp2.Body.Close()
   396  	body, err = io.ReadAll(resp2.Body)
   397  	require.Nil(t, err)
   398  	require.Equal(t, []byte("HTTP/1.1"), body)
   399  	require.Equal(t, http.StatusOK, resp2.StatusCode)
   400  }
   401  
   402  func TestHttp2StartTLSServerAndNoCheckServer(t *testing.T) {
   403  	ctx := context.Background()
   404  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   405  	require.Nil(t, err)
   406  	defer func() { require.Nil(t, ln.Close()) }()
   407  	// Only enables https server and do not verify client certificate.
   408  	require.Nil(
   409  		t,
   410  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   411  			ctx,
   412  			transport.WithListener(ln),
   413  			transport.WithHandler(transport.Handler(&h{})),
   414  			transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   415  		),
   416  		"Failed to new client transport",
   417  	)
   418  
   419  	ct := thttp.NewClientTransport(true)
   420  	ctx, msg := codec.WithNewMessage(ctx)
   421  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   422  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   423  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   424  
   425  	rsp, err := ct.RoundTrip(
   426  		ctx,
   427  		[]byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   428  		transport.WithDialAddress(ln.Addr().String()),
   429  		// Fully trust the https server and do not verify server certificate,
   430  		// can only be used in test env.
   431  		transport.WithDialTLS("", "", "none", ""),
   432  	)
   433  	require.Nil(t, rsp, "roundtrip rsp not empty")
   434  	require.Nil(t, err, "Failed to roundtrip")
   435  }
   436  
   437  func TestStartTLSServerAndCheckServer(t *testing.T) {
   438  	ctx := context.Background()
   439  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   440  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   441  	require.Nil(t, err)
   442  	defer func() { require.Nil(t, ln.Close()) }()
   443  	err = tp.ListenAndServe(ctx,
   444  		transport.WithHandler(transport.Handler(&h{})),
   445  		// Only enables https server and do not verify client certificate.
   446  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   447  		transport.WithListener(ln),
   448  	)
   449  	require.Nil(t, err, "Failed to new client transport")
   450  
   451  	ct := thttp.NewClientTransport(false)
   452  	ctx, msg := codec.WithNewMessage(ctx)
   453  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   454  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   455  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   456  
   457  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   458  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   459  		transport.WithDialAddress(ln.Addr().String()),
   460  		// Uses ca public key to verify server certificate.
   461  		transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"),
   462  	)
   463  	require.Nil(t, rsp, "roundtrip rsp not empty")
   464  	require.Nil(t, err, "Failed to roundtrip")
   465  }
   466  
   467  func TestStartTLSServerAndCheckClientNoCert(t *testing.T) {
   468  	ctx := context.Background()
   469  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   470  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   471  	require.Nil(t, err)
   472  	defer func() { require.Nil(t, ln.Close()) }()
   473  	err = tp.ListenAndServe(ctx,
   474  		transport.WithHandler(transport.Handler(&h{})),
   475  		// Enables two-way authentication http server and need to verify client certificate.
   476  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   477  		transport.WithListener(ln),
   478  	)
   479  	require.Nil(t, err, "Failed to new client transport")
   480  
   481  	ct := thttp.NewClientTransport(false)
   482  	ctx, msg := codec.WithNewMessage(ctx)
   483  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   484  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   485  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   486  
   487  	_, err = ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   488  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   489  		transport.WithDialAddress(ln.Addr().String()),
   490  		// If the client's own certificate is not sent, will return TLS verification failed.
   491  		transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"),
   492  	)
   493  	require.NotNil(t, err, "Failed to roundtrip")
   494  }
   495  
   496  func TestStartTLSServerAndCheckClient(t *testing.T) {
   497  	ctx := context.Background()
   498  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   499  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   500  	require.Nil(t, err)
   501  	defer func() { require.Nil(t, ln.Close()) }()
   502  	// Enables two-way authentication http server and need to verify client certificate.
   503  	err = tp.ListenAndServe(ctx,
   504  		transport.WithHandler(transport.Handler(&h{})),
   505  		// Only enables https server and do not verify client certificate.
   506  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   507  		transport.WithListener(ln),
   508  	)
   509  	require.Nil(t, err, "Failed to new client transport")
   510  
   511  	ct := thttp.NewClientTransport(false)
   512  	ctx, msg := codec.WithNewMessage(ctx)
   513  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   514  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   515  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   516  
   517  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   518  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   519  		transport.WithDialAddress(ln.Addr().String()),
   520  		// Need to send the client's own certificate to server.
   521  		transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost"),
   522  	)
   523  	require.Nil(t, rsp, "roundtrip rsp not empty")
   524  	require.Nil(t, err, "Failed to roundtrip")
   525  }
   526  
   527  func TestNewClientTransport(t *testing.T) {
   528  	ct := thttp.NewClientTransport(false)
   529  	require.NotNil(t, ct, "Failed to new client transport")
   530  
   531  	ct2 := thttp.NewClientTransport(true)
   532  	require.NotNil(t, ct2, "Failed to new http2 client transport")
   533  }
   534  
   535  func TestClientRoundTrip(t *testing.T) {
   536  	ctx := context.Background()
   537  	ct := thttp.NewClientTransport(false)
   538  	ctx, msg := codec.WithNewMessage(ctx)
   539  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   540  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   541  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   542  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   543  	require.Nil(t, err)
   544  	defer ln.Close()
   545  	go http.Serve(ln, nil)
   546  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   547  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   548  		transport.WithDialAddress(ln.Addr().String()))
   549  	require.Nil(t, rsp, "roundtrip rsp not empty")
   550  	require.Nil(t, err, "Failed to roundtrip")
   551  }
   552  
   553  func TestClientRoundTripWithNoHead(t *testing.T) {
   554  	ctx := context.Background()
   555  	ct := thttp.NewClientTransport(false)
   556  	ctx, msg := codec.WithNewMessage(ctx)
   557  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   558  
   559  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   560  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   561  		transport.WithDialAddress("127.0.0.1:18080"))
   562  	require.Nil(t, rsp, "no head roundtrip rsp not empty")
   563  	require.NotNil(t, err, "no head roundtrip err nil")
   564  
   565  }
   566  
   567  func TestClientWithSelectorNode(t *testing.T) {
   568  	ctx := context.Background()
   569  	type testCase struct {
   570  		target   string
   571  		address  string
   572  		listener net.Listener
   573  	}
   574  	var tests []testCase
   575  	for i := 0; i < 2; i++ {
   576  		ln, err := net.Listen("tcp", "127.0.0.1:0")
   577  		require.Nil(t, err)
   578  		defer ln.Close()
   579  		addr := ln.Addr().String()
   580  		tests = append(tests, testCase{"ip://" + addr, addr, ln})
   581  	}
   582  	for _, tt := range tests {
   583  		tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   584  		option := transport.WithListener(tt.listener)
   585  		handler := transport.WithHandler(transport.Handler(&h{}))
   586  		err := tp.ListenAndServe(ctx, option, handler)
   587  		require.Nil(t, err, "Failed to new client transport")
   588  
   589  		proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   590  			client.WithTarget(tt.target),
   591  			client.WithSerializationType(codec.SerializationTypeNoop),
   592  		)
   593  
   594  		reqBody := &codec.Body{
   595  			Data: []byte("{\"username\":\"xyz\"," +
   596  				"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   597  		}
   598  		rspBody := &codec.Body{}
   599  		n := &registry.Node{}
   600  		require.Nil(t,
   601  			proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, client.WithSelectorNode(n)),
   602  			"Failed to post")
   603  		require.Equal(t, tt.address, n.Address)
   604  	}
   605  }
   606  
   607  func TestClient(t *testing.T) {
   608  	ctx := context.Background()
   609  	old := codec.GetSerializer(codec.SerializationTypeJSON)
   610  	defer func() { codec.RegisterSerializer(codec.SerializationTypeJSON, old) }()
   611  	codec.RegisterSerializer(codec.SerializationTypeJSON, &codec.JSONPBSerialization{})
   612  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   613  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   614  	require.Nil(t, err)
   615  	defer ln.Close()
   616  	option := transport.WithListener(ln)
   617  	handler := transport.WithHandler(transport.Handler(&h{}))
   618  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   619  
   620  	header := &thttp.ClientReqHeader{}
   621  	header.AddHeader("ContentType", "application/json")
   622  
   623  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   624  		client.WithTarget("ip://"+ln.Addr().String()),
   625  		client.WithSerializationType(codec.SerializationTypeNoop),
   626  		client.WithReqHead(header),
   627  		client.WithMetaData("k1", []byte("v1")),
   628  	)
   629  	reqBody := &codec.Body{
   630  		Data: []byte("{\"username\":\"xyz\"," +
   631  			"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   632  	}
   633  	rspBody := &codec.Body{}
   634  
   635  	require.Nil(t, proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to post")
   636  	require.Nil(t, proxy.Put(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to put")
   637  	require.Nil(t, proxy.Delete(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to delete")
   638  	require.Nil(t, proxy.Get(ctx, "/trpc.test.helloworld.Greeter/SayHello", rspBody), "Failed to get")
   639  	require.Nil(t, proxy.Patch(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to patch")
   640  
   641  	// Test client with options.
   642  	proxy = thttp.NewClientProxy("trpc.test.helloworld.Greeter")
   643  	reqBody = &codec.Body{
   644  		Data: []byte("{\"username\":\"xyz\"," +
   645  			"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   646  	}
   647  	rspBody = &codec.Body{}
   648  	require.Nil(t,
   649  		proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody,
   650  			client.WithTarget("ip://"+ln.Addr().String()),
   651  			client.WithSerializationType(codec.SerializationTypeNoop),
   652  			client.WithReqHead(header),
   653  			client.WithMetaData("k1", []byte("v1")),
   654  		), "Failed to post")
   655  
   656  	require.NotNil(t,
   657  		proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody,
   658  			client.WithTarget("ip://127.0.0.1:180"),
   659  		), "Failed to post")
   660  }
   661  
   662  func TestReqHeader(t *testing.T) {
   663  	ctx := context.Background()
   664  	// Invalid url.
   665  	header := &thttp.ClientReqHeader{}
   666  	header.AddHeader("Content-Type", "application/json")
   667  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   668  		client.WithTarget("ip://127.0.0.1:18080:www.baidu.com//"),
   669  		client.WithSerializationType(codec.SerializationTypeNoop),
   670  		client.WithReqHead(header),
   671  	)
   672  	reqBody := &codec.Body{}
   673  	rspBody := &codec.Body{}
   674  	err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody)
   675  	require.NotNil(t, err)
   676  }
   677  
   678  func TestReqHeaderWithContentType(t *testing.T) {
   679  	ctx := context.Background()
   680  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   681  	require.Nil(t, err)
   682  	defer ln.Close()
   683  	option := transport.WithListener(ln)
   684  	handler := transport.WithHandler(transport.Handler(&h{}))
   685  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   686  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   687  	var tests = []struct {
   688  		expected string
   689  	}{
   690  		{"application/json"},
   691  		{"application/jsonp"},
   692  		{"application/jsonp123"},
   693  		{"application/text123"},
   694  	}
   695  	for _, tt := range tests {
   696  		header := &thttp.ClientReqHeader{}
   697  		header.AddHeader("Content-Type", tt.expected)
   698  		proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   699  			client.WithTarget("ip://"+ln.Addr().String()),
   700  			client.WithSerializationType(codec.SerializationTypeForm),
   701  			client.WithReqHead(header),
   702  		)
   703  		reqBody := &codec.Body{}
   704  		rspBody := &codec.Body{}
   705  		err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody)
   706  		require.Nil(t, err)
   707  	}
   708  }
   709  
   710  func TestHandler(t *testing.T) {
   711  	var (
   712  		handler = func(w http.ResponseWriter, r *http.Request) {
   713  			return
   714  		}
   715  		handlerFunc = func(w http.ResponseWriter, r *http.Request) error {
   716  			return nil
   717  		}
   718  		service = server.New(server.WithProtocol("http"))
   719  	)
   720  
   721  	thttp.Handle("*", http.HandlerFunc(handler))
   722  	thttp.HandleFunc("/path/do/not/equal/to/*", handlerFunc)
   723  	thttp.RegisterDefaultService(service)
   724  
   725  	for _, method := range thttp.ServiceDesc.Methods {
   726  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   727  			return make([]filter.ServerFilter, 0), nil
   728  		})
   729  
   730  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   731  			return make([]filter.ServerFilter, 0), errors.New("invalid filter")
   732  		})
   733  
   734  		header := &thttp.Header{
   735  			Request:  &http.Request{},
   736  			Response: &httptest.ResponseRecorder{},
   737  		}
   738  		ctx := thttp.WithHeader(context.TODO(), header)
   739  		_, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) {
   740  			return make([]filter.ServerFilter, 0), nil
   741  		})
   742  		require.Nil(t, err)
   743  	}
   744  }
   745  
   746  func TestMux(t *testing.T) {
   747  	var handler = func(w http.ResponseWriter, r *http.Request) {
   748  		return
   749  	}
   750  	mux := http.NewServeMux()
   751  	mux.HandleFunc("/", handler)
   752  
   753  	var service = &mockService{}
   754  	thttp.RegisterServiceMux(service, mux)
   755  	desc, _ := service.desc.(*server.ServiceDesc)
   756  	for _, method := range desc.Methods {
   757  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   758  			return make([]filter.ServerFilter, 0), nil
   759  		})
   760  
   761  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   762  			return make([]filter.ServerFilter, 0), errors.New("invalid filter")
   763  		})
   764  
   765  		req, _ := http.NewRequest("GET", "/", nil)
   766  		header := &thttp.Header{
   767  			Request:  req,
   768  			Response: &httptest.ResponseRecorder{},
   769  		}
   770  		ctx := thttp.WithHeader(context.TODO(), header)
   771  		_, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) {
   772  			return make([]filter.ServerFilter, 0), nil
   773  		})
   774  		require.Nil(t, err)
   775  	}
   776  }
   777  
   778  // TestCheckRedirect tests set CheckRedirect
   779  func TestCheckRedirect(t *testing.T) {
   780  	ctx := context.Background()
   781  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   782  	require.Nil(t, err)
   783  	defer ln.Close()
   784  	// server
   785  	go func() {
   786  		// real backend
   787  		h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   788  			w.Write([]byte("real"))
   789  		})
   790  		http.Handle("/real", h)
   791  
   792  		// redirect a
   793  		rha := http.RedirectHandler("/b", http.StatusMovedPermanently)
   794  		http.Handle("/a", rha)
   795  
   796  		// redirect b
   797  		rhb := http.RedirectHandler("/real", http.StatusMovedPermanently)
   798  		http.Handle("/b", rhb)
   799  
   800  		http.Serve(ln, nil)
   801  	}()
   802  	time.Sleep(200 * time.Millisecond)
   803  
   804  	// sets CheckRedirect
   805  	checkRedirect := func(_ *http.Request, via []*http.Request) error {
   806  		if len(via) > 1 {
   807  			return errors.New("more than once")
   808  		}
   809  		return nil
   810  	}
   811  	thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = checkRedirect
   812  	defer func() {
   813  		thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = nil
   814  	}()
   815  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   816  		client.WithTarget("ip://"+ln.Addr().String()),
   817  		client.WithSerializationType(codec.SerializationTypeNoop),
   818  	)
   819  	reqBody := &codec.Body{}
   820  	rspBody := &codec.Body{}
   821  	// only redirect once form /b
   822  	require.Nil(t, proxy.Post(ctx, "/b", reqBody, rspBody))
   823  	// redirect twice from /a
   824  	err = proxy.Post(ctx, "/a", reqBody, rspBody)
   825  	require.NotNil(t, err)
   826  	require.Equal(t, true, strings.Contains(err.Error(), "more than once"))
   827  }
   828  
   829  func TestTransportError(t *testing.T) {
   830  	http.HandleFunc("/timeout", func(http.ResponseWriter, *http.Request) {
   831  		time.Sleep(time.Second)
   832  	})
   833  	http.HandleFunc("/cancel", func(http.ResponseWriter, *http.Request) {})
   834  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   835  	require.Nil(t, err)
   836  	defer ln.Close()
   837  	go func() { http.Serve(ln, nil) }()
   838  	time.Sleep(200 * time.Millisecond)
   839  
   840  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   841  		client.WithTarget("ip://"+ln.Addr().String()),
   842  		client.WithSerializationType(codec.SerializationTypeNoop),
   843  		client.WithTimeout(time.Millisecond*500),
   844  	)
   845  	rspBody := &codec.Body{}
   846  
   847  	err = proxy.Get(context.Background(), "/timeout", rspBody)
   848  	terr, ok := err.(*errs.Error)
   849  	require.True(t, ok)
   850  	require.EqualValues(t, terr.Code, int32(errs.RetClientTimeout))
   851  
   852  	ctx, cancel := context.WithCancel(context.Background())
   853  	cancel()
   854  	err = proxy.Get(ctx, "/cancel", rspBody)
   855  	terr, ok = err.(*errs.Error)
   856  	require.True(t, ok)
   857  	require.EqualValues(t, terr.Code, int32(errs.RetClientCanceled))
   858  }
   859  
   860  func TestClientRoundDyeing(t *testing.T) {
   861  	ctx := context.Background()
   862  	ct := thttp.NewClientTransport(false)
   863  	ctx, msg := codec.WithNewMessage(ctx)
   864  	msg.WithDyeing(true)
   865  	dyeingKey := "dyeingkey"
   866  	msg.WithDyeingKey(dyeingKey)
   867  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   868  	req := &http.Request{
   869  		Header: http.Header{},
   870  	}
   871  	reqHeader := &thttp.ClientReqHeader{
   872  		Request: req,
   873  	}
   874  	msg.WithClientReqHead(reqHeader)
   875  	rspHeader := &thttp.ClientRspHeader{}
   876  	msg.WithClientRspHead(rspHeader)
   877  	meta := codec.MetaData{
   878  		thttp.TrpcDyeingKey: []byte(dyeingKey),
   879  	}
   880  	msg.WithClientMetaData(meta)
   881  	_, err := ct.RoundTrip(ctx, nil)
   882  	require.NotNil(t, err)
   883  	require.Equal(t, req.Header.Get(thttp.TrpcMessageType),
   884  		strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)))
   885  }
   886  
   887  func TestClientRoundEnvTransfer(t *testing.T) {
   888  	ctx := context.Background()
   889  	ct := thttp.NewClientTransport(false)
   890  	ctx, msg := codec.WithNewMessage(ctx)
   891  	msg.WithEnvTransfer("feat,master")
   892  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   893  	req := &http.Request{
   894  		Header: http.Header{},
   895  	}
   896  	reqHeader := &thttp.ClientReqHeader{
   897  		Request: req,
   898  	}
   899  	msg.WithClientReqHead(reqHeader)
   900  	rspHeader := &thttp.ClientRspHeader{}
   901  	msg.WithClientRspHead(rspHeader)
   902  	_, err := ct.RoundTrip(ctx, nil)
   903  	require.NotNil(t, err)
   904  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), thttp.TrpcEnv)
   905  }
   906  
   907  func TestDisableBase64EncodeTransInfo(t *testing.T) {
   908  	ctx := context.Background()
   909  	ct := thttp.NewClientTransport(false, transport.WithDisableEncodeTransInfoBase64())
   910  	ctx, msg := codec.WithNewMessage(ctx)
   911  	var (
   912  		envTrans  = "feat,master"
   913  		metaVal   = "value"
   914  		dyeingKey = "dyeingkey"
   915  	)
   916  	msg.WithEnvTransfer(envTrans)
   917  	msg.WithClientMetaData(codec.MetaData{"key": []byte(metaVal)})
   918  	msg.WithDyeing(true)
   919  	msg.WithDyeingKey(dyeingKey)
   920  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   921  	req := &http.Request{
   922  		Header: http.Header{},
   923  	}
   924  	reqHeader := &thttp.ClientReqHeader{
   925  		Request: req,
   926  	}
   927  	msg.WithClientReqHead(reqHeader)
   928  	rspHeader := &thttp.ClientRspHeader{}
   929  	msg.WithClientRspHead(rspHeader)
   930  	_, err := ct.RoundTrip(ctx, nil)
   931  	require.NotNil(t, err)
   932  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), envTrans)
   933  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), metaVal)
   934  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), dyeingKey)
   935  }
   936  
   937  func TestDisableServiceRouterTransInfo(t *testing.T) {
   938  	ctx := context.Background()
   939  	a := require.New(t)
   940  	ct := thttp.NewClientTransport(false)
   941  	ctx, msg := codec.WithNewMessage(ctx)
   942  	msg.WithClientMetaData(codec.MetaData{thttp.TrpcEnv: []byte("orienv")}) // this emulate decode trpc protocol client request
   943  	msg.WithEnvTransfer("feat,master")
   944  	req := &http.Request{
   945  		Header: http.Header{},
   946  	}
   947  	reqHeader := &thttp.ClientReqHeader{
   948  		Request: req,
   949  	}
   950  	msg.WithClientReqHead(reqHeader)
   951  	rspHeader := &thttp.ClientRspHeader{}
   952  	msg.WithClientRspHead(rspHeader)
   953  	_, err := ct.RoundTrip(ctx, nil)
   954  	a.NotNil(err)
   955  	info, err := thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo))
   956  	a.NoError(err)
   957  	a.Equal(string(info[thttp.TrpcEnv]), "feat,master")
   958  
   959  	msg.WithEnvTransfer("") // DisableServiceRouter would clear EnvTransfer
   960  	_, err = ct.RoundTrip(ctx, nil)
   961  	a.NotNil(err)
   962  	info, err = thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo))
   963  	a.NoError(err)
   964  	a.Equal(string(info[thttp.TrpcEnv]), "")
   965  }
   966  
   967  func TestHTTPSUseClientVerify(t *testing.T) {
   968  	const (
   969  		network = "tcp"
   970  		address = "127.0.0.1:0"
   971  	)
   972  	ln, err := net.Listen(network, address)
   973  	require.Nil(t, err)
   974  	defer ln.Close()
   975  	serviceName := "trpc.app.server.Service" + t.Name()
   976  	service := server.New(
   977  		server.WithServiceName(serviceName),
   978  		server.WithNetwork(network),
   979  		server.WithProtocol("http_no_protocol"),
   980  		server.WithListener(ln),
   981  		server.WithTLS(
   982  			"../testdata/server.crt",
   983  			"../testdata/server.key",
   984  			"../testdata/ca.pem",
   985  		),
   986  	)
   987  	pattern := "/" + t.Name()
   988  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   989  		w.Write([]byte(t.Name()))
   990  	}))
   991  	s := &server.Server{}
   992  	s.AddService(serviceName, service)
   993  	go s.Serve()
   994  	defer s.Close(nil)
   995  	time.Sleep(100 * time.Millisecond)
   996  
   997  	c := thttp.NewClientProxy(
   998  		serviceName,
   999  		client.WithTarget("ip://"+ln.Addr().String()),
  1000  	)
  1001  	req := &codec.Body{}
  1002  	rsp := &codec.Body{}
  1003  	require.Nil(t,
  1004  		c.Post(context.Background(), pattern, req, rsp,
  1005  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1006  			client.WithSerializationType(codec.SerializationTypeNoop),
  1007  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1008  			client.WithTLS(
  1009  				"../testdata/client.crt",
  1010  				"../testdata/client.key",
  1011  				"../testdata/ca.pem",
  1012  				"localhost",
  1013  			),
  1014  		))
  1015  	require.Equal(t, []byte(t.Name()), rsp.Data)
  1016  }
  1017  
  1018  func TestHTTPSSkipClientVerify(t *testing.T) {
  1019  	const (
  1020  		network = "tcp"
  1021  		address = "127.0.0.1:0"
  1022  	)
  1023  	ln, err := net.Listen(network, address)
  1024  	require.Nil(t, err)
  1025  	defer ln.Close()
  1026  	serviceName := "trpc.app.server.Service" + t.Name()
  1027  	service := server.New(
  1028  		server.WithServiceName(serviceName),
  1029  		server.WithNetwork(network),
  1030  		server.WithProtocol("http_no_protocol"),
  1031  		server.WithListener(ln),
  1032  		server.WithTLS(
  1033  			"../testdata/server.crt",
  1034  			"../testdata/server.key",
  1035  			"",
  1036  		),
  1037  	)
  1038  	pattern := "/" + t.Name()
  1039  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  1040  		w.Write([]byte(t.Name()))
  1041  	}))
  1042  	s := &server.Server{}
  1043  	s.AddService(serviceName, service)
  1044  	go s.Serve()
  1045  	defer s.Close(nil)
  1046  	time.Sleep(100 * time.Millisecond)
  1047  
  1048  	c := thttp.NewClientProxy(
  1049  		serviceName,
  1050  		client.WithTarget("ip://"+ln.Addr().String()),
  1051  	)
  1052  	req := &codec.Body{}
  1053  	rsp := &codec.Body{}
  1054  	require.Nil(t,
  1055  		c.Post(context.Background(), pattern, req, rsp,
  1056  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1057  			client.WithSerializationType(codec.SerializationTypeNoop),
  1058  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1059  			client.WithTLS(
  1060  				"", "", "none", "",
  1061  			),
  1062  		))
  1063  	require.Equal(t, []byte(t.Name()), rsp.Data)
  1064  }
  1065  
  1066  func TestListenAndServeHTTPHead(t *testing.T) {
  1067  	ctx := context.Background()
  1068  	const (
  1069  		network = "tcp"
  1070  		address = "127.0.0.1:0"
  1071  	)
  1072  	ln, err := net.Listen(network, address)
  1073  	require.Nil(t, err)
  1074  	defer ln.Close()
  1075  	st := thttp.NewServerTransport(newNoopStdHTTPServer)
  1076  	require.Nil(t, st.ListenAndServe(ctx,
  1077  		transport.WithHandler(&httpHeadHandler{
  1078  			func(ctx context.Context, _ []byte) (rsp []byte, err error) {
  1079  				head := thttp.Head(ctx)
  1080  				head.Response.WriteHeader(http.StatusOK)
  1081  				head.Response.Write([]byte(fmt.Sprintf("%+v", thttp.Head(head.Request.Context()) != nil)))
  1082  				return
  1083  			}}),
  1084  		transport.WithListener(ln),
  1085  	))
  1086  	time.Sleep(200 * time.Millisecond)
  1087  	rsp, err := http.Get("http://" + ln.Addr().String())
  1088  	require.Nil(t, err)
  1089  	bs, err := io.ReadAll(rsp.Body)
  1090  	require.Nil(t, err)
  1091  	require.Equal(t, fmt.Sprintf("%+v", true), string(bs))
  1092  }
  1093  
  1094  type httpHeadHandler struct {
  1095  	handle func(ctx context.Context, req []byte) (rsp []byte, err error)
  1096  }
  1097  
  1098  func (h *httpHeadHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) {
  1099  	return h.handle(ctx, req)
  1100  }
  1101  
  1102  func TestHTTPStreamFileUpload(t *testing.T) {
  1103  	// Start server.
  1104  	const (
  1105  		network = "tcp"
  1106  		address = "127.0.0.1:0"
  1107  	)
  1108  	ln, err := net.Listen(network, address)
  1109  	require.Nil(t, err)
  1110  	defer ln.Close()
  1111  	go http.Serve(ln, &fileHandler{})
  1112  	// Start client.
  1113  	c := thttp.NewClientProxy(
  1114  		"trpc.app.server.Service_http",
  1115  		client.WithTarget("ip://"+ln.Addr().String()),
  1116  	)
  1117  	// Open and read file.
  1118  	fileDir, err := os.Getwd()
  1119  	require.Nil(t, err)
  1120  	fileName := "README.md"
  1121  	filePath := path.Join(fileDir, fileName)
  1122  	file, err := os.Open(filePath)
  1123  	require.Nil(t, err)
  1124  	defer file.Close()
  1125  	// Construct multipart form file.
  1126  	body := &bytes.Buffer{}
  1127  	writer := multipart.NewWriter(body)
  1128  	part, err := writer.CreateFormFile("field_name", filepath.Base(file.Name()))
  1129  	require.Nil(t, err)
  1130  	io.Copy(part, file)
  1131  	require.Nil(t, writer.Close())
  1132  	// Add multipart form data header.
  1133  	header := http.Header{}
  1134  	header.Add("Content-Type", writer.FormDataContentType())
  1135  	reqHeader := &thttp.ClientReqHeader{
  1136  		Header:  header,
  1137  		ReqBody: body, // Stream send.
  1138  	}
  1139  	req := &codec.Body{}
  1140  	rsp := &codec.Body{}
  1141  	// Upload file.
  1142  	require.Nil(t,
  1143  		c.Post(context.Background(), "/", req, rsp,
  1144  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1145  			client.WithSerializationType(codec.SerializationTypeNoop),
  1146  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1147  			client.WithReqHead(reqHeader),
  1148  		))
  1149  	require.Equal(t, []byte(fileName), rsp.Data)
  1150  }
  1151  
  1152  type fileHandler struct{}
  1153  
  1154  func (*fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1155  	_, h, err := r.FormFile("field_name")
  1156  	if err != nil {
  1157  		w.WriteHeader(http.StatusBadRequest)
  1158  		return
  1159  	}
  1160  	w.WriteHeader(http.StatusOK)
  1161  	// Write back file name.
  1162  	w.Write([]byte(h.Filename))
  1163  	return
  1164  }
  1165  
  1166  func TestHTTPStreamRead(t *testing.T) {
  1167  	// Start server.
  1168  	const (
  1169  		network = "tcp"
  1170  		address = "127.0.0.1:0"
  1171  	)
  1172  	ln, err := net.Listen(network, address)
  1173  	require.Nil(t, err)
  1174  	defer ln.Close()
  1175  	go http.Serve(ln, &fileServer{})
  1176  
  1177  	// Start client.
  1178  	c := thttp.NewClientProxy(
  1179  		"trpc.app.server.Service_http",
  1180  		client.WithTarget("ip://"+ln.Addr().String()),
  1181  	)
  1182  
  1183  	// Enable manual body reading in order to
  1184  	// disable the framework's automatic body reading capability,
  1185  	// so that users can manually do their own client-side streaming reads.
  1186  	rspHead := &thttp.ClientRspHeader{
  1187  		ManualReadBody: true,
  1188  	}
  1189  	req := &codec.Body{}
  1190  	rsp := &codec.Body{}
  1191  	require.Nil(t,
  1192  		c.Post(context.Background(), "/", req, rsp,
  1193  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1194  			client.WithSerializationType(codec.SerializationTypeNoop),
  1195  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1196  			client.WithRspHead(rspHead),
  1197  		))
  1198  	require.Nil(t, rsp.Data)
  1199  	body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body.
  1200  	defer body.Close()            // Do remember to close the body.
  1201  	bs, err := io.ReadAll(body)
  1202  	require.Nil(t, err)
  1203  	require.NotNil(t, bs)
  1204  }
  1205  
  1206  func TestHTTPSendReceiveChunk(t *testing.T) {
  1207  	// HTTP chunked example:
  1208  	//   1. Client sends chunks: Add "chunked" transfer encoding header, and use io.Reader as body.
  1209  	//   2. Client reads chunks: The Go/net/http automatically handles the chunked reading.
  1210  	//                           Users can simply read resp.Body in a loop until io.EOF.
  1211  	//   3. Server reads chunks: Similar to client reads chunks.
  1212  	//   4. Server sends chunks: Assert http.ResponseWriter as http.Flusher, call flusher.Flush() after
  1213  	//         writing a part of data, it will automatically trigger "chunked" encoding to send a chunk.
  1214  
  1215  	// Start server.
  1216  	const (
  1217  		network = "tcp"
  1218  		address = "127.0.0.1:0"
  1219  	)
  1220  	ln, err := net.Listen(network, address)
  1221  	require.Nil(t, err)
  1222  	defer ln.Close()
  1223  	go http.Serve(ln, &chunkedServer{})
  1224  
  1225  	// Start client.
  1226  	c := thttp.NewClientProxy(
  1227  		"trpc.app.server.Service_http",
  1228  		client.WithTarget("ip://"+ln.Addr().String()),
  1229  	)
  1230  
  1231  	// Open and read file.
  1232  	fileDir, err := os.Getwd()
  1233  	require.Nil(t, err)
  1234  	fileName := "README.md"
  1235  	filePath := path.Join(fileDir, fileName)
  1236  	file, err := os.Open(filePath)
  1237  	require.Nil(t, err)
  1238  	defer file.Close()
  1239  
  1240  	// 1. Client sends chunks.
  1241  
  1242  	// Add request headers.
  1243  	header := http.Header{}
  1244  	header.Add("Content-Type", "text/plain")
  1245  	// Add chunked transfer encoding header.
  1246  	header.Add("Transfer-Encoding", "chunked")
  1247  	reqHead := &thttp.ClientReqHeader{
  1248  		Header:  header,
  1249  		ReqBody: file, // Stream send (for chunks).
  1250  	}
  1251  
  1252  	// Enable manual body reading in order to
  1253  	// disable the framework's automatic body reading capability,
  1254  	// so that users can manually do their own client-side streaming reads.
  1255  	rspHead := &thttp.ClientRspHeader{
  1256  		ManualReadBody: true,
  1257  	}
  1258  	req := &codec.Body{}
  1259  	rsp := &codec.Body{}
  1260  	require.Nil(t,
  1261  		c.Post(context.Background(), "/", req, rsp,
  1262  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1263  			client.WithSerializationType(codec.SerializationTypeNoop),
  1264  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1265  			client.WithReqHead(reqHead),
  1266  			client.WithRspHead(rspHead),
  1267  		))
  1268  	require.Nil(t, rsp.Data)
  1269  
  1270  	// 2. Client reads chunks.
  1271  
  1272  	// Do stream reads directly from rspHead.Response.Body.
  1273  	body := rspHead.Response.Body
  1274  	defer body.Close() // Do remember to close the body.
  1275  	buf := make([]byte, 4096)
  1276  	var idx int
  1277  	for {
  1278  		n, err := body.Read(buf)
  1279  		if err == io.EOF {
  1280  			t.Logf("reached io.EOF\n")
  1281  			break
  1282  		}
  1283  		t.Logf("read chunk %d of length %d: %q\n", idx, n, buf[:n])
  1284  		idx++
  1285  	}
  1286  }
  1287  
  1288  type chunkedServer struct{}
  1289  
  1290  func (*chunkedServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1291  	// 3. Server reads chunks.
  1292  
  1293  	// io.ReadAll will read until io.EOF.
  1294  	// Go/net/http will automatically handle chunked body reads.
  1295  	bs, err := io.ReadAll(r.Body)
  1296  	if err != nil {
  1297  		w.WriteHeader(http.StatusInternalServerError)
  1298  		w.Write([]byte(fmt.Sprintf("io.ReadAll err: %+v", err)))
  1299  		return
  1300  	}
  1301  
  1302  	// 4. Server sends chunks.
  1303  
  1304  	// Send HTTP chunks using http.Flusher.
  1305  	// Reference: https://stackoverflow.com/questions/26769626/send-a-chunked-http-response-from-a-go-server.
  1306  	// The "Transfer-Encoding" header will be handled by the writer implicitly, so no need to set it.
  1307  	flusher, ok := w.(http.Flusher)
  1308  	if !ok {
  1309  		w.WriteHeader(http.StatusInternalServerError)
  1310  		w.Write([]byte("expected http.ResponseWriter to be an http.Flusher"))
  1311  		return
  1312  	}
  1313  	chunks := 10
  1314  	chunkSize := (len(bs) + chunks - 1) / chunks
  1315  	for i := 0; i < chunks; i++ {
  1316  		start := i * chunkSize
  1317  		end := (i + 1) * chunkSize
  1318  		if end > len(bs) {
  1319  			end = len(bs)
  1320  		}
  1321  		w.Write(bs[start:end])
  1322  		flusher.Flush() // Trigger "chunked" encoding and send a chunk.
  1323  		time.Sleep(500 * time.Millisecond)
  1324  	}
  1325  	return
  1326  }
  1327  
  1328  type fileServer struct{}
  1329  
  1330  func (*fileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1331  	http.ServeFile(w, r, "./README.md")
  1332  	return
  1333  }
  1334  
  1335  func TestHTTPSendAndReceiveSSE(t *testing.T) {
  1336  	const (
  1337  		network = "tcp"
  1338  		address = "127.0.0.1:0"
  1339  	)
  1340  	ln, err := net.Listen(network, address)
  1341  	require.Nil(t, err)
  1342  	defer ln.Close()
  1343  	serviceName := "trpc.app.server.Service" + t.Name()
  1344  	service := server.New(
  1345  		server.WithServiceName(serviceName),
  1346  		server.WithNetwork(network),
  1347  		server.WithProtocol("http_no_protocol"),
  1348  		server.WithListener(ln),
  1349  	)
  1350  	pattern := "/" + t.Name()
  1351  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1352  		flusher, ok := w.(http.Flusher)
  1353  		if !ok {
  1354  			http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
  1355  			return
  1356  		}
  1357  		w.Header().Set("Content-Type", "text/event-stream")
  1358  		w.Header().Set("Cache-Control", "no-cache")
  1359  		w.Header().Set("Connection", "keep-alive")
  1360  		w.Header().Set("Access-Control-Allow-Origin", "*")
  1361  		bs, err := io.ReadAll(r.Body)
  1362  		if err != nil {
  1363  			http.Error(w, err.Error(), http.StatusBadRequest)
  1364  			return
  1365  		}
  1366  		msg := string(bs)
  1367  		for i := 0; i < 3; i++ {
  1368  			msgBytes := []byte("event: message\n\ndata: " + msg + strconv.Itoa(i) + "\n\n")
  1369  			_, err = w.Write(msgBytes)
  1370  			if err != nil {
  1371  				http.Error(w, err.Error(), http.StatusInternalServerError)
  1372  				return
  1373  			}
  1374  			flusher.Flush()
  1375  			time.Sleep(500 * time.Millisecond)
  1376  		}
  1377  		return
  1378  	}))
  1379  	s := &server.Server{}
  1380  	s.AddService(serviceName, service)
  1381  	go s.Serve()
  1382  	defer s.Close(nil)
  1383  	time.Sleep(100 * time.Millisecond)
  1384  
  1385  	c := thttp.NewClientProxy(
  1386  		serviceName,
  1387  		client.WithTarget("ip://"+ln.Addr().String()),
  1388  	)
  1389  	header := http.Header{}
  1390  	header.Set("Cache-Control", "no-cache")
  1391  	header.Set("Accept", "text/event-stream")
  1392  	header.Set("Connection", "keep-alive")
  1393  	reqHeader := &thttp.ClientReqHeader{
  1394  		Header: header,
  1395  	}
  1396  	// Enable manual body reading in order to
  1397  	// disable the framework's automatic body reading capability,
  1398  	// so that users can manually do their own client-side streaming reads.
  1399  	rspHead := &thttp.ClientRspHeader{
  1400  		ManualReadBody: true,
  1401  	}
  1402  	req := &codec.Body{Data: []byte("hello")}
  1403  	rsp := &codec.Body{}
  1404  	require.Nil(t,
  1405  		c.Post(context.Background(), pattern, req, rsp,
  1406  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1407  			client.WithSerializationType(codec.SerializationTypeNoop),
  1408  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1409  			client.WithReqHead(reqHeader),
  1410  			client.WithRspHead(rspHead),
  1411  			client.WithTimeout(time.Minute),
  1412  		))
  1413  	body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body.
  1414  	defer body.Close()            // Do remember to close the body.
  1415  	data := make([]byte, 1024)
  1416  	for {
  1417  		n, err := body.Read(data)
  1418  		if err == io.EOF {
  1419  			break
  1420  		}
  1421  		require.Nil(t, err)
  1422  		t.Logf("Received message: \n%s\n", string(data[:n]))
  1423  	}
  1424  }
  1425  
  1426  func TestHTTPClientReqRspDifferentContentType(t *testing.T) {
  1427  	const (
  1428  		network = "tcp"
  1429  		address = "127.0.0.1:0"
  1430  	)
  1431  	ln, err := net.Listen(network, address)
  1432  	require.Nil(t, err)
  1433  	defer ln.Close()
  1434  	serviceName := "trpc.app.server.Service" + t.Name()
  1435  	service := server.New(
  1436  		server.WithServiceName(serviceName),
  1437  		server.WithNetwork(network),
  1438  		server.WithProtocol("http_no_protocol"),
  1439  		server.WithListener(ln),
  1440  	)
  1441  	const (
  1442  		hello = "hello "
  1443  		key   = "key"
  1444  	)
  1445  	pattern := "/" + t.Name()
  1446  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1447  		bs, err := io.ReadAll(r.Body)
  1448  		if err != nil {
  1449  			w.WriteHeader(http.StatusBadRequest)
  1450  			return
  1451  		}
  1452  		req, err := url.ParseQuery(string(bs))
  1453  		if err != nil {
  1454  			w.WriteHeader(http.StatusBadRequest)
  1455  			return
  1456  		}
  1457  		rsp := &helloworld.HelloReply{Message: hello + req.Get(key)}
  1458  		bs, err = codec.Marshal(codec.SerializationTypePB, rsp)
  1459  		if err != nil {
  1460  			w.WriteHeader(http.StatusInternalServerError)
  1461  			return
  1462  		}
  1463  		w.Header().Add("Content-Type", "application/protobuf")
  1464  		w.Write(bs)
  1465  		return
  1466  	}))
  1467  	s := &server.Server{}
  1468  	s.AddService(serviceName, service)
  1469  	go s.Serve()
  1470  	defer s.Close(nil)
  1471  	time.Sleep(100 * time.Millisecond)
  1472  
  1473  	c := thttp.NewClientProxy(
  1474  		serviceName,
  1475  		client.WithTarget("ip://"+ln.Addr().String()),
  1476  	)
  1477  	req := make(url.Values)
  1478  	req.Add(key, t.Name())
  1479  	rsp := &helloworld.HelloReply{}
  1480  	require.Nil(t,
  1481  		c.Post(context.Background(), pattern, req, rsp,
  1482  			client.WithSerializationType(codec.SerializationTypeForm),
  1483  		))
  1484  	require.Equal(t, hello+t.Name(), rsp.Message)
  1485  }
  1486  
  1487  func TestHTTPGotConnectionRemoteAddr(t *testing.T) {
  1488  	ctx := context.Background()
  1489  	for i := 0; i < 3; i++ {
  1490  		proxy := thttp.NewClientProxy(t.Name(), client.WithTarget("dns://new.qq.com/"))
  1491  		rsp := &codec.Body{}
  1492  		require.Nil(t, proxy.Get(ctx, "/", rsp,
  1493  			client.WithSerializationType(codec.SerializationTypeNoop),
  1494  			client.WithFilter(
  1495  				func(ctx context.Context, req, rsp interface{}, next filter.ClientHandleFunc) error {
  1496  					err := next(ctx, req, rsp)
  1497  					msg := codec.Message(ctx)
  1498  					addr := msg.RemoteAddr()
  1499  					require.NotNil(t, addr, "expect to get remote addr from msg in connection reuse case")
  1500  					t.Logf("addr = %+v\n", addr)
  1501  					return err
  1502  				})))
  1503  	}
  1504  }
  1505  
  1506  type h struct{}
  1507  
  1508  func (*h) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1509  	fmt.Println("recv http req")
  1510  	return nil, nil
  1511  }
  1512  
  1513  type testLog struct {
  1514  	log.Logger
  1515  	errorCh chan error
  1516  }
  1517  
  1518  func (ln *testLog) Errorf(format string, args ...interface{}) {
  1519  	ln.errorCh <- fmt.Errorf(format, args...)
  1520  }
  1521  
  1522  // mockService is a mock service.
  1523  type mockService struct {
  1524  	desc interface{}
  1525  }
  1526  
  1527  // Register registers route information.
  1528  func (m *mockService) Register(serviceDesc interface{}, serviceImpl interface{}) error {
  1529  	m.desc = serviceDesc
  1530  	return nil
  1531  }
  1532  
  1533  // Serve runs service.
  1534  func (m *mockService) Serve() error {
  1535  	return nil
  1536  }
  1537  
  1538  // Close closes service.
  1539  func (m *mockService) Close(chan struct{}) error {
  1540  	return nil
  1541  }
  1542  
  1543  type errHandler struct{}
  1544  
  1545  func (*errHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1546  	return nil, errors.New("mock error")
  1547  }
  1548  
  1549  type errHeaderHandler struct{}
  1550  
  1551  func (*errHeaderHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1552  	return nil, thttp.ErrEncodeMissingHeader
  1553  }