trpc.group/trpc-go/trpc-go@v1.0.2/http/transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http_test
    15  
    16  // https certificate file generation method:
    17  // 1. ca certificate:
    18  // openssl genrsa -out ca.key 2048
    19  // openssl req -x509 -new -nodes -key ca.key -subj "/CN=*" -days 5000 -out ca.pem
    20  // 2. server certificate:
    21  // openssl genrsa -out server.key 2048
    22  // openssl req -new -key server.key -subj "/CN=*" -out server.csr
    23  // openssl x509 -req -in server.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out server.crt -days 5000 <(printf "subjectAltName=DNS:localhost")
    24  // 3. client certificate:
    25  // openssl genrsa -out client.key 2048
    26  // openssl req -new -key client.key -subj "/CN=*" -out client.csr
    27  // openssl x509 -req -in client.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out client.crt -days 5000 <(printf "subjectAltName=DNS:localhost")
    28  // 4. show certificate content:
    29  // openssl x509 -text -in server.crt -noout
    30  
    31  import (
    32  	"bytes"
    33  	"context"
    34  	"crypto/tls"
    35  	"errors"
    36  	"fmt"
    37  	"io"
    38  	"mime/multipart"
    39  	"net"
    40  	"net/http"
    41  	"net/http/httptest"
    42  	"net/url"
    43  	"os"
    44  	"path"
    45  	"path/filepath"
    46  	"strconv"
    47  	"strings"
    48  	"testing"
    49  	"time"
    50  
    51  	"github.com/stretchr/testify/require"
    52  	"golang.org/x/net/http2"
    53  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    54  
    55  	"trpc.group/trpc-go/trpc-go/client"
    56  	"trpc.group/trpc-go/trpc-go/codec"
    57  	"trpc.group/trpc-go/trpc-go/errs"
    58  	"trpc.group/trpc-go/trpc-go/filter"
    59  	thttp "trpc.group/trpc-go/trpc-go/http"
    60  	"trpc.group/trpc-go/trpc-go/log"
    61  	"trpc.group/trpc-go/trpc-go/naming/registry"
    62  	"trpc.group/trpc-go/trpc-go/server"
    63  	"trpc.group/trpc-go/trpc-go/testdata/restful/helloworld"
    64  	"trpc.group/trpc-go/trpc-go/transport"
    65  )
    66  
    67  func newNoopStdHTTPServer() *http.Server { return &http.Server{} }
    68  
    69  func TestStartServer(t *testing.T) {
    70  	ctx := context.Background()
    71  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
    72  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    73  	require.Nil(t, err)
    74  	defer ln.Close()
    75  	option := transport.WithListener(ln)
    76  	handler := transport.WithHandler(transport.Handler(&h{}))
    77  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
    78  	require.NotNil(t, tp.ListenAndServe(ctx, transport.WithListenAddress("127.0.0.1:8888"), handler, transport.WithListenNetwork("tcp1")))
    79  	tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "ca1")
    80  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls))
    81  }
    82  
    83  func TestH2C(t *testing.T) {
    84  	ctx := context.Background()
    85  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    86  	require.Nil(t, err)
    87  	defer ln.Close()
    88  	handler := transport.WithHandler(transport.Handler(&h{}))
    89  	tp := thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithReusePort(), thttp.WithEnableH2C())
    90  	require.Nil(t, tp.ListenAndServe(ctx, transport.WithListener(ln), handler))
    91  }
    92  
    93  func TestDisableReusePort(t *testing.T) {
    94  	ctx := context.Background()
    95  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
    96  	ln1, err := net.Listen("tcp", "127.0.0.1:0")
    97  	require.Nil(t, err)
    98  	defer ln1.Close()
    99  	option := transport.WithListener(ln1)
   100  	handler := transport.WithHandler(transport.Handler(&h{}))
   101  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   102  
   103  	option = transport.WithListenAddress(ln1.Addr().String())
   104  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, transport.WithListenNetwork("tcp1")))
   105  
   106  	ln2, err := net.Listen("tcp", "127.0.0.1:0")
   107  	require.Nil(t, err)
   108  	defer ln2.Close()
   109  	option = transport.WithListener(ln2)
   110  	tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "")
   111  	require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls))
   112  
   113  	ln3, err := net.Listen("tcp", "127.0.0.1:0")
   114  	require.Nil(t, err)
   115  	defer ln3.Close()
   116  	option = transport.WithListener(ln3)
   117  	tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "root")
   118  	require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls))
   119  
   120  	ln4, err := net.Listen("tcp", "127.0.0.1:0")
   121  	require.Nil(t, err)
   122  	defer ln4.Close()
   123  	option = transport.WithListener(ln4)
   124  	tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.key")
   125  	require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls))
   126  }
   127  
   128  func TestStartServerWithNoHandler(t *testing.T) {
   129  	ctx := context.Background()
   130  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   131  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   132  	require.Nil(t, err)
   133  	defer ln.Close()
   134  	option := transport.WithListener(ln)
   135  	require.NotNil(t, tp.ListenAndServe(ctx, option), "http server transport handler empty")
   136  }
   137  
   138  func TestErrHandler(t *testing.T) {
   139  	ctx := context.Background()
   140  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   141  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   142  	require.Nil(t, err)
   143  	defer ln.Close()
   144  	option := transport.WithListener(ln)
   145  	h := transport.WithHandler(transport.Handler(&errHandler{}))
   146  	require.Nil(t, tp.ListenAndServe(ctx, option, h))
   147  
   148  	ct := thttp.NewClientTransport(true)
   149  	ctx, msg := codec.WithNewMessage(ctx)
   150  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   151  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   152  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   153  
   154  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   155  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   156  		transport.WithDialAddress(ln.Addr().String()),
   157  	)
   158  	require.Nil(t, rsp, "roundtrip rsp not empty")
   159  	require.Nil(t, err, "Failed to roundtrip")
   160  }
   161  
   162  func TestErrHeaderHandler(t *testing.T) {
   163  	ctx := context.Background()
   164  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   165  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   166  	require.Nil(t, err)
   167  	defer func() { require.Nil(t, ln.Close()) }()
   168  	err = tp.ListenAndServe(ctx,
   169  		transport.WithHandler(transport.Handler(&errHeaderHandler{})),
   170  		transport.WithListener(ln),
   171  	)
   172  	require.Nil(t, err)
   173  
   174  	ct := thttp.NewClientTransport(true)
   175  	ctx, msg := codec.WithNewMessage(ctx)
   176  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   177  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   178  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   179  
   180  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   181  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   182  		transport.WithDialAddress(ln.Addr().String()),
   183  	)
   184  	require.Nil(t, rsp, "roundtrip rsp not empty")
   185  	require.Nil(t, err, "Failed to roundtrip")
   186  }
   187  
   188  func TestListenAndServeFailedDueToBadCertificationFile(t *testing.T) {
   189  	ctx := context.Background()
   190  	oldLogger := log.DefaultLogger
   191  	defer func() {
   192  		log.DefaultLogger = oldLogger
   193  	}()
   194  	errorCh := make(chan error)
   195  	log.DefaultLogger = &testLog{Logger: oldLogger, errorCh: errorCh}
   196  
   197  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   198  	require.Nil(t, err)
   199  	defer func() { require.Nil(t, ln.Close()) }()
   200  	const badCertFile = "bad-file.cert"
   201  	require.Nil(
   202  		t,
   203  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   204  			ctx,
   205  			transport.WithListener(ln),
   206  			transport.WithHandler(transport.Handler(&h{})),
   207  			transport.WithServeTLS(badCertFile, "../testdata/server.key", ""),
   208  		),
   209  		"failed to new client transport",
   210  	)
   211  
   212  	select {
   213  	case <-time.After(time.Second):
   214  		t.Fatal("listen on a bad cert should log an error")
   215  	case err := <-errorCh:
   216  		require.Contains(t, err.Error(), badCertFile)
   217  	}
   218  }
   219  
   220  func TestStartTLSServerAndNoCheckServer(t *testing.T) {
   221  	ctx := context.Background()
   222  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   223  	require.Nil(t, err)
   224  	defer func() { require.Nil(t, ln.Close()) }()
   225  	// Only enables https server and do not verify client certificate.
   226  	require.Nil(
   227  		t,
   228  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   229  			ctx,
   230  			transport.WithListener(ln),
   231  			transport.WithHandler(transport.Handler(&h{})),
   232  			transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   233  		),
   234  		"Failed to new client transport",
   235  	)
   236  
   237  	ct := thttp.NewClientTransport(false)
   238  	ctx, msg := codec.WithNewMessage(ctx)
   239  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   240  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   241  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   242  
   243  	rsp, err := ct.RoundTrip(
   244  		ctx,
   245  		[]byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   246  		transport.WithDialAddress(ln.Addr().String()),
   247  		// Fully trust the https server and do not verify server certificate,
   248  		// can only be used in test env.
   249  		transport.WithDialTLS("", "", "none", ""),
   250  	)
   251  	require.Nil(t, rsp, "roundtrip rsp not empty")
   252  	require.Nil(t, err, "Failed to roundtrip")
   253  }
   254  
   255  func TestServerWithListenerOption(t *testing.T) {
   256  	ln, err := net.Listen("tcp", "localhost:0")
   257  	require.Nil(t, err)
   258  	defer ln.Close()
   259  	service := server.New(
   260  		server.WithServiceName("trpc.http.server.ListenerTest"),
   261  		server.WithNetwork("tcp"),
   262  		server.WithProtocol("http"),
   263  		server.WithListener(ln),
   264  	)
   265  	thttp.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) error {
   266  		fmt.Printf("Protocol: %s\n", r.Proto)
   267  		w.Write([]byte(r.Proto))
   268  		return nil
   269  	})
   270  	thttp.RegisterDefaultService(service)
   271  	s := &server.Server{}
   272  	s.AddService("trpc.http.server.ListenerTest", service)
   273  	go func() {
   274  		require.Nil(t, s.Serve())
   275  	}()
   276  	defer s.Close(nil)
   277  	time.Sleep(100 * time.Millisecond)
   278  
   279  	resp, err := http.Get(fmt.Sprintf("http://%v/index", ln.Addr()))
   280  	require.Nil(t, err)
   281  	defer resp.Body.Close()
   282  	body, err := io.ReadAll(resp.Body)
   283  	require.Nil(t, err)
   284  	require.Equal(t, []byte("HTTP/1.1"), body)
   285  
   286  	const invalidAddr = "localhost:910439"
   287  	resp, err = http.Get(fmt.Sprintf("http://%s/index", invalidAddr))
   288  	require.NotNil(t, err)
   289  	require.Nil(t, resp)
   290  }
   291  
   292  func TestStartDisableKeepAlivesServer(t *testing.T) {
   293  	ln, err := net.Listen("tcp", "localhost:0")
   294  	require.Nil(t, err)
   295  	defer ln.Close()
   296  	s := &server.Server{}
   297  	service := server.New(
   298  		server.WithListener(ln),
   299  		server.WithServiceName("trpc.http.server.ListenerTest"),
   300  		server.WithNetwork("tcp"),
   301  		server.WithProtocol("http"),
   302  		server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer)),
   303  		server.WithDisableKeepAlives(true),
   304  	)
   305  	thttp.HandleFunc("/disable-keepalives", func(w http.ResponseWriter, _ *http.Request) error {
   306  		w.Header().Set("Connection", "keep-alive")
   307  		return nil
   308  	})
   309  	thttp.RegisterDefaultService(service)
   310  	s.AddService("trpc.http.server.ListenerTest", service)
   311  	go func() {
   312  		err := s.Serve()
   313  		require.Nil(t, err)
   314  	}()
   315  	defer func() {
   316  		_ = s.Close(nil)
   317  	}()
   318  
   319  	time.Sleep(100 * time.Millisecond)
   320  
   321  	dailCount := 0
   322  	client := &http.Client{
   323  		Transport: &http.Transport{
   324  			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   325  				dailCount++
   326  				conn, err := (&net.Dialer{}).DialContext(ctx, network, addr)
   327  				return conn, err
   328  			},
   329  		},
   330  	}
   331  	num := 3
   332  	url := fmt.Sprintf("http://%s/disable-keepalives", ln.Addr())
   333  	for i := 0; i < num; i++ {
   334  		resp, err := client.Get(url)
   335  		require.Nil(t, err)
   336  		defer resp.Body.Close()
   337  		_, err = io.Copy(io.Discard, resp.Body)
   338  		require.Nil(t, err)
   339  	}
   340  	require.Equal(t, num, dailCount)
   341  }
   342  
   343  func TestStartH2cServer(t *testing.T) {
   344  	ln, err := net.Listen("tcp", "localhost:0")
   345  	require.Nil(t, err)
   346  	defer ln.Close()
   347  	s := &server.Server{}
   348  	service := server.New(
   349  		server.WithListener(ln),
   350  		server.WithServiceName("trpc.h2c.server.Greeter"),
   351  		server.WithNetwork("tcp"),
   352  		server.WithProtocol("http2"),
   353  		server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithEnableH2C())),
   354  	)
   355  	thttp.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) error {
   356  		fmt.Printf("Protocol: %s\n", r.Proto)
   357  		w.Write([]byte(r.Proto))
   358  		return nil
   359  	})
   360  	thttp.HandleFunc("/main", func(w http.ResponseWriter, r *http.Request) error {
   361  		fmt.Printf("Protocol: %s\n", r.Proto)
   362  		w.Write([]byte(r.Proto))
   363  		return nil
   364  	})
   365  	thttp.RegisterDefaultService(service)
   366  	s.AddService("trpc.h2c.server.Greeter", service)
   367  
   368  	go func() {
   369  		err := s.Serve()
   370  		require.Nil(t, err)
   371  	}()
   372  
   373  	time.Sleep(100 * time.Millisecond)
   374  
   375  	// h2c client
   376  	h2cClient := http.Client{
   377  		Transport: &http2.Transport{
   378  			AllowHTTP: true,
   379  			DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   380  				return net.Dial(network, addr)
   381  			},
   382  		},
   383  	}
   384  	url := fmt.Sprintf("http://%s/", ln.Addr())
   385  	resp, err := h2cClient.Get(url + "main")
   386  	require.Nil(t, err)
   387  	defer resp.Body.Close()
   388  	body, err := io.ReadAll(resp.Body)
   389  	require.Nil(t, err)
   390  	require.Equal(t, []byte("HTTP/2.0"), body)
   391  
   392  	// http1 client
   393  	resp2, err := http.Get(url)
   394  	require.Nil(t, err)
   395  	defer resp2.Body.Close()
   396  	body, err = io.ReadAll(resp2.Body)
   397  	require.Nil(t, err)
   398  	require.Equal(t, []byte("HTTP/1.1"), body)
   399  	require.Equal(t, http.StatusOK, resp2.StatusCode)
   400  }
   401  
   402  func TestHttp2StartTLSServerAndNoCheckServer(t *testing.T) {
   403  	ctx := context.Background()
   404  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   405  	require.Nil(t, err)
   406  	defer func() { require.Nil(t, ln.Close()) }()
   407  	// Only enables https server and do not verify client certificate.
   408  	require.Nil(
   409  		t,
   410  		thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe(
   411  			ctx,
   412  			transport.WithListener(ln),
   413  			transport.WithHandler(transport.Handler(&h{})),
   414  			transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   415  		),
   416  		"Failed to new client transport",
   417  	)
   418  
   419  	ct := thttp.NewClientTransport(true)
   420  	ctx, msg := codec.WithNewMessage(ctx)
   421  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   422  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   423  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   424  
   425  	rsp, err := ct.RoundTrip(
   426  		ctx,
   427  		[]byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   428  		transport.WithDialAddress(ln.Addr().String()),
   429  		// Fully trust the https server and do not verify server certificate,
   430  		// can only be used in test env.
   431  		transport.WithDialTLS("", "", "none", ""),
   432  	)
   433  	require.Nil(t, rsp, "roundtrip rsp not empty")
   434  	require.Nil(t, err, "Failed to roundtrip")
   435  }
   436  
   437  func TestStartTLSServerAndCheckServer(t *testing.T) {
   438  	ctx := context.Background()
   439  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   440  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   441  	require.Nil(t, err)
   442  	defer func() { require.Nil(t, ln.Close()) }()
   443  	err = tp.ListenAndServe(ctx,
   444  		transport.WithHandler(transport.Handler(&h{})),
   445  		// Only enables https server and do not verify client certificate.
   446  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""),
   447  		transport.WithListener(ln),
   448  	)
   449  	require.Nil(t, err, "Failed to new client transport")
   450  
   451  	ct := thttp.NewClientTransport(false)
   452  	ctx, msg := codec.WithNewMessage(ctx)
   453  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   454  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   455  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   456  
   457  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   458  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   459  		transport.WithDialAddress(ln.Addr().String()),
   460  		// Uses ca public key to verify server certificate.
   461  		transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"),
   462  	)
   463  	require.Nil(t, rsp, "roundtrip rsp not empty")
   464  	require.Nil(t, err, "Failed to roundtrip")
   465  }
   466  
   467  func TestStartTLSServerAndCheckClientNoCert(t *testing.T) {
   468  	ctx := context.Background()
   469  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   470  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   471  	require.Nil(t, err)
   472  	defer func() { require.Nil(t, ln.Close()) }()
   473  	err = tp.ListenAndServe(ctx,
   474  		transport.WithHandler(transport.Handler(&h{})),
   475  		// Enables two-way authentication http server and need to verify client certificate.
   476  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   477  		transport.WithListener(ln),
   478  	)
   479  	require.Nil(t, err, "Failed to new client transport")
   480  
   481  	ct := thttp.NewClientTransport(false)
   482  	ctx, msg := codec.WithNewMessage(ctx)
   483  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   484  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   485  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   486  
   487  	_, err = ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   488  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   489  		transport.WithDialAddress(ln.Addr().String()),
   490  		// If the client's own certificate is not sent, will return TLS verification failed.
   491  		transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"),
   492  	)
   493  	require.NotNil(t, err, "Failed to roundtrip")
   494  }
   495  
   496  func TestStartTLSServerAndCheckClient(t *testing.T) {
   497  	ctx := context.Background()
   498  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   499  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   500  	require.Nil(t, err)
   501  	defer func() { require.Nil(t, ln.Close()) }()
   502  	// Enables two-way authentication http server and need to verify client certificate.
   503  	err = tp.ListenAndServe(ctx,
   504  		transport.WithHandler(transport.Handler(&h{})),
   505  		// Only enables https server and do not verify client certificate.
   506  		transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   507  		transport.WithListener(ln),
   508  	)
   509  	require.Nil(t, err, "Failed to new client transport")
   510  
   511  	ct := thttp.NewClientTransport(false)
   512  	ctx, msg := codec.WithNewMessage(ctx)
   513  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   514  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   515  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   516  
   517  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   518  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   519  		transport.WithDialAddress(ln.Addr().String()),
   520  		// Need to send the client's own certificate to server.
   521  		transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost"),
   522  	)
   523  	require.Nil(t, rsp, "roundtrip rsp not empty")
   524  	require.Nil(t, err, "Failed to roundtrip")
   525  }
   526  
   527  func TestNewClientTransport(t *testing.T) {
   528  	ct := thttp.NewClientTransport(false)
   529  	require.NotNil(t, ct, "Failed to new client transport")
   530  
   531  	ct2 := thttp.NewClientTransport(true)
   532  	require.NotNil(t, ct2, "Failed to new http2 client transport")
   533  }
   534  
   535  func TestClientRoundTrip(t *testing.T) {
   536  	ctx := context.Background()
   537  	ct := thttp.NewClientTransport(false)
   538  	ctx, msg := codec.WithNewMessage(ctx)
   539  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   540  	msg.WithClientReqHead(&thttp.ClientReqHeader{})
   541  	msg.WithClientRspHead(&thttp.ClientRspHeader{})
   542  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   543  	require.Nil(t, err)
   544  	defer ln.Close()
   545  	go http.Serve(ln, nil)
   546  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   547  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   548  		transport.WithDialAddress(ln.Addr().String()))
   549  	require.Nil(t, rsp, "roundtrip rsp not empty")
   550  	require.Nil(t, err, "Failed to roundtrip")
   551  }
   552  
   553  func TestClientRoundTripWithNoHead(t *testing.T) {
   554  	ctx := context.Background()
   555  	ct := thttp.NewClientTransport(false)
   556  	ctx, msg := codec.WithNewMessage(ctx)
   557  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   558  
   559  	rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+
   560  		"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   561  		transport.WithDialAddress("127.0.0.1:18080"))
   562  	require.Nil(t, rsp, "no head roundtrip rsp not empty")
   563  	require.NotNil(t, err, "no head roundtrip err nil")
   564  
   565  }
   566  
   567  func TestClientWithSelectorNode(t *testing.T) {
   568  	ctx := context.Background()
   569  	type testCase struct {
   570  		target   string
   571  		address  string
   572  		listener net.Listener
   573  	}
   574  	var tests []testCase
   575  	for i := 0; i < 2; i++ {
   576  		ln, err := net.Listen("tcp", "127.0.0.1:0")
   577  		require.Nil(t, err)
   578  		defer ln.Close()
   579  		addr := ln.Addr().String()
   580  		tests = append(tests, testCase{"ip://" + addr, addr, ln})
   581  	}
   582  	for _, tt := range tests {
   583  		tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   584  		option := transport.WithListener(tt.listener)
   585  		handler := transport.WithHandler(transport.Handler(&h{}))
   586  		err := tp.ListenAndServe(ctx, option, handler)
   587  		require.Nil(t, err, "Failed to new client transport")
   588  
   589  		proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   590  			client.WithTarget(tt.target),
   591  			client.WithSerializationType(codec.SerializationTypeNoop),
   592  		)
   593  
   594  		reqBody := &codec.Body{
   595  			Data: []byte("{\"username\":\"xyz\"," +
   596  				"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   597  		}
   598  		rspBody := &codec.Body{}
   599  		n := &registry.Node{}
   600  		require.Nil(t,
   601  			proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, client.WithSelectorNode(n)),
   602  			"Failed to post")
   603  		require.Equal(t, tt.address, n.Address)
   604  	}
   605  }
   606  
   607  func TestClient(t *testing.T) {
   608  	ctx := context.Background()
   609  	old := codec.GetSerializer(codec.SerializationTypeJSON)
   610  	defer func() { codec.RegisterSerializer(codec.SerializationTypeJSON, old) }()
   611  	codec.RegisterSerializer(codec.SerializationTypeJSON, &codec.JSONPBSerialization{})
   612  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   613  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   614  	require.Nil(t, err)
   615  	defer ln.Close()
   616  	option := transport.WithListener(ln)
   617  	handler := transport.WithHandler(transport.Handler(&h{}))
   618  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   619  
   620  	header := &thttp.ClientReqHeader{}
   621  	header.AddHeader("ContentType", "application/json")
   622  
   623  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   624  		client.WithTarget("ip://"+ln.Addr().String()),
   625  		client.WithSerializationType(codec.SerializationTypeNoop),
   626  		client.WithReqHead(header),
   627  		client.WithMetaData("k1", []byte("v1")),
   628  	)
   629  	reqBody := &codec.Body{
   630  		Data: []byte("{\"username\":\"xyz\"," +
   631  			"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   632  	}
   633  	rspBody := &codec.Body{}
   634  
   635  	require.Nil(t, proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to post")
   636  	require.Nil(t, proxy.Put(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to put")
   637  	require.Nil(t, proxy.Delete(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to delete")
   638  	require.Nil(t, proxy.Get(ctx, "/trpc.test.helloworld.Greeter/SayHello", rspBody), "Failed to get")
   639  	require.Nil(t, proxy.Patch(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to patch")
   640  
   641  	// Test client with options.
   642  	proxy = thttp.NewClientProxy("trpc.test.helloworld.Greeter")
   643  	reqBody = &codec.Body{
   644  		Data: []byte("{\"username\":\"xyz\"," +
   645  			"\"password\":\"xyz\",\"from\":\"xyz\"}"),
   646  	}
   647  	rspBody = &codec.Body{}
   648  	require.Nil(t,
   649  		proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody,
   650  			client.WithTarget("ip://"+ln.Addr().String()),
   651  			client.WithSerializationType(codec.SerializationTypeNoop),
   652  			client.WithReqHead(header),
   653  			client.WithMetaData("k1", []byte("v1")),
   654  		), "Failed to post")
   655  
   656  	require.NotNil(t,
   657  		proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody,
   658  			client.WithTarget("ip://127.0.0.1:180"),
   659  		), "Failed to post")
   660  }
   661  
   662  func TestReqHeader(t *testing.T) {
   663  	ctx := context.Background()
   664  	// Invalid url.
   665  	header := &thttp.ClientReqHeader{}
   666  	header.AddHeader("Content-Type", "application/json")
   667  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   668  		client.WithTarget("ip://127.0.0.1:18080:www.baidu.com//"),
   669  		client.WithSerializationType(codec.SerializationTypeNoop),
   670  		client.WithReqHead(header),
   671  	)
   672  	reqBody := &codec.Body{}
   673  	rspBody := &codec.Body{}
   674  	err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody)
   675  	require.NotNil(t, err)
   676  }
   677  
   678  func TestReqHeaderWithContentType(t *testing.T) {
   679  	ctx := context.Background()
   680  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   681  	require.Nil(t, err)
   682  	defer ln.Close()
   683  	option := transport.WithListener(ln)
   684  	handler := transport.WithHandler(transport.Handler(&h{}))
   685  	tp := thttp.NewServerTransport(newNoopStdHTTPServer)
   686  	require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport")
   687  	var tests = []struct {
   688  		expected string
   689  	}{
   690  		{"application/json"},
   691  		{"application/jsonp"},
   692  		{"application/jsonp123"},
   693  		{"application/text123"},
   694  	}
   695  	for _, tt := range tests {
   696  		header := &thttp.ClientReqHeader{}
   697  		header.AddHeader("Content-Type", tt.expected)
   698  		proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   699  			client.WithTarget("ip://"+ln.Addr().String()),
   700  			client.WithSerializationType(codec.SerializationTypeForm),
   701  			client.WithReqHead(header),
   702  		)
   703  		reqBody := &codec.Body{}
   704  		rspBody := &codec.Body{}
   705  		err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody)
   706  		require.Nil(t, err)
   707  	}
   708  }
   709  
   710  func TestHandler(t *testing.T) {
   711  	var (
   712  		handler = func(w http.ResponseWriter, r *http.Request) {
   713  			return
   714  		}
   715  		handlerFunc = func(w http.ResponseWriter, r *http.Request) error {
   716  			return nil
   717  		}
   718  		service = server.New(server.WithProtocol("http"))
   719  	)
   720  
   721  	thttp.Handle("*", http.HandlerFunc(handler))
   722  	thttp.HandleFunc("/path/do/not/equal/to/*", handlerFunc)
   723  	thttp.RegisterDefaultService(service)
   724  
   725  	for _, method := range thttp.ServiceDesc.Methods {
   726  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   727  			return make([]filter.ServerFilter, 0), nil
   728  		})
   729  
   730  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   731  			return make([]filter.ServerFilter, 0), errors.New("invalid filter")
   732  		})
   733  
   734  		header := &thttp.Header{
   735  			Request:  &http.Request{},
   736  			Response: &httptest.ResponseRecorder{},
   737  		}
   738  		ctx := thttp.WithHeader(context.TODO(), header)
   739  		_, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) {
   740  			return make([]filter.ServerFilter, 0), nil
   741  		})
   742  		require.Nil(t, err)
   743  	}
   744  }
   745  
   746  func TestMux(t *testing.T) {
   747  	var handler = func(w http.ResponseWriter, r *http.Request) {
   748  		return
   749  	}
   750  	mux := http.NewServeMux()
   751  	mux.HandleFunc("/", handler)
   752  
   753  	var service = &mockService{}
   754  	thttp.RegisterServiceMux(service, mux)
   755  	desc, _ := service.desc.(*server.ServiceDesc)
   756  	for _, method := range desc.Methods {
   757  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   758  			return make([]filter.ServerFilter, 0), nil
   759  		})
   760  
   761  		method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) {
   762  			return make([]filter.ServerFilter, 0), errors.New("invalid filter")
   763  		})
   764  
   765  		req, _ := http.NewRequest("GET", "/", nil)
   766  		header := &thttp.Header{
   767  			Request:  req,
   768  			Response: &httptest.ResponseRecorder{},
   769  		}
   770  		ctx := thttp.WithHeader(context.TODO(), header)
   771  		_, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) {
   772  			return make([]filter.ServerFilter, 0), nil
   773  		})
   774  		require.Nil(t, err)
   775  	}
   776  }
   777  
   778  // TestCheckRedirect tests set CheckRedirect
   779  func TestCheckRedirect(t *testing.T) {
   780  	ctx := context.Background()
   781  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   782  	require.Nil(t, err)
   783  	defer ln.Close()
   784  	// server
   785  	go func() {
   786  		// real backend
   787  		h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   788  			w.Write([]byte("real"))
   789  		})
   790  		http.Handle("/real", h)
   791  
   792  		// redirect a
   793  		rha := http.RedirectHandler("/b", http.StatusMovedPermanently)
   794  		http.Handle("/a", rha)
   795  
   796  		// redirect b
   797  		rhb := http.RedirectHandler("/real", http.StatusMovedPermanently)
   798  		http.Handle("/b", rhb)
   799  
   800  		http.Serve(ln, nil)
   801  	}()
   802  	time.Sleep(200 * time.Millisecond)
   803  
   804  	// sets CheckRedirect
   805  	checkRedirect := func(_ *http.Request, via []*http.Request) error {
   806  		if len(via) > 1 {
   807  			return errors.New("more than once")
   808  		}
   809  		return nil
   810  	}
   811  	thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = checkRedirect
   812  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   813  		client.WithTarget("ip://"+ln.Addr().String()),
   814  		client.WithSerializationType(codec.SerializationTypeNoop),
   815  	)
   816  	reqBody := &codec.Body{}
   817  	rspBody := &codec.Body{}
   818  	// only redirect once form /b
   819  	require.Nil(t, proxy.Post(ctx, "/b", reqBody, rspBody))
   820  	// redirect twice from /a
   821  	err = proxy.Post(ctx, "/a", reqBody, rspBody)
   822  	require.NotNil(t, err)
   823  	require.Equal(t, true, strings.Contains(err.Error(), "more than once"))
   824  }
   825  
   826  func TestTransportError(t *testing.T) {
   827  	http.HandleFunc("/timeout", func(http.ResponseWriter, *http.Request) {
   828  		time.Sleep(time.Second)
   829  	})
   830  	http.HandleFunc("/cancel", func(http.ResponseWriter, *http.Request) {})
   831  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   832  	require.Nil(t, err)
   833  	defer ln.Close()
   834  	go func() { http.Serve(ln, nil) }()
   835  	time.Sleep(200 * time.Millisecond)
   836  
   837  	proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter",
   838  		client.WithTarget("ip://"+ln.Addr().String()),
   839  		client.WithSerializationType(codec.SerializationTypeNoop),
   840  		client.WithTimeout(time.Millisecond*500),
   841  	)
   842  	rspBody := &codec.Body{}
   843  
   844  	err = proxy.Get(context.Background(), "/timeout", rspBody)
   845  	terr, ok := err.(*errs.Error)
   846  	require.True(t, ok)
   847  	require.EqualValues(t, terr.Code, int32(errs.RetClientTimeout))
   848  
   849  	ctx, cancel := context.WithCancel(context.Background())
   850  	cancel()
   851  	err = proxy.Get(ctx, "/cancel", rspBody)
   852  	terr, ok = err.(*errs.Error)
   853  	require.True(t, ok)
   854  	require.EqualValues(t, terr.Code, int32(errs.RetClientCanceled))
   855  }
   856  
   857  func TestClientRoundDyeing(t *testing.T) {
   858  	ctx := context.Background()
   859  	ct := thttp.NewClientTransport(false)
   860  	ctx, msg := codec.WithNewMessage(ctx)
   861  	msg.WithDyeing(true)
   862  	dyeingKey := "dyeingkey"
   863  	msg.WithDyeingKey(dyeingKey)
   864  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   865  	req := &http.Request{
   866  		Header: http.Header{},
   867  	}
   868  	reqHeader := &thttp.ClientReqHeader{
   869  		Request: req,
   870  	}
   871  	msg.WithClientReqHead(reqHeader)
   872  	rspHeader := &thttp.ClientRspHeader{}
   873  	msg.WithClientRspHead(rspHeader)
   874  	meta := codec.MetaData{
   875  		thttp.TrpcDyeingKey: []byte(dyeingKey),
   876  	}
   877  	msg.WithClientMetaData(meta)
   878  	_, err := ct.RoundTrip(ctx, nil)
   879  	require.NotNil(t, err)
   880  	require.Equal(t, req.Header.Get(thttp.TrpcMessageType),
   881  		strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)))
   882  }
   883  
   884  func TestClientRoundEnvTransfer(t *testing.T) {
   885  	ctx := context.Background()
   886  	ct := thttp.NewClientTransport(false)
   887  	ctx, msg := codec.WithNewMessage(ctx)
   888  	msg.WithEnvTransfer("feat,master")
   889  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   890  	req := &http.Request{
   891  		Header: http.Header{},
   892  	}
   893  	reqHeader := &thttp.ClientReqHeader{
   894  		Request: req,
   895  	}
   896  	msg.WithClientReqHead(reqHeader)
   897  	rspHeader := &thttp.ClientRspHeader{}
   898  	msg.WithClientRspHead(rspHeader)
   899  	_, err := ct.RoundTrip(ctx, nil)
   900  	require.NotNil(t, err)
   901  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), thttp.TrpcEnv)
   902  }
   903  
   904  func TestDisableBase64EncodeTransInfo(t *testing.T) {
   905  	ctx := context.Background()
   906  	ct := thttp.NewClientTransport(false, transport.WithDisableEncodeTransInfoBase64())
   907  	ctx, msg := codec.WithNewMessage(ctx)
   908  	var (
   909  		envTrans  = "feat,master"
   910  		metaVal   = "value"
   911  		dyeingKey = "dyeingkey"
   912  	)
   913  	msg.WithEnvTransfer(envTrans)
   914  	msg.WithClientMetaData(codec.MetaData{"key": []byte(metaVal)})
   915  	msg.WithDyeing(true)
   916  	msg.WithDyeingKey(dyeingKey)
   917  	msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
   918  	req := &http.Request{
   919  		Header: http.Header{},
   920  	}
   921  	reqHeader := &thttp.ClientReqHeader{
   922  		Request: req,
   923  	}
   924  	msg.WithClientReqHead(reqHeader)
   925  	rspHeader := &thttp.ClientRspHeader{}
   926  	msg.WithClientRspHead(rspHeader)
   927  	_, err := ct.RoundTrip(ctx, nil)
   928  	require.NotNil(t, err)
   929  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), envTrans)
   930  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), metaVal)
   931  	require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), dyeingKey)
   932  }
   933  
   934  func TestDisableServiceRouterTransInfo(t *testing.T) {
   935  	ctx := context.Background()
   936  	a := require.New(t)
   937  	ct := thttp.NewClientTransport(false)
   938  	ctx, msg := codec.WithNewMessage(ctx)
   939  	msg.WithClientMetaData(codec.MetaData{thttp.TrpcEnv: []byte("orienv")}) // this emulate decode trpc protocol client request
   940  	msg.WithEnvTransfer("feat,master")
   941  	req := &http.Request{
   942  		Header: http.Header{},
   943  	}
   944  	reqHeader := &thttp.ClientReqHeader{
   945  		Request: req,
   946  	}
   947  	msg.WithClientReqHead(reqHeader)
   948  	rspHeader := &thttp.ClientRspHeader{}
   949  	msg.WithClientRspHead(rspHeader)
   950  	_, err := ct.RoundTrip(ctx, nil)
   951  	a.NotNil(err)
   952  	info, err := thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo))
   953  	a.NoError(err)
   954  	a.Equal(string(info[thttp.TrpcEnv]), "feat,master")
   955  
   956  	msg.WithEnvTransfer("") // DisableServiceRouter would clear EnvTransfer
   957  	_, err = ct.RoundTrip(ctx, nil)
   958  	a.NotNil(err)
   959  	info, err = thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo))
   960  	a.NoError(err)
   961  	a.Equal(string(info[thttp.TrpcEnv]), "")
   962  }
   963  
   964  func TestHTTPSUseClientVerify(t *testing.T) {
   965  	const (
   966  		network = "tcp"
   967  		address = "127.0.0.1:0"
   968  	)
   969  	ln, err := net.Listen(network, address)
   970  	require.Nil(t, err)
   971  	defer ln.Close()
   972  	serviceName := "trpc.app.server.Service" + t.Name()
   973  	service := server.New(
   974  		server.WithServiceName(serviceName),
   975  		server.WithNetwork(network),
   976  		server.WithProtocol("http_no_protocol"),
   977  		server.WithListener(ln),
   978  		server.WithTLS(
   979  			"../testdata/server.crt",
   980  			"../testdata/server.key",
   981  			"../testdata/ca.pem",
   982  		),
   983  	)
   984  	pattern := "/" + t.Name()
   985  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
   986  		w.Write([]byte(t.Name()))
   987  	}))
   988  	s := &server.Server{}
   989  	s.AddService(serviceName, service)
   990  	go s.Serve()
   991  	defer s.Close(nil)
   992  	time.Sleep(100 * time.Millisecond)
   993  
   994  	c := thttp.NewClientProxy(
   995  		serviceName,
   996  		client.WithTarget("ip://"+ln.Addr().String()),
   997  	)
   998  	req := &codec.Body{}
   999  	rsp := &codec.Body{}
  1000  	require.Nil(t,
  1001  		c.Post(context.Background(), pattern, req, rsp,
  1002  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1003  			client.WithSerializationType(codec.SerializationTypeNoop),
  1004  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1005  			client.WithTLS(
  1006  				"../testdata/client.crt",
  1007  				"../testdata/client.key",
  1008  				"../testdata/ca.pem",
  1009  				"localhost",
  1010  			),
  1011  		))
  1012  	require.Equal(t, []byte(t.Name()), rsp.Data)
  1013  }
  1014  
  1015  func TestHTTPSSkipClientVerify(t *testing.T) {
  1016  	const (
  1017  		network = "tcp"
  1018  		address = "127.0.0.1:0"
  1019  	)
  1020  	ln, err := net.Listen(network, address)
  1021  	require.Nil(t, err)
  1022  	defer ln.Close()
  1023  	serviceName := "trpc.app.server.Service" + t.Name()
  1024  	service := server.New(
  1025  		server.WithServiceName(serviceName),
  1026  		server.WithNetwork(network),
  1027  		server.WithProtocol("http_no_protocol"),
  1028  		server.WithListener(ln),
  1029  		server.WithTLS(
  1030  			"../testdata/server.crt",
  1031  			"../testdata/server.key",
  1032  			"",
  1033  		),
  1034  	)
  1035  	pattern := "/" + t.Name()
  1036  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  1037  		w.Write([]byte(t.Name()))
  1038  	}))
  1039  	s := &server.Server{}
  1040  	s.AddService(serviceName, service)
  1041  	go s.Serve()
  1042  	defer s.Close(nil)
  1043  	time.Sleep(100 * time.Millisecond)
  1044  
  1045  	c := thttp.NewClientProxy(
  1046  		serviceName,
  1047  		client.WithTarget("ip://"+ln.Addr().String()),
  1048  	)
  1049  	req := &codec.Body{}
  1050  	rsp := &codec.Body{}
  1051  	require.Nil(t,
  1052  		c.Post(context.Background(), pattern, req, rsp,
  1053  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1054  			client.WithSerializationType(codec.SerializationTypeNoop),
  1055  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1056  			client.WithTLS(
  1057  				"", "", "none", "",
  1058  			),
  1059  		))
  1060  	require.Equal(t, []byte(t.Name()), rsp.Data)
  1061  }
  1062  
  1063  func TestListenAndServeHTTPHead(t *testing.T) {
  1064  	ctx := context.Background()
  1065  	const (
  1066  		network = "tcp"
  1067  		address = "127.0.0.1:0"
  1068  	)
  1069  	ln, err := net.Listen(network, address)
  1070  	require.Nil(t, err)
  1071  	defer ln.Close()
  1072  	st := thttp.NewServerTransport(newNoopStdHTTPServer)
  1073  	require.Nil(t, st.ListenAndServe(ctx,
  1074  		transport.WithHandler(&httpHeadHandler{
  1075  			func(ctx context.Context, _ []byte) (rsp []byte, err error) {
  1076  				head := thttp.Head(ctx)
  1077  				head.Response.WriteHeader(http.StatusOK)
  1078  				head.Response.Write([]byte(fmt.Sprintf("%+v", thttp.Head(head.Request.Context()) != nil)))
  1079  				return
  1080  			}}),
  1081  		transport.WithListener(ln),
  1082  	))
  1083  	time.Sleep(200 * time.Millisecond)
  1084  	rsp, err := http.Get("http://" + ln.Addr().String())
  1085  	require.Nil(t, err)
  1086  	bs, err := io.ReadAll(rsp.Body)
  1087  	require.Nil(t, err)
  1088  	require.Equal(t, fmt.Sprintf("%+v", true), string(bs))
  1089  }
  1090  
  1091  type httpHeadHandler struct {
  1092  	handle func(ctx context.Context, req []byte) (rsp []byte, err error)
  1093  }
  1094  
  1095  func (h *httpHeadHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) {
  1096  	return h.handle(ctx, req)
  1097  }
  1098  
  1099  func TestHTTPStreamFileUpload(t *testing.T) {
  1100  	// Start server.
  1101  	const (
  1102  		network = "tcp"
  1103  		address = "127.0.0.1:0"
  1104  	)
  1105  	ln, err := net.Listen(network, address)
  1106  	require.Nil(t, err)
  1107  	defer ln.Close()
  1108  	go http.Serve(ln, &fileHandler{})
  1109  	// Start client.
  1110  	c := thttp.NewClientProxy(
  1111  		"trpc.app.server.Service_http",
  1112  		client.WithTarget("ip://"+ln.Addr().String()),
  1113  	)
  1114  	// Open and read file.
  1115  	fileDir, err := os.Getwd()
  1116  	require.Nil(t, err)
  1117  	fileName := "README.md"
  1118  	filePath := path.Join(fileDir, fileName)
  1119  	file, err := os.Open(filePath)
  1120  	require.Nil(t, err)
  1121  	defer file.Close()
  1122  	// Construct multipart form file.
  1123  	body := &bytes.Buffer{}
  1124  	writer := multipart.NewWriter(body)
  1125  	part, err := writer.CreateFormFile("field_name", filepath.Base(file.Name()))
  1126  	require.Nil(t, err)
  1127  	io.Copy(part, file)
  1128  	require.Nil(t, writer.Close())
  1129  	// Add multipart form data header.
  1130  	header := http.Header{}
  1131  	header.Add("Content-Type", writer.FormDataContentType())
  1132  	reqHeader := &thttp.ClientReqHeader{
  1133  		Header:  header,
  1134  		ReqBody: body, // Stream send.
  1135  	}
  1136  	req := &codec.Body{}
  1137  	rsp := &codec.Body{}
  1138  	// Upload file.
  1139  	require.Nil(t,
  1140  		c.Post(context.Background(), "/", req, rsp,
  1141  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1142  			client.WithSerializationType(codec.SerializationTypeNoop),
  1143  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1144  			client.WithReqHead(reqHeader),
  1145  		))
  1146  	require.Equal(t, []byte(fileName), rsp.Data)
  1147  }
  1148  
  1149  type fileHandler struct{}
  1150  
  1151  func (*fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1152  	_, h, err := r.FormFile("field_name")
  1153  	if err != nil {
  1154  		w.WriteHeader(http.StatusBadRequest)
  1155  		return
  1156  	}
  1157  	w.WriteHeader(http.StatusOK)
  1158  	// Write back file name.
  1159  	w.Write([]byte(h.Filename))
  1160  	return
  1161  }
  1162  
  1163  func TestHTTPStreamRead(t *testing.T) {
  1164  	// Start server.
  1165  	const (
  1166  		network = "tcp"
  1167  		address = "127.0.0.1:0"
  1168  	)
  1169  	ln, err := net.Listen(network, address)
  1170  	require.Nil(t, err)
  1171  	defer ln.Close()
  1172  	go http.Serve(ln, &fileServer{})
  1173  
  1174  	// Start client.
  1175  	c := thttp.NewClientProxy(
  1176  		"trpc.app.server.Service_http",
  1177  		client.WithTarget("ip://"+ln.Addr().String()),
  1178  	)
  1179  
  1180  	// Enable manual body reading in order to
  1181  	// disable the framework's automatic body reading capability,
  1182  	// so that users can manually do their own client-side streaming reads.
  1183  	rspHead := &thttp.ClientRspHeader{
  1184  		ManualReadBody: true,
  1185  	}
  1186  	req := &codec.Body{}
  1187  	rsp := &codec.Body{}
  1188  	require.Nil(t,
  1189  		c.Post(context.Background(), "/", req, rsp,
  1190  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1191  			client.WithSerializationType(codec.SerializationTypeNoop),
  1192  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1193  			client.WithRspHead(rspHead),
  1194  		))
  1195  	require.Nil(t, rsp.Data)
  1196  	body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body.
  1197  	defer body.Close()            // Do remember to close the body.
  1198  	bs, err := io.ReadAll(body)
  1199  	require.Nil(t, err)
  1200  	require.NotNil(t, bs)
  1201  }
  1202  
  1203  func TestHTTPSendReceiveChunk(t *testing.T) {
  1204  	// HTTP chunked example:
  1205  	//   1. Client sends chunks: Add "chunked" transfer encoding header, and use io.Reader as body.
  1206  	//   2. Client reads chunks: The Go/net/http automatically handles the chunked reading.
  1207  	//                           Users can simply read resp.Body in a loop until io.EOF.
  1208  	//   3. Server reads chunks: Similar to client reads chunks.
  1209  	//   4. Server sends chunks: Assert http.ResponseWriter as http.Flusher, call flusher.Flush() after
  1210  	//         writing a part of data, it will automatically trigger "chunked" encoding to send a chunk.
  1211  
  1212  	// Start server.
  1213  	const (
  1214  		network = "tcp"
  1215  		address = "127.0.0.1:0"
  1216  	)
  1217  	ln, err := net.Listen(network, address)
  1218  	require.Nil(t, err)
  1219  	defer ln.Close()
  1220  	go http.Serve(ln, &chunkedServer{})
  1221  
  1222  	// Start client.
  1223  	c := thttp.NewClientProxy(
  1224  		"trpc.app.server.Service_http",
  1225  		client.WithTarget("ip://"+ln.Addr().String()),
  1226  	)
  1227  
  1228  	// Open and read file.
  1229  	fileDir, err := os.Getwd()
  1230  	require.Nil(t, err)
  1231  	fileName := "README.md"
  1232  	filePath := path.Join(fileDir, fileName)
  1233  	file, err := os.Open(filePath)
  1234  	require.Nil(t, err)
  1235  	defer file.Close()
  1236  
  1237  	// 1. Client sends chunks.
  1238  
  1239  	// Add request headers.
  1240  	header := http.Header{}
  1241  	header.Add("Content-Type", "text/plain")
  1242  	// Add chunked transfer encoding header.
  1243  	header.Add("Transfer-Encoding", "chunked")
  1244  	reqHead := &thttp.ClientReqHeader{
  1245  		Header:  header,
  1246  		ReqBody: file, // Stream send (for chunks).
  1247  	}
  1248  
  1249  	// Enable manual body reading in order to
  1250  	// disable the framework's automatic body reading capability,
  1251  	// so that users can manually do their own client-side streaming reads.
  1252  	rspHead := &thttp.ClientRspHeader{
  1253  		ManualReadBody: true,
  1254  	}
  1255  	req := &codec.Body{}
  1256  	rsp := &codec.Body{}
  1257  	require.Nil(t,
  1258  		c.Post(context.Background(), "/", req, rsp,
  1259  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1260  			client.WithSerializationType(codec.SerializationTypeNoop),
  1261  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1262  			client.WithReqHead(reqHead),
  1263  			client.WithRspHead(rspHead),
  1264  		))
  1265  	require.Nil(t, rsp.Data)
  1266  
  1267  	// 2. Client reads chunks.
  1268  
  1269  	// Do stream reads directly from rspHead.Response.Body.
  1270  	body := rspHead.Response.Body
  1271  	defer body.Close() // Do remember to close the body.
  1272  	buf := make([]byte, 4096)
  1273  	var idx int
  1274  	for {
  1275  		n, err := body.Read(buf)
  1276  		if err == io.EOF {
  1277  			t.Logf("reached io.EOF\n")
  1278  			break
  1279  		}
  1280  		t.Logf("read chunk %d of length %d: %q\n", idx, n, buf[:n])
  1281  		idx++
  1282  	}
  1283  }
  1284  
  1285  type chunkedServer struct{}
  1286  
  1287  func (*chunkedServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1288  	// 3. Server reads chunks.
  1289  
  1290  	// io.ReadAll will read until io.EOF.
  1291  	// Go/net/http will automatically handle chunked body reads.
  1292  	bs, err := io.ReadAll(r.Body)
  1293  	if err != nil {
  1294  		w.WriteHeader(http.StatusInternalServerError)
  1295  		w.Write([]byte(fmt.Sprintf("io.ReadAll err: %+v", err)))
  1296  		return
  1297  	}
  1298  
  1299  	// 4. Server sends chunks.
  1300  
  1301  	// Send HTTP chunks using http.Flusher.
  1302  	// Reference: https://stackoverflow.com/questions/26769626/send-a-chunked-http-response-from-a-go-server.
  1303  	// The "Transfer-Encoding" header will be handled by the writer implicitly, so no need to set it.
  1304  	flusher, ok := w.(http.Flusher)
  1305  	if !ok {
  1306  		w.WriteHeader(http.StatusInternalServerError)
  1307  		w.Write([]byte("expected http.ResponseWriter to be an http.Flusher"))
  1308  		return
  1309  	}
  1310  	chunks := 10
  1311  	chunkSize := (len(bs) + chunks - 1) / chunks
  1312  	for i := 0; i < chunks; i++ {
  1313  		start := i * chunkSize
  1314  		end := (i + 1) * chunkSize
  1315  		if end > len(bs) {
  1316  			end = len(bs)
  1317  		}
  1318  		w.Write(bs[start:end])
  1319  		flusher.Flush() // Trigger "chunked" encoding and send a chunk.
  1320  		time.Sleep(500 * time.Millisecond)
  1321  	}
  1322  	return
  1323  }
  1324  
  1325  type fileServer struct{}
  1326  
  1327  func (*fileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  1328  	http.ServeFile(w, r, "./README.md")
  1329  	return
  1330  }
  1331  
  1332  func TestHTTPSendAndReceiveSSE(t *testing.T) {
  1333  	const (
  1334  		network = "tcp"
  1335  		address = "127.0.0.1:0"
  1336  	)
  1337  	ln, err := net.Listen(network, address)
  1338  	require.Nil(t, err)
  1339  	defer ln.Close()
  1340  	serviceName := "trpc.app.server.Service" + t.Name()
  1341  	service := server.New(
  1342  		server.WithServiceName(serviceName),
  1343  		server.WithNetwork(network),
  1344  		server.WithProtocol("http_no_protocol"),
  1345  		server.WithListener(ln),
  1346  	)
  1347  	pattern := "/" + t.Name()
  1348  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1349  		flusher, ok := w.(http.Flusher)
  1350  		if !ok {
  1351  			http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
  1352  			return
  1353  		}
  1354  		w.Header().Set("Content-Type", "text/event-stream")
  1355  		w.Header().Set("Cache-Control", "no-cache")
  1356  		w.Header().Set("Connection", "keep-alive")
  1357  		w.Header().Set("Access-Control-Allow-Origin", "*")
  1358  		bs, err := io.ReadAll(r.Body)
  1359  		if err != nil {
  1360  			http.Error(w, err.Error(), http.StatusBadRequest)
  1361  			return
  1362  		}
  1363  		msg := string(bs)
  1364  		for i := 0; i < 3; i++ {
  1365  			msgBytes := []byte("event: message\n\ndata: " + msg + strconv.Itoa(i) + "\n\n")
  1366  			_, err = w.Write(msgBytes)
  1367  			if err != nil {
  1368  				http.Error(w, err.Error(), http.StatusInternalServerError)
  1369  				return
  1370  			}
  1371  			flusher.Flush()
  1372  			time.Sleep(500 * time.Millisecond)
  1373  		}
  1374  		return
  1375  	}))
  1376  	s := &server.Server{}
  1377  	s.AddService(serviceName, service)
  1378  	go s.Serve()
  1379  	defer s.Close(nil)
  1380  	time.Sleep(100 * time.Millisecond)
  1381  
  1382  	c := thttp.NewClientProxy(
  1383  		serviceName,
  1384  		client.WithTarget("ip://"+ln.Addr().String()),
  1385  	)
  1386  	header := http.Header{}
  1387  	header.Set("Cache-Control", "no-cache")
  1388  	header.Set("Accept", "text/event-stream")
  1389  	header.Set("Connection", "keep-alive")
  1390  	reqHeader := &thttp.ClientReqHeader{
  1391  		Header: header,
  1392  	}
  1393  	// Enable manual body reading in order to
  1394  	// disable the framework's automatic body reading capability,
  1395  	// so that users can manually do their own client-side streaming reads.
  1396  	rspHead := &thttp.ClientRspHeader{
  1397  		ManualReadBody: true,
  1398  	}
  1399  	req := &codec.Body{Data: []byte("hello")}
  1400  	rsp := &codec.Body{}
  1401  	require.Nil(t,
  1402  		c.Post(context.Background(), pattern, req, rsp,
  1403  			client.WithCurrentSerializationType(codec.SerializationTypeNoop),
  1404  			client.WithSerializationType(codec.SerializationTypeNoop),
  1405  			client.WithCurrentCompressType(codec.CompressTypeNoop),
  1406  			client.WithReqHead(reqHeader),
  1407  			client.WithRspHead(rspHead),
  1408  		))
  1409  	body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body.
  1410  	defer body.Close()            // Do remember to close the body.
  1411  	data := make([]byte, 1024)
  1412  	for {
  1413  		n, err := body.Read(data)
  1414  		if err == io.EOF {
  1415  			break
  1416  		}
  1417  		require.Nil(t, err)
  1418  		t.Logf("Received message: \n%s\n", string(data[:n]))
  1419  	}
  1420  }
  1421  
  1422  func TestHTTPClientReqRspDifferentContentType(t *testing.T) {
  1423  	const (
  1424  		network = "tcp"
  1425  		address = "127.0.0.1:0"
  1426  	)
  1427  	ln, err := net.Listen(network, address)
  1428  	require.Nil(t, err)
  1429  	defer ln.Close()
  1430  	serviceName := "trpc.app.server.Service" + t.Name()
  1431  	service := server.New(
  1432  		server.WithServiceName(serviceName),
  1433  		server.WithNetwork(network),
  1434  		server.WithProtocol("http_no_protocol"),
  1435  		server.WithListener(ln),
  1436  	)
  1437  	const (
  1438  		hello = "hello "
  1439  		key   = "key"
  1440  	)
  1441  	pattern := "/" + t.Name()
  1442  	thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1443  		bs, err := io.ReadAll(r.Body)
  1444  		if err != nil {
  1445  			w.WriteHeader(http.StatusBadRequest)
  1446  			return
  1447  		}
  1448  		req, err := url.ParseQuery(string(bs))
  1449  		if err != nil {
  1450  			w.WriteHeader(http.StatusBadRequest)
  1451  			return
  1452  		}
  1453  		rsp := &helloworld.HelloReply{Message: hello + req.Get(key)}
  1454  		bs, err = codec.Marshal(codec.SerializationTypePB, rsp)
  1455  		if err != nil {
  1456  			w.WriteHeader(http.StatusInternalServerError)
  1457  			return
  1458  		}
  1459  		w.Header().Add("Content-Type", "application/protobuf")
  1460  		w.Write(bs)
  1461  		return
  1462  	}))
  1463  	s := &server.Server{}
  1464  	s.AddService(serviceName, service)
  1465  	go s.Serve()
  1466  	defer s.Close(nil)
  1467  	time.Sleep(100 * time.Millisecond)
  1468  
  1469  	c := thttp.NewClientProxy(
  1470  		serviceName,
  1471  		client.WithTarget("ip://"+ln.Addr().String()),
  1472  	)
  1473  	req := make(url.Values)
  1474  	req.Add(key, t.Name())
  1475  	rsp := &helloworld.HelloReply{}
  1476  	require.Nil(t,
  1477  		c.Post(context.Background(), pattern, req, rsp,
  1478  			client.WithSerializationType(codec.SerializationTypeForm),
  1479  		))
  1480  	require.Equal(t, hello+t.Name(), rsp.Message)
  1481  }
  1482  
  1483  type h struct{}
  1484  
  1485  func (*h) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1486  	fmt.Println("recv http req")
  1487  	return nil, nil
  1488  }
  1489  
  1490  type testLog struct {
  1491  	log.Logger
  1492  	errorCh chan error
  1493  }
  1494  
  1495  func (ln *testLog) Errorf(format string, args ...interface{}) {
  1496  	ln.errorCh <- fmt.Errorf(format, args...)
  1497  }
  1498  
  1499  // mockService is a mock service.
  1500  type mockService struct {
  1501  	desc interface{}
  1502  }
  1503  
  1504  // Register registers route information.
  1505  func (m *mockService) Register(serviceDesc interface{}, serviceImpl interface{}) error {
  1506  	m.desc = serviceDesc
  1507  	return nil
  1508  }
  1509  
  1510  // Serve runs service.
  1511  func (m *mockService) Serve() error {
  1512  	return nil
  1513  }
  1514  
  1515  // Close closes service.
  1516  func (m *mockService) Close(chan struct{}) error {
  1517  	return nil
  1518  }
  1519  
  1520  type errHandler struct{}
  1521  
  1522  func (*errHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1523  	return nil, errors.New("mock error")
  1524  }
  1525  
  1526  type errHeaderHandler struct{}
  1527  
  1528  func (*errHeaderHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) {
  1529  	return nil, thttp.ErrEncodeMissingHeader
  1530  }