github.com/searKing/golang/go@v1.2.117/net/mux/mux_helper_test.go (about)

     1  // Copyright 2020 The searKing Author. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mux_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"crypto/tls"
    11  	"errors"
    12  	"fmt"
    13  	"go/build"
    14  	"io"
    15  	"io/ioutil"
    16  	"log"
    17  	"net"
    18  	"net/http"
    19  	"net/rpc"
    20  	"os"
    21  	"os/exec"
    22  	"strings"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	net_ "github.com/searKing/golang/go/net"
    28  	"github.com/searKing/golang/go/net/mux"
    29  	"github.com/searKing/golang/go/sync/atomic"
    30  	"github.com/searKing/golang/go/testing/leakcheck"
    31  	"golang.org/x/net/http2"
    32  	"golang.org/x/net/http2/hpack"
    33  )
    34  
    35  const (
    36  	testHTTP1Resp = "http1"
    37  	rpcVal        = 1234
    38  )
    39  
    40  func safeServe(errCh chan<- error, muxl *mux.Server, l net.Listener) {
    41  	if err := muxl.Serve(l); err != nil {
    42  		if errors.Is(err, mux.ErrServerClosed) || errors.Is(err, mux.ErrListenerClosed) {
    43  			return
    44  		}
    45  		if strings.Contains(err.Error(), "use of closed") {
    46  			return
    47  		}
    48  		errCh <- err
    49  	}
    50  }
    51  
    52  func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
    53  	c, err := rpc.Dial(addr.Network(), addr.String())
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  	return c, func() {
    58  		if err := c.Close(); err != nil {
    59  			t.Fatal(err)
    60  		}
    61  	}
    62  }
    63  
    64  type chanListener struct {
    65  	net.Listener
    66  	connCh     chan net.Conn
    67  	inShutdown atomic.Bool
    68  }
    69  
    70  func newChanListener() *chanListener {
    71  	return &chanListener{connCh: make(chan net.Conn, 1)}
    72  }
    73  
    74  func (l *chanListener) Notify(conn net.Conn) {
    75  	if l.inShutdown.Load() {
    76  		return
    77  	}
    78  	l.connCh <- conn
    79  }
    80  
    81  func (l *chanListener) Accept() (net.Conn, error) {
    82  	if c, ok := <-l.connCh; ok {
    83  		return c, nil
    84  	}
    85  	return nil, errors.New("use of closed network connection")
    86  }
    87  
    88  func (l *chanListener) Close() error {
    89  	if l.inShutdown.Load() {
    90  		return nil
    91  	}
    92  
    93  	l.inShutdown.Store(true)
    94  
    95  	close(l.connCh)
    96  
    97  	if l.Listener == nil {
    98  		return nil
    99  	}
   100  	return l.Listener.Close()
   101  }
   102  
   103  func testListener(t leakcheck.Errorfer) net.Listener {
   104  	l, err := net_.LoopbackListener()
   105  	if err != nil {
   106  		t.Errorf(err.Error())
   107  	}
   108  	return net_.OnceCloseListener(l)
   109  }
   110  
   111  type testHTTP1Handler struct{}
   112  
   113  func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   114  	fmt.Fprintf(w, testHTTP1Resp)
   115  }
   116  
   117  func runTestHTTPServer(errCh chan<- error, l net.Listener) {
   118  	var mu sync.Mutex
   119  	conns := make(map[net.Conn]struct{})
   120  
   121  	defer func() {
   122  		mu.Lock()
   123  		for c := range conns {
   124  			if err := c.Close(); err != nil {
   125  				errCh <- err
   126  			}
   127  		}
   128  		mu.Unlock()
   129  	}()
   130  
   131  	s := &http.Server{
   132  		Handler: &testHTTP1Handler{},
   133  		ConnState: func(c net.Conn, state http.ConnState) {
   134  			mu.Lock()
   135  			switch state {
   136  			case http.StateNew:
   137  				conns[c] = struct{}{}
   138  			case http.StateClosed:
   139  				delete(conns, c)
   140  			}
   141  			mu.Unlock()
   142  		},
   143  	}
   144  	if err := s.Serve(l); err != mux.ErrListenerClosed {
   145  		errCh <- err
   146  	}
   147  }
   148  
   149  func generateTLSCert(t *testing.T) {
   150  	err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  }
   155  
   156  func cleanupTLSCert(t *testing.T) {
   157  	err := os.Remove("cert.pem")
   158  	if err != nil {
   159  		t.Error(err)
   160  	}
   161  	err = os.Remove("key.pem")
   162  	if err != nil {
   163  		t.Error(err)
   164  	}
   165  }
   166  
   167  func runTestTLSServer(errCh chan<- error, l net.Listener) {
   168  	certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
   169  	if err != nil {
   170  		errCh <- err
   171  		log.Printf("1")
   172  		return
   173  	}
   174  
   175  	config := &tls.Config{
   176  		Certificates: []tls.Certificate{certificate},
   177  		Rand:         rand.Reader,
   178  	}
   179  
   180  	tlsl := tls.NewListener(l, config)
   181  	runTestHTTPServer(errCh, tlsl)
   182  }
   183  
   184  func runTestHTTP1Client(t *testing.T, addr net.Addr) {
   185  	runTestHTTPClient(t, "http", addr)
   186  }
   187  
   188  func runTestTLSClient(t *testing.T, addr net.Addr) {
   189  	runTestHTTPClient(t, "https", addr)
   190  }
   191  
   192  func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
   193  	client := http.Client{
   194  		Timeout: 5 * time.Second,
   195  		Transport: &http.Transport{
   196  			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
   197  		},
   198  	}
   199  	r, err := client.Get(proto + "://" + addr.String())
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  
   204  	defer func() {
   205  		if err = r.Body.Close(); err != nil {
   206  			t.Fatal(err)
   207  		}
   208  	}()
   209  
   210  	b, err := ioutil.ReadAll(r.Body)
   211  	if err != nil {
   212  		t.Fatal(err)
   213  	}
   214  	if string(b) != testHTTP1Resp {
   215  		t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b)
   216  	}
   217  }
   218  
   219  type TestRPCRcvr struct{}
   220  
   221  func (r TestRPCRcvr) Test(i int, j *int) error {
   222  	*j = i
   223  	return nil
   224  }
   225  
   226  func runTestRPCServer(errCh chan<- error, l net.Listener) {
   227  	s := rpc.NewServer()
   228  	if err := s.Register(TestRPCRcvr{}); err != nil {
   229  		errCh <- err
   230  	}
   231  	for {
   232  		c, err := l.Accept()
   233  		if err != nil {
   234  			if err != mux.ErrListenerClosed {
   235  				errCh <- err
   236  			}
   237  			return
   238  		}
   239  		go s.ServeConn(c)
   240  	}
   241  }
   242  
   243  func runTestRPCClient(t *testing.T, addr net.Addr) {
   244  	c, clean := safeDial(t, addr)
   245  	defer clean()
   246  
   247  	var num int
   248  	if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
   249  		t.Fatal(err)
   250  	}
   251  
   252  	if num != rpcVal {
   253  		t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
   254  	}
   255  }
   256  
   257  func testHTTP2HeaderField(
   258  	t *testing.T,
   259  	matcherConstructor func(sendSetting bool,
   260  		expects ...hpack.HeaderField) mux.MatcherFunc,
   261  	headerValue string,
   262  	matchValue string,
   263  	notMatchValue string,
   264  ) {
   265  	defer leakcheck.Check(t)
   266  	errCh := make(chan error)
   267  	defer func() {
   268  		for {
   269  			select {
   270  			case err, ok := <-errCh:
   271  				if !ok {
   272  					return
   273  				}
   274  				t.Fatal(err)
   275  			default:
   276  				close(errCh)
   277  				return
   278  			}
   279  		}
   280  	}()
   281  	name := "name"
   282  	writer, reader := net.Pipe()
   283  	go func() {
   284  		if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
   285  			t.Fatal(err)
   286  		}
   287  		var buf bytes.Buffer
   288  		enc := hpack.NewEncoder(&buf)
   289  		if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil {
   290  			t.Fatal(err)
   291  		}
   292  		framer := http2.NewFramer(writer, nil)
   293  		if err := framer.WriteSettingsAck(); err != nil {
   294  			t.Fatal(err)
   295  		}
   296  
   297  		if err := framer.WriteHeaders(http2.HeadersFrameParam{
   298  			StreamID:      1,
   299  			BlockFragment: buf.Bytes(),
   300  			EndStream:     true,
   301  			EndHeaders:    true,
   302  		}); err != nil {
   303  			t.Fatal(err)
   304  		}
   305  		if err := writer.Close(); err != nil {
   306  			t.Fatal(err)
   307  		}
   308  	}()
   309  
   310  	muxer := mux.NewServeMux()
   311  
   312  	l := newChanListener()
   313  	l.Notify(reader)
   314  	// Register a bogus matcher that only reads one byte.
   315  	muxl := muxer.HandleListener(mux.MatcherFunc(func(w io.Writer, r io.Reader) bool {
   316  		var b [1]byte
   317  		_, _ = r.Read(b[:])
   318  		return false
   319  	}))
   320  	defer muxl.Close()
   321  
   322  	// Create a matcher that cannot match the response.
   323  	//muxl.Match(matcherConstructor(false, hpack.HeaderField{Name: name, Value: notMatchValue}))
   324  	// Then match with the expected field.
   325  	h2l := muxer.HandleListener(matcherConstructor(false, hpack.HeaderField{Name: name, Value: matchValue}))
   326  	defer h2l.Close()
   327  
   328  	srv := mux.NewServer()
   329  	defer srv.Close()
   330  	srv.Handler = muxer
   331  	go func() {
   332  		safeServe(errCh, srv, l)
   333  	}()
   334  	muxedConn, err := h2l.Accept()
   335  	_ = l.Close()
   336  	if err != nil {
   337  		t.Fatal(err)
   338  	}
   339  	var b [len(http2.ClientPreface)]byte
   340  	// We have the sniffed buffer first...
   341  	if _, err := muxedConn.Read(b[:]); err == io.EOF {
   342  		t.Fatal(err)
   343  	}
   344  	if string(b[:]) != http2.ClientPreface {
   345  		t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
   346  	}
   347  }