trpc.group/trpc-go/trpc-go@v1.0.3/http/service_desc_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http_test 15 16 import ( 17 "bytes" 18 "errors" 19 "io" 20 "log" 21 "mime/multipart" 22 "net" 23 "net/http" 24 "os" 25 "testing" 26 "time" 27 28 "github.com/stretchr/testify/require" 29 thttp "trpc.group/trpc-go/trpc-go/http" 30 "trpc.group/trpc-go/trpc-go/server" 31 ) 32 33 func TestRegisterDefaultService(t *testing.T) { 34 defer func() { 35 err := recover() 36 require.New(t).Contains(err, "duplicate method name") 37 thttp.DefaultServerCodec.AutoReadBody = true 38 thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0] 39 }() 40 s := server.New() 41 thttp.HandleFunc("/test/path", func(w http.ResponseWriter, r *http.Request) error { return nil }) 42 thttp.HandleFunc("/test/path", func(w http.ResponseWriter, r *http.Request) error { return nil }) 43 thttp.RegisterDefaultService(s) 44 } 45 46 func TestRegisterServiceMux(t *testing.T) { 47 defer func() { 48 err := recover() 49 require.New(t).Contains(err, "duplicate method name") 50 thttp.DefaultServerCodec.AutoReadBody = true 51 thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0] 52 }() 53 s := server.New() 54 thttp.RegisterServiceMux(s, nil) 55 thttp.RegisterServiceMux(s, nil) 56 } 57 58 func TestMultipartTmpFileCleaning(t *testing.T) { 59 // Setup server. 60 ln, err := net.Listen("tcp", "localhost:0") 61 require.Nil(t, err) 62 defer ln.Close() 63 serviceName := "trpc.http.server.MultipartTmpFileCleaningTest" 64 service := server.New( 65 server.WithServiceName(serviceName), 66 server.WithNetwork("tcp"), 67 server.WithProtocol("http_no_protocol"), 68 server.WithListener(ln), 69 ) 70 var tmpFiles []string 71 defer func() { 72 // Ensure that the temporary files are removed despite the test failure. 73 for i := range tmpFiles { 74 os.Remove(tmpFiles[i]) 75 } 76 }() 77 thttp.HandleFunc("/test/multipart", func(_ http.ResponseWriter, r *http.Request) error { 78 const verySmallMaximumMemory = 4 79 if err := r.ParseMultipartForm(verySmallMaximumMemory); err != nil { 80 log.Println("err: ", err) 81 return err 82 } 83 for _, files := range r.MultipartForm.File { 84 f, _ := files[0].Open() 85 if osFile, ok := f.(*os.File); ok { 86 tmpFiles = append(tmpFiles, osFile.Name()) 87 } 88 f.Close() 89 } 90 return nil 91 }) 92 defer func() { 93 // Remove the registered handle func to ensure the independency of each test. 😅 94 thttp.ServiceDesc.Methods = thttp.ServiceDesc.Methods[:0] 95 }() 96 s := &server.Server{} 97 s.AddService(serviceName, service) 98 thttp.RegisterNoProtocolService(s.Service(serviceName)) 99 go func() { 100 require.Nil(t, s.Serve()) 101 }() 102 defer s.Close(nil) 103 time.Sleep(100 * time.Millisecond) 104 105 // Setup multipart form data. 106 const fileSize = 33554432 // 32MB 107 largeFileContent := make([]byte, fileSize) 108 rd := bytes.NewReader(largeFileContent) 109 var b bytes.Buffer 110 w := multipart.NewWriter(&b) 111 fw, err := w.CreateFormFile("data", "largefile.test") 112 require.Nil(t, err) 113 _, err = io.Copy(fw, rd) 114 require.Nil(t, err) 115 require.Nil(t, w.Close()) 116 117 // Setup client. 118 req, err := http.NewRequest("POST", "http://"+ln.Addr().String()+"/test/multipart", &b) 119 require.Nil(t, err) 120 req.Header.Set("Content-Type", w.FormDataContentType()) 121 client := http.DefaultClient 122 res, err := client.Do(req) 123 require.Nil(t, err) 124 require.Equal(t, res.StatusCode, http.StatusOK) 125 126 // Check whether temporary file is removed. 127 require.Eventually(t, func() bool { 128 for i := range tmpFiles { 129 if _, err := os.Stat(tmpFiles[i]); !errors.Is(err, os.ErrNotExist) { 130 t.Logf("tmp file %s may still exist, err: %+v", tmpFiles[i], err) 131 return false 132 } 133 } 134 return true 135 }, time.Second, 10*time.Millisecond) 136 }