trpc.group/trpc-go/trpc-go@v1.0.3/http/restful_server_transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http_test
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"crypto/tls"
    20  	"crypto/x509"
    21  	"encoding/base64"
    22  	"encoding/json"
    23  	"errors"
    24  	"io"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"strings"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/require"
    33  	"github.com/valyala/fasthttp"
    34  
    35  	trpc "trpc.group/trpc-go/trpc-go"
    36  	"trpc.group/trpc-go/trpc-go/codec"
    37  	thttp "trpc.group/trpc-go/trpc-go/http"
    38  	itls "trpc.group/trpc-go/trpc-go/internal/tls"
    39  	"trpc.group/trpc-go/trpc-go/restful"
    40  	"trpc.group/trpc-go/trpc-go/server"
    41  	"trpc.group/trpc-go/trpc-go/testdata/restful/helloworld"
    42  	"trpc.group/trpc-go/trpc-go/transport"
    43  )
    44  
    45  func TestCompatibility(t *testing.T) {
    46  	// Registers service.
    47  	serviceName := "trpc.test.server.Greeter" + t.Name()
    48  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    49  	require.Nil(t, err)
    50  	defer ln.Close()
    51  	url := "http://" + ln.Addr().String()
    52  	s := &server.Server{}
    53  	service := server.New(
    54  		server.WithListener(ln),
    55  		server.WithServiceName(serviceName),
    56  		server.WithProtocol("restful"),
    57  	)
    58  	s.AddService(serviceName, service)
    59  	helloworld.RegisterGreeterService(s, &greeterServerImpl{})
    60  
    61  	go func() { require.Nil(t, s.Serve()) }()
    62  	defer s.Close(nil)
    63  
    64  	time.Sleep(100 * time.Millisecond)
    65  
    66  	// Removes compatibility setting.
    67  	restful.SetCtxForCompatibility(
    68  		func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context {
    69  			return ctx
    70  		},
    71  	)
    72  
    73  	// Sends restful request.
    74  	req1, err := http.NewRequest("POST", url+"/v1/foobar",
    75  		bytes.NewBuffer([]byte(`{"name": "xyz"}`)))
    76  	require.Nil(t, err)
    77  	cli := http.Client{}
    78  	resp1, err := cli.Do(req1)
    79  	require.Nil(t, err)
    80  	defer resp1.Body.Close()
    81  	require.Equal(t, resp1.StatusCode, http.StatusInternalServerError)
    82  
    83  	// Adds compatibility setting.
    84  	restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter,
    85  		r *http.Request) context.Context {
    86  		return thttp.WithHeader(ctx, &thttp.Header{Response: w, Request: r})
    87  	})
    88  
    89  	// Sends restful request.
    90  	req2, err := http.NewRequest("POST", url+"/v1/foobar",
    91  		bytes.NewBuffer([]byte(`{"name": "xyz"}`)))
    92  	require.Nil(t, err)
    93  	resp2, err := cli.Do(req2)
    94  	require.Nil(t, err)
    95  	defer resp2.Body.Close()
    96  	require.Equal(t, resp2.StatusCode, http.StatusOK)
    97  }
    98  
    99  func TestEnableTLS(t *testing.T) {
   100  	// Registers service.
   101  	s := &server.Server{}
   102  	conf, err := itls.GetServerConfig("../testdata/ca.pem", "../testdata/server.crt", "../testdata/server.key")
   103  	require.Nil(t, err, "%+v", err)
   104  	ln, err := tls.Listen("tcp", "127.0.0.1:0", conf)
   105  	require.Nil(t, err)
   106  	defer ln.Close()
   107  	addr := strings.Split(ln.Addr().String(), ":")
   108  	require.Equal(t, 2, len(addr))
   109  	port := addr[1]
   110  	// Must use localhost to replace 127.0.0.1, or else the following error will occur:
   111  	// tls: failed to verify certificate: x509: cannot validate certificate for 127.0.0.1 because it doesn't contain any IP SANs.
   112  	url := "https://localhost:" + port
   113  	service := server.New(
   114  		server.WithListener(ln),
   115  		server.WithServiceName("trpc.test.helloworld.Greeter"),
   116  		server.WithProtocol("restful"),
   117  	)
   118  	s.AddService("trpc.test.helloworld.Greeter", service)
   119  	helloworld.RegisterGreeterService(s, &greeterServerImpl{})
   120  
   121  	go func() { require.Nil(t, s.Serve()) }()
   122  	defer s.Close(nil)
   123  
   124  	time.Sleep(100 * time.Millisecond)
   125  
   126  	// Sends https request.
   127  	pool := x509.NewCertPool()
   128  	ca, err := os.ReadFile("../testdata/ca.pem")
   129  	require.Nil(t, err)
   130  	pool.AppendCertsFromPEM(ca)
   131  	cert, err := tls.LoadX509KeyPair("../testdata/client.crt", "../testdata/client.key")
   132  	require.Nil(t, err)
   133  
   134  	cli := &http.Client{
   135  		Transport: &http.Transport{
   136  			TLSClientConfig: &tls.Config{
   137  				RootCAs:      pool,
   138  				Certificates: []tls.Certificate{cert},
   139  			},
   140  		},
   141  	}
   142  
   143  	req, err := http.NewRequest("POST", url+"/v1/foobar",
   144  		bytes.NewBuffer([]byte(`{"name": "xyz"}`)))
   145  	require.Nil(t, err)
   146  
   147  	resp, err := cli.Do(req)
   148  	require.Nil(t, err, "%+v", err)
   149  	defer resp.Body.Close()
   150  	require.Equal(t, resp.StatusCode, http.StatusOK)
   151  
   152  	bodyBytes, err := io.ReadAll(resp.Body)
   153  	require.Nil(t, err)
   154  	type responseBody struct {
   155  		Message string `json:"message"`
   156  	}
   157  	respBody := &responseBody{}
   158  	json.Unmarshal(bodyBytes, respBody)
   159  	require.Equal(t, respBody.Message, "test restful server transport")
   160  }
   161  
   162  func TestReplaceRouter(t *testing.T) {
   163  	st := thttp.NewRESTServerTransport(true, transport.WithReusePort(true))
   164  	restful.RegisterRouter("replacing", http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}))
   165  	restful.RegisterRouter("no_replacing", restful.NewRouter())
   166  	err := st.ListenAndServe(context.Background(), transport.WithServiceName("replacing"))
   167  	require.NotNil(t, err)
   168  	err = st.ListenAndServe(context.Background(), transport.WithServiceName("no_replacing"))
   169  	require.Nil(t, err)
   170  }
   171  
   172  var (
   173  	headerMatcherTransInfo, _ = json.Marshal(map[string]string{
   174  		"kfuin": base64.StdEncoding.EncodeToString([]byte("3009025887")),
   175  	})
   176  )
   177  
   178  func TestDefaultRESTHeaderMatcher(t *testing.T) {
   179  	bgctx := trpc.BackgroundContext()
   180  	req := http.Request{Header: make(http.Header)}
   181  	req.Header.Set(thttp.TrpcCaller, "TestDefaultHeaderMatcher")
   182  	req.Header.Set(thttp.TrpcTransInfo, string(headerMatcherTransInfo))
   183  	req.Header.Set(thttp.TrpcTimeout, "2000")
   184  	req.Header.Set(thttp.TrpcMessageType, "1")
   185  	ctx, err := thttp.DefaultRESTHeaderMatcher(bgctx, nil, &req, "UTService", "UTMethod")
   186  	require.Nil(t, err)
   187  	msg := codec.Message(ctx)
   188  	require.Equal(t, "UTService", msg.CalleeServiceName())
   189  	require.Equal(t, "UTMethod", msg.ServerRPCName())
   190  	require.Equal(t, "TestDefaultHeaderMatcher", msg.CallerServiceName())
   191  	require.Equal(t, time.Duration(2000*time.Millisecond), msg.RequestTimeout())
   192  	require.Equal(t, "3009025887", string(trpc.GetMetaData(ctx, "kfuin")))
   193  	require.Equal(t, true, msg.Dyeing())
   194  
   195  	req.Header.Set(thttp.TrpcTransInfo, "")
   196  	req.Header.Set(thttp.TrpcMessageType, "0")
   197  	ctx, err = thttp.DefaultRESTHeaderMatcher(bgctx, nil, &req, "UTService", "UTMethod")
   198  	require.Nil(t, err)
   199  	msg = codec.Message(ctx)
   200  	require.Equal(t, "", string(trpc.GetMetaData(ctx, "kfuin")))
   201  	require.Equal(t, false, msg.Dyeing())
   202  }
   203  
   204  func TestDefaultRESTFastHTTPHeaderMatcher(t *testing.T) {
   205  	bgctx := trpc.BackgroundContext()
   206  	req := fasthttp.RequestCtx{}
   207  	req.Request.Header.Set(thttp.TrpcCaller, "TestDefaultHeaderMatcher")
   208  	req.Request.Header.Set(thttp.TrpcTransInfo, string(headerMatcherTransInfo))
   209  	req.Request.Header.Set(thttp.TrpcTimeout, "2000")
   210  	req.Request.Header.Set(thttp.TrpcMessageType, "1")
   211  	ctx, err := thttp.DefaultRESTFastHTTPHeaderMatcher(bgctx, &req, "UTService", "UTMethod")
   212  	require.Nil(t, err)
   213  	msg := codec.Message(ctx)
   214  	require.Equal(t, "UTService", msg.CalleeServiceName())
   215  	require.Equal(t, "UTMethod", msg.ServerRPCName())
   216  	require.Equal(t, "TestDefaultHeaderMatcher", msg.CallerServiceName())
   217  	require.Equal(t, time.Duration(2000*time.Millisecond), msg.RequestTimeout())
   218  	require.Equal(t, "3009025887", string(trpc.GetMetaData(ctx, "kfuin")))
   219  	require.Equal(t, true, msg.Dyeing())
   220  
   221  	req = fasthttp.RequestCtx{}
   222  	req.Request.Header.Set(thttp.TrpcTransInfo, "xyz")
   223  	_, err = thttp.DefaultRESTFastHTTPHeaderMatcher(bgctx, &req, "UTService", "UTMethod")
   224  	require.NotNil(t, err)
   225  }
   226  
   227  func TestPassListenerUseTLS(t *testing.T) {
   228  	// Registers service.
   229  	serviceName := "trpc.test.helloworld.Greeter" + t.Name()
   230  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   231  	require.Nil(t, err)
   232  	addr := strings.Split(ln.Addr().String(), ":")
   233  	require.Equal(t, 2, len(addr))
   234  	port := addr[1]
   235  	// Must use localhost to replace 127.0.0.1, or else the following error will occur:
   236  	// tls: failed to verify certificate: x509: cannot validate certificate for 127.0.0.1 because it doesn't contain any IP SANs.
   237  	url := "https://localhost:" + port
   238  	s := &server.Server{}
   239  	service := server.New(
   240  		server.WithListener(ln),
   241  		server.WithServiceName(serviceName),
   242  		server.WithProtocol("restful"),
   243  		server.WithTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   244  	)
   245  	s.AddService(serviceName, service)
   246  	helloworld.RegisterGreeterService(s, &greeterServerImpl{})
   247  
   248  	go func() {
   249  		err := s.Serve()
   250  		require.Nil(t, err)
   251  	}()
   252  	defer s.Close(nil)
   253  
   254  	time.Sleep(100 * time.Millisecond)
   255  
   256  	// Sends https request.
   257  	pool := x509.NewCertPool()
   258  	ca, err := os.ReadFile("../testdata/ca.pem")
   259  	require.Nil(t, err)
   260  	pool.AppendCertsFromPEM(ca)
   261  	cert, err := tls.LoadX509KeyPair("../testdata/client.crt", "../testdata/client.key")
   262  	require.Nil(t, err)
   263  
   264  	cli := &http.Client{
   265  		Transport: &http.Transport{
   266  			TLSClientConfig: &tls.Config{
   267  				RootCAs:      pool,
   268  				Certificates: []tls.Certificate{cert},
   269  			},
   270  		},
   271  	}
   272  
   273  	req, err := http.NewRequest("POST", url+"/v1/foobar",
   274  		bytes.NewBuffer([]byte(`{"name": "xyz"}`)))
   275  	require.Nil(t, err)
   276  
   277  	resp, err := cli.Do(req)
   278  	require.Nil(t, err, "err: %+v", err)
   279  	defer resp.Body.Close()
   280  	require.Equal(t, resp.StatusCode, http.StatusOK)
   281  
   282  	bodyBytes, err := io.ReadAll(resp.Body)
   283  	require.Nil(t, err)
   284  	type responseBody struct {
   285  		Message string `json:"message"`
   286  	}
   287  	respBody := &responseBody{}
   288  	json.Unmarshal(bodyBytes, respBody)
   289  	require.Equal(t, respBody.Message, "test restful server transport")
   290  }
   291  
   292  func TestListenAndServeInvalidAddrErr(t *testing.T) {
   293  	serviceName := "trpc.test.helloworld.Greeter" + t.Name()
   294  	s := &server.Server{}
   295  	invalidAddr := "888.888.888.888:88888"
   296  	service := server.New(
   297  		server.WithAddress(invalidAddr),
   298  		server.WithServiceName(serviceName),
   299  		server.WithProtocol("restful"),
   300  	)
   301  	s.AddService(serviceName, service)
   302  	require.NotNil(t, s.Serve())
   303  }
   304  
   305  type greeterServerImpl struct{}
   306  
   307  func (s *greeterServerImpl) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
   308  	rsp := &helloworld.HelloReply{}
   309  	if thttp.Head(ctx) == nil {
   310  		return nil, errors.New("test error")
   311  	}
   312  	rsp.Message = "test restful server transport"
   313  	return rsp, nil
   314  }