trpc.group/trpc-go/trpc-go@v1.0.3/http/service_desc_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http_test
    15  
    16  import (
    17  	"bytes"
    18  	"errors"
    19  	"io"
    20  	"log"
    21  	"mime/multipart"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/require"
    29  	thttp "trpc.group/trpc-go/trpc-go/http"
    30  	"trpc.group/trpc-go/trpc-go/server"
    31  )
    32  
    33  func TestRegisterDefaultService(t *testing.T) {
    34  	defer func() {
    35  		err := recover()
    36  		require.New(t).Contains(err, "duplicate method name")
    37  		thttp.DefaultServerCodec.AutoReadBody = true
    38  		thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0]
    39  	}()
    40  	s := server.New()
    41  	thttp.HandleFunc("/test/path", func(w http.ResponseWriter, r *http.Request) error { return nil })
    42  	thttp.HandleFunc("/test/path", func(w http.ResponseWriter, r *http.Request) error { return nil })
    43  	thttp.RegisterDefaultService(s)
    44  }
    45  
    46  func TestRegisterServiceMux(t *testing.T) {
    47  	defer func() {
    48  		err := recover()
    49  		require.New(t).Contains(err, "duplicate method name")
    50  		thttp.DefaultServerCodec.AutoReadBody = true
    51  		thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0]
    52  	}()
    53  	s := server.New()
    54  	thttp.RegisterServiceMux(s, nil)
    55  	thttp.RegisterServiceMux(s, nil)
    56  }
    57  
    58  func TestMultipartTmpFileCleaning(t *testing.T) {
    59  	// Setup server.
    60  	ln, err := net.Listen("tcp", "localhost:0")
    61  	require.Nil(t, err)
    62  	defer ln.Close()
    63  	serviceName := "trpc.http.server.MultipartTmpFileCleaningTest"
    64  	service := server.New(
    65  		server.WithServiceName(serviceName),
    66  		server.WithNetwork("tcp"),
    67  		server.WithProtocol("http_no_protocol"),
    68  		server.WithListener(ln),
    69  	)
    70  	var tmpFiles []string
    71  	defer func() {
    72  		// Ensure that the temporary files are removed despite the test failure.
    73  		for i := range tmpFiles {
    74  			os.Remove(tmpFiles[i])
    75  		}
    76  	}()
    77  	thttp.HandleFunc("/test/multipart", func(_ http.ResponseWriter, r *http.Request) error {
    78  		const verySmallMaximumMemory = 4
    79  		if err := r.ParseMultipartForm(verySmallMaximumMemory); err != nil {
    80  			log.Println("err: ", err)
    81  			return err
    82  		}
    83  		for _, files := range r.MultipartForm.File {
    84  			f, _ := files[0].Open()
    85  			if osFile, ok := f.(*os.File); ok {
    86  				tmpFiles = append(tmpFiles, osFile.Name())
    87  			}
    88  			f.Close()
    89  		}
    90  		return nil
    91  	})
    92  	defer func() {
    93  		// Remove the registered handle func to ensure the independency of each test. 😅
    94  		thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0]
    95  	}()
    96  	s := &server.Server{}
    97  	s.AddService(serviceName, service)
    98  	thttp.RegisterNoProtocolService(s.Service(serviceName))
    99  	go func() {
   100  		require.Nil(t, s.Serve())
   101  	}()
   102  	defer s.Close(nil)
   103  	time.Sleep(100 * time.Millisecond)
   104  
   105  	// Setup multipart form data.
   106  	const fileSize = 33554432 // 32MB
   107  	largeFileContent := make([]byte, fileSize)
   108  	rd := bytes.NewReader(largeFileContent)
   109  	var b bytes.Buffer
   110  	w := multipart.NewWriter(&b)
   111  	fw, err := w.CreateFormFile("data", "largefile.test")
   112  	require.Nil(t, err)
   113  	_, err = io.Copy(fw, rd)
   114  	require.Nil(t, err)
   115  	require.Nil(t, w.Close())
   116  
   117  	// Setup client.
   118  	req, err := http.NewRequest("POST", "http://"+ln.Addr().String()+"/test/multipart", &b)
   119  	require.Nil(t, err)
   120  	req.Header.Set("Content-Type", w.FormDataContentType())
   121  	client := http.DefaultClient
   122  	res, err := client.Do(req)
   123  	require.Nil(t, err)
   124  	require.Equal(t, res.StatusCode, http.StatusOK)
   125  
   126  	// Check whether temporary file is removed.
   127  	require.Eventually(t, func() bool {
   128  		for i := range tmpFiles {
   129  			if _, err := os.Stat(tmpFiles[i]); !errors.Is(err, os.ErrNotExist) {
   130  				t.Logf("tmp file %s may still exist, err: %+v", tmpFiles[i], err)
   131  				return false
   132  			}
   133  		}
   134  		return true
   135  	}, time.Second, 10*time.Millisecond)
   136  }