github.com/searKing/golang/go@v1.2.117/net/mux/mux_helper_test.go (about) 1 // Copyright 2020 The searKing Author. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mux_test 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "crypto/tls" 11 "errors" 12 "fmt" 13 "go/build" 14 "io" 15 "io/ioutil" 16 "log" 17 "net" 18 "net/http" 19 "net/rpc" 20 "os" 21 "os/exec" 22 "strings" 23 "sync" 24 "testing" 25 "time" 26 27 net_ "github.com/searKing/golang/go/net" 28 "github.com/searKing/golang/go/net/mux" 29 "github.com/searKing/golang/go/sync/atomic" 30 "github.com/searKing/golang/go/testing/leakcheck" 31 "golang.org/x/net/http2" 32 "golang.org/x/net/http2/hpack" 33 ) 34 35 const ( 36 testHTTP1Resp = "http1" 37 rpcVal = 1234 38 ) 39 40 func safeServe(errCh chan<- error, muxl *mux.Server, l net.Listener) { 41 if err := muxl.Serve(l); err != nil { 42 if errors.Is(err, mux.ErrServerClosed) || errors.Is(err, mux.ErrListenerClosed) { 43 return 44 } 45 if strings.Contains(err.Error(), "use of closed") { 46 return 47 } 48 errCh <- err 49 } 50 } 51 52 func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) { 53 c, err := rpc.Dial(addr.Network(), addr.String()) 54 if err != nil { 55 t.Fatal(err) 56 } 57 return c, func() { 58 if err := c.Close(); err != nil { 59 t.Fatal(err) 60 } 61 } 62 } 63 64 type chanListener struct { 65 net.Listener 66 connCh chan net.Conn 67 inShutdown atomic.Bool 68 } 69 70 func newChanListener() *chanListener { 71 return &chanListener{connCh: make(chan net.Conn, 1)} 72 } 73 74 func (l *chanListener) Notify(conn net.Conn) { 75 if l.inShutdown.Load() { 76 return 77 } 78 l.connCh <- conn 79 } 80 81 func (l *chanListener) Accept() (net.Conn, error) { 82 if c, ok := <-l.connCh; ok { 83 return c, nil 84 } 85 return nil, errors.New("use of closed network connection") 86 } 87 88 func (l *chanListener) Close() error { 89 if l.inShutdown.Load() { 90 return nil 91 } 92 93 l.inShutdown.Store(true) 94 95 close(l.connCh) 96 97 if l.Listener == nil { 98 return nil 99 } 100 return l.Listener.Close() 101 } 102 103 func testListener(t leakcheck.Errorfer) net.Listener { 104 l, err := net_.LoopbackListener() 105 if err != nil { 106 t.Errorf(err.Error()) 107 } 108 return net_.OnceCloseListener(l) 109 } 110 111 type testHTTP1Handler struct{} 112 113 func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 114 fmt.Fprintf(w, testHTTP1Resp) 115 } 116 117 func runTestHTTPServer(errCh chan<- error, l net.Listener) { 118 var mu sync.Mutex 119 conns := make(map[net.Conn]struct{}) 120 121 defer func() { 122 mu.Lock() 123 for c := range conns { 124 if err := c.Close(); err != nil { 125 errCh <- err 126 } 127 } 128 mu.Unlock() 129 }() 130 131 s := &http.Server{ 132 Handler: &testHTTP1Handler{}, 133 ConnState: func(c net.Conn, state http.ConnState) { 134 mu.Lock() 135 switch state { 136 case http.StateNew: 137 conns[c] = struct{}{} 138 case http.StateClosed: 139 delete(conns, c) 140 } 141 mu.Unlock() 142 }, 143 } 144 if err := s.Serve(l); err != mux.ErrListenerClosed { 145 errCh <- err 146 } 147 } 148 149 func generateTLSCert(t *testing.T) { 150 err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run() 151 if err != nil { 152 t.Fatal(err) 153 } 154 } 155 156 func cleanupTLSCert(t *testing.T) { 157 err := os.Remove("cert.pem") 158 if err != nil { 159 t.Error(err) 160 } 161 err = os.Remove("key.pem") 162 if err != nil { 163 t.Error(err) 164 } 165 } 166 167 func runTestTLSServer(errCh chan<- error, l net.Listener) { 168 certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem") 169 if err != nil { 170 errCh <- err 171 log.Printf("1") 172 return 173 } 174 175 config := &tls.Config{ 176 Certificates: []tls.Certificate{certificate}, 177 Rand: rand.Reader, 178 } 179 180 tlsl := tls.NewListener(l, config) 181 runTestHTTPServer(errCh, tlsl) 182 } 183 184 func runTestHTTP1Client(t *testing.T, addr net.Addr) { 185 runTestHTTPClient(t, "http", addr) 186 } 187 188 func runTestTLSClient(t *testing.T, addr net.Addr) { 189 runTestHTTPClient(t, "https", addr) 190 } 191 192 func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) { 193 client := http.Client{ 194 Timeout: 5 * time.Second, 195 Transport: &http.Transport{ 196 TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 197 }, 198 } 199 r, err := client.Get(proto + "://" + addr.String()) 200 if err != nil { 201 t.Fatal(err) 202 } 203 204 defer func() { 205 if err = r.Body.Close(); err != nil { 206 t.Fatal(err) 207 } 208 }() 209 210 b, err := ioutil.ReadAll(r.Body) 211 if err != nil { 212 t.Fatal(err) 213 } 214 if string(b) != testHTTP1Resp { 215 t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b) 216 } 217 } 218 219 type TestRPCRcvr struct{} 220 221 func (r TestRPCRcvr) Test(i int, j *int) error { 222 *j = i 223 return nil 224 } 225 226 func runTestRPCServer(errCh chan<- error, l net.Listener) { 227 s := rpc.NewServer() 228 if err := s.Register(TestRPCRcvr{}); err != nil { 229 errCh <- err 230 } 231 for { 232 c, err := l.Accept() 233 if err != nil { 234 if err != mux.ErrListenerClosed { 235 errCh <- err 236 } 237 return 238 } 239 go s.ServeConn(c) 240 } 241 } 242 243 func runTestRPCClient(t *testing.T, addr net.Addr) { 244 c, clean := safeDial(t, addr) 245 defer clean() 246 247 var num int 248 if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil { 249 t.Fatal(err) 250 } 251 252 if num != rpcVal { 253 t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num) 254 } 255 } 256 257 func testHTTP2HeaderField( 258 t *testing.T, 259 matcherConstructor func(sendSetting bool, 260 expects ...hpack.HeaderField) mux.MatcherFunc, 261 headerValue string, 262 matchValue string, 263 notMatchValue string, 264 ) { 265 defer leakcheck.Check(t) 266 errCh := make(chan error) 267 defer func() { 268 for { 269 select { 270 case err, ok := <-errCh: 271 if !ok { 272 return 273 } 274 t.Fatal(err) 275 default: 276 close(errCh) 277 return 278 } 279 } 280 }() 281 name := "name" 282 writer, reader := net.Pipe() 283 go func() { 284 if _, err := io.WriteString(writer, http2.ClientPreface); err != nil { 285 t.Fatal(err) 286 } 287 var buf bytes.Buffer 288 enc := hpack.NewEncoder(&buf) 289 if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil { 290 t.Fatal(err) 291 } 292 framer := http2.NewFramer(writer, nil) 293 if err := framer.WriteSettingsAck(); err != nil { 294 t.Fatal(err) 295 } 296 297 if err := framer.WriteHeaders(http2.HeadersFrameParam{ 298 StreamID: 1, 299 BlockFragment: buf.Bytes(), 300 EndStream: true, 301 EndHeaders: true, 302 }); err != nil { 303 t.Fatal(err) 304 } 305 if err := writer.Close(); err != nil { 306 t.Fatal(err) 307 } 308 }() 309 310 muxer := mux.NewServeMux() 311 312 l := newChanListener() 313 l.Notify(reader) 314 // Register a bogus matcher that only reads one byte. 315 muxl := muxer.HandleListener(mux.MatcherFunc(func(w io.Writer, r io.Reader) bool { 316 var b [1]byte 317 _, _ = r.Read(b[:]) 318 return false 319 })) 320 defer muxl.Close() 321 322 // Create a matcher that cannot match the response. 323 //muxl.Match(matcherConstructor(false, hpack.HeaderField{Name: name, Value: notMatchValue})) 324 // Then match with the expected field. 325 h2l := muxer.HandleListener(matcherConstructor(false, hpack.HeaderField{Name: name, Value: matchValue})) 326 defer h2l.Close() 327 328 srv := mux.NewServer() 329 defer srv.Close() 330 srv.Handler = muxer 331 go func() { 332 safeServe(errCh, srv, l) 333 }() 334 muxedConn, err := h2l.Accept() 335 _ = l.Close() 336 if err != nil { 337 t.Fatal(err) 338 } 339 var b [len(http2.ClientPreface)]byte 340 // We have the sniffed buffer first... 341 if _, err := muxedConn.Read(b[:]); err == io.EOF { 342 t.Fatal(err) 343 } 344 if string(b[:]) != http2.ClientPreface { 345 t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface) 346 } 347 }