go.etcd.io/etcd@v3.3.27+incompatible/pkg/proxy/server_test.go (about)

     1  // Copyright 2018 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"math/rand"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/coreos/etcd/pkg/transport"
    32  
    33  	"go.uber.org/zap"
    34  )
    35  
    36  // enable DebugLevel
    37  var testLogger = zap.NewExample()
    38  
    39  var testTLSInfo = transport.TLSInfo{
    40  	KeyFile:        "./fixtures/server.key.insecure",
    41  	CertFile:       "./fixtures/server.crt",
    42  	TrustedCAFile:  "./fixtures/ca.crt",
    43  	ClientCertAuth: true,
    44  }
    45  
    46  func TestServer_Unix_Insecure(t *testing.T)         { testServer(t, "unix", false, false) }
    47  func TestServer_TCP_Insecure(t *testing.T)          { testServer(t, "tcp", false, false) }
    48  func TestServer_Unix_Secure(t *testing.T)           { testServer(t, "unix", true, false) }
    49  func TestServer_TCP_Secure(t *testing.T)            { testServer(t, "tcp", true, false) }
    50  func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) }
    51  func TestServer_TCP_Insecure_DelayTx(t *testing.T)  { testServer(t, "tcp", false, true) }
    52  func TestServer_Unix_Secure_DelayTx(t *testing.T)   { testServer(t, "unix", true, true) }
    53  func TestServer_TCP_Secure_DelayTx(t *testing.T)    { testServer(t, "tcp", true, true) }
    54  func testServer(t *testing.T, scheme string, secure bool, delayTx bool) {
    55  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
    56  	if scheme == "tcp" {
    57  		ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
    58  		srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String()
    59  		ln1.Close()
    60  		ln2.Close()
    61  	} else {
    62  		defer func() {
    63  			os.RemoveAll(srcAddr)
    64  			os.RemoveAll(dstAddr)
    65  		}()
    66  	}
    67  	tlsInfo := testTLSInfo
    68  	if !secure {
    69  		tlsInfo = transport.TLSInfo{}
    70  	}
    71  	ln := listen(t, scheme, dstAddr, tlsInfo)
    72  	defer ln.Close()
    73  
    74  	cfg := ServerConfig{
    75  		Logger: testLogger,
    76  		From:   url.URL{Scheme: scheme, Host: srcAddr},
    77  		To:     url.URL{Scheme: scheme, Host: dstAddr},
    78  	}
    79  	if secure {
    80  		cfg.TLSInfo = testTLSInfo
    81  	}
    82  	p := NewServer(cfg)
    83  	<-p.Ready()
    84  	defer p.Close()
    85  
    86  	data1 := []byte("Hello World!")
    87  	donec, writec := make(chan struct{}), make(chan []byte)
    88  
    89  	go func() {
    90  		defer close(donec)
    91  		for data := range writec {
    92  			send(t, data, scheme, srcAddr, tlsInfo)
    93  		}
    94  	}()
    95  
    96  	recvc := make(chan []byte)
    97  	go func() {
    98  		for i := 0; i < 2; i++ {
    99  			recvc <- receive(t, ln)
   100  		}
   101  	}()
   102  
   103  	writec <- data1
   104  	now := time.Now()
   105  	if d := <-recvc; !bytes.Equal(data1, d) {
   106  		t.Fatalf("expected %q, got %q", string(data1), string(d))
   107  	}
   108  	took1 := time.Since(now)
   109  	t.Logf("took %v with no latency", took1)
   110  
   111  	lat, rv := 50*time.Millisecond, 5*time.Millisecond
   112  	if delayTx {
   113  		p.DelayTx(lat, rv)
   114  	}
   115  
   116  	data2 := []byte("new data")
   117  	writec <- data2
   118  	now = time.Now()
   119  	if d := <-recvc; !bytes.Equal(data2, d) {
   120  		t.Fatalf("expected %q, got %q", string(data2), string(d))
   121  	}
   122  	took2 := time.Since(now)
   123  	if delayTx {
   124  		t.Logf("took %v with latency %v±%v", took2, lat, rv)
   125  	} else {
   126  		t.Logf("took %v with no latency", took2)
   127  	}
   128  
   129  	if delayTx {
   130  		p.UndelayTx()
   131  		if took1 >= took2 {
   132  			t.Fatalf("expected took1 %v < took2 %v (with latency)", took1, took2)
   133  		}
   134  	}
   135  
   136  	close(writec)
   137  	select {
   138  	case <-donec:
   139  	case <-time.After(3 * time.Second):
   140  		t.Fatal("took too long to write")
   141  	}
   142  
   143  	select {
   144  	case <-p.Done():
   145  		t.Fatal("unexpected done")
   146  	case err := <-p.Error():
   147  		t.Fatal(err)
   148  	default:
   149  	}
   150  
   151  	if err := p.Close(); err != nil {
   152  		t.Fatal(err)
   153  	}
   154  
   155  	select {
   156  	case <-p.Done():
   157  	case err := <-p.Error():
   158  		if !strings.HasPrefix(err.Error(), "accept ") &&
   159  			!strings.HasSuffix(err.Error(), "use of closed network connection") {
   160  			t.Fatal(err)
   161  		}
   162  	case <-time.After(3 * time.Second):
   163  		t.Fatal("took too long to close")
   164  	}
   165  }
   166  
   167  func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) }
   168  func TestServer_Unix_Secure_DelayAccept(t *testing.T)   { testServerDelayAccept(t, true) }
   169  func testServerDelayAccept(t *testing.T, secure bool) {
   170  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   171  	defer func() {
   172  		os.RemoveAll(srcAddr)
   173  		os.RemoveAll(dstAddr)
   174  	}()
   175  	tlsInfo := testTLSInfo
   176  	if !secure {
   177  		tlsInfo = transport.TLSInfo{}
   178  	}
   179  	scheme := "unix"
   180  	ln := listen(t, scheme, dstAddr, tlsInfo)
   181  	defer ln.Close()
   182  
   183  	cfg := ServerConfig{
   184  		Logger: testLogger,
   185  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   186  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   187  	}
   188  	if secure {
   189  		cfg.TLSInfo = testTLSInfo
   190  	}
   191  	p := NewServer(cfg)
   192  	<-p.Ready()
   193  	defer p.Close()
   194  
   195  	data := []byte("Hello World!")
   196  
   197  	now := time.Now()
   198  	send(t, data, scheme, srcAddr, tlsInfo)
   199  	if d := receive(t, ln); !bytes.Equal(data, d) {
   200  		t.Fatalf("expected %q, got %q", string(data), string(d))
   201  	}
   202  	took1 := time.Since(now)
   203  	t.Logf("took %v with no latency", took1)
   204  
   205  	lat, rv := 700*time.Millisecond, 10*time.Millisecond
   206  	p.DelayAccept(lat, rv)
   207  	defer p.UndelayAccept()
   208  	if err := p.ResetListener(); err != nil {
   209  		t.Fatal(err)
   210  	}
   211  	time.Sleep(200 * time.Millisecond)
   212  
   213  	now = time.Now()
   214  	send(t, data, scheme, srcAddr, tlsInfo)
   215  	if d := receive(t, ln); !bytes.Equal(data, d) {
   216  		t.Fatalf("expected %q, got %q", string(data), string(d))
   217  	}
   218  	took2 := time.Since(now)
   219  	t.Logf("took %v with latency %v±%v", took2, lat, rv)
   220  
   221  	if took1 >= took2 {
   222  		t.Fatalf("expected took1 %v < took2 %v", took1, took2)
   223  	}
   224  }
   225  
   226  func TestServer_PauseTx(t *testing.T) {
   227  	scheme := "unix"
   228  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   229  	defer func() {
   230  		os.RemoveAll(srcAddr)
   231  		os.RemoveAll(dstAddr)
   232  	}()
   233  	ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
   234  	defer ln.Close()
   235  
   236  	p := NewServer(ServerConfig{
   237  		Logger: testLogger,
   238  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   239  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   240  	})
   241  	<-p.Ready()
   242  	defer p.Close()
   243  
   244  	p.PauseTx()
   245  
   246  	data := []byte("Hello World!")
   247  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   248  
   249  	recvc := make(chan []byte)
   250  	go func() {
   251  		recvc <- receive(t, ln)
   252  	}()
   253  
   254  	select {
   255  	case d := <-recvc:
   256  		t.Fatalf("received unexpected data %q during pause", string(d))
   257  	case <-time.After(200 * time.Millisecond):
   258  	}
   259  
   260  	p.UnpauseTx()
   261  
   262  	select {
   263  	case d := <-recvc:
   264  		if !bytes.Equal(data, d) {
   265  			t.Fatalf("expected %q, got %q", string(data), string(d))
   266  		}
   267  	case <-time.After(2 * time.Second):
   268  		t.Fatal("took too long to receive after unpause")
   269  	}
   270  }
   271  
   272  func TestServer_BlackholeTx(t *testing.T) {
   273  	scheme := "unix"
   274  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   275  	defer func() {
   276  		os.RemoveAll(srcAddr)
   277  		os.RemoveAll(dstAddr)
   278  	}()
   279  	ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
   280  	defer ln.Close()
   281  
   282  	p := NewServer(ServerConfig{
   283  		Logger: testLogger,
   284  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   285  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   286  	})
   287  	<-p.Ready()
   288  	defer p.Close()
   289  
   290  	p.BlackholeTx()
   291  
   292  	data := []byte("Hello World!")
   293  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   294  
   295  	recvc := make(chan []byte)
   296  	go func() {
   297  		recvc <- receive(t, ln)
   298  	}()
   299  
   300  	select {
   301  	case d := <-recvc:
   302  		t.Fatalf("unexpected data receive %q during blackhole", string(d))
   303  	case <-time.After(200 * time.Millisecond):
   304  	}
   305  
   306  	p.UnblackholeTx()
   307  
   308  	// expect different data, old data dropped
   309  	data[0]++
   310  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   311  
   312  	select {
   313  	case d := <-recvc:
   314  		if !bytes.Equal(data, d) {
   315  			t.Fatalf("expected %q, got %q", string(data), string(d))
   316  		}
   317  	case <-time.After(2 * time.Second):
   318  		t.Fatal("took too long to receive after unblackhole")
   319  	}
   320  }
   321  
   322  func TestServer_CorruptTx(t *testing.T) {
   323  	scheme := "unix"
   324  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   325  	defer func() {
   326  		os.RemoveAll(srcAddr)
   327  		os.RemoveAll(dstAddr)
   328  	}()
   329  	ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
   330  	defer ln.Close()
   331  
   332  	p := NewServer(ServerConfig{
   333  		Logger: testLogger,
   334  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   335  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   336  	})
   337  	<-p.Ready()
   338  	defer p.Close()
   339  
   340  	p.CorruptTx(func(d []byte) []byte {
   341  		d[len(d)/2]++
   342  		return d
   343  	})
   344  	data := []byte("Hello World!")
   345  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   346  	if d := receive(t, ln); bytes.Equal(d, data) {
   347  		t.Fatalf("expected corrupted data, got %q", string(d))
   348  	}
   349  
   350  	p.UncorruptTx()
   351  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   352  	if d := receive(t, ln); !bytes.Equal(d, data) {
   353  		t.Fatalf("expected uncorrupted data, got %q", string(d))
   354  	}
   355  }
   356  
   357  func TestServer_Shutdown(t *testing.T) {
   358  	scheme := "unix"
   359  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   360  	defer func() {
   361  		os.RemoveAll(srcAddr)
   362  		os.RemoveAll(dstAddr)
   363  	}()
   364  	ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
   365  	defer ln.Close()
   366  
   367  	p := NewServer(ServerConfig{
   368  		Logger: testLogger,
   369  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   370  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   371  	})
   372  	<-p.Ready()
   373  	defer p.Close()
   374  
   375  	px, _ := p.(*proxyServer)
   376  	px.listener.Close()
   377  	time.Sleep(200 * time.Millisecond)
   378  
   379  	data := []byte("Hello World!")
   380  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   381  	if d := receive(t, ln); !bytes.Equal(d, data) {
   382  		t.Fatalf("expected %q, got %q", string(data), string(d))
   383  	}
   384  }
   385  
   386  func TestServer_ShutdownListener(t *testing.T) {
   387  	scheme := "unix"
   388  	srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
   389  	defer func() {
   390  		os.RemoveAll(srcAddr)
   391  		os.RemoveAll(dstAddr)
   392  	}()
   393  
   394  	ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
   395  	defer ln.Close()
   396  
   397  	p := NewServer(ServerConfig{
   398  		Logger: testLogger,
   399  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   400  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   401  	})
   402  	<-p.Ready()
   403  	defer p.Close()
   404  
   405  	// shut down destination
   406  	ln.Close()
   407  	time.Sleep(200 * time.Millisecond)
   408  
   409  	ln = listen(t, scheme, dstAddr, transport.TLSInfo{})
   410  	defer ln.Close()
   411  
   412  	data := []byte("Hello World!")
   413  	send(t, data, scheme, srcAddr, transport.TLSInfo{})
   414  	if d := receive(t, ln); !bytes.Equal(d, data) {
   415  		t.Fatalf("expected %q, got %q", string(data), string(d))
   416  	}
   417  }
   418  
   419  func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) }
   420  func TestServerHTTP_Secure_DelayTx(t *testing.T)   { testServerHTTP(t, true, true) }
   421  func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) }
   422  func TestServerHTTP_Secure_DelayRx(t *testing.T)   { testServerHTTP(t, true, false) }
   423  func testServerHTTP(t *testing.T, secure, delayTx bool) {
   424  	scheme := "tcp"
   425  	ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{})
   426  	srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
   427  	ln1.Close()
   428  	ln2.Close()
   429  
   430  	mux := http.NewServeMux()
   431  	mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
   432  		d, err := ioutil.ReadAll(req.Body)
   433  		if err != nil {
   434  			t.Fatal(err)
   435  		}
   436  		if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil {
   437  			t.Fatal(err)
   438  		}
   439  	})
   440  	var tlsConfig *tls.Config
   441  	var err error
   442  	if secure {
   443  		tlsConfig, err = testTLSInfo.ServerConfig()
   444  		if err != nil {
   445  			t.Fatal(err)
   446  		}
   447  	}
   448  	srv := &http.Server{
   449  		Addr:      dstAddr,
   450  		Handler:   mux,
   451  		TLSConfig: tlsConfig,
   452  	}
   453  
   454  	donec := make(chan struct{})
   455  	defer func() {
   456  		srv.Close()
   457  		<-donec
   458  	}()
   459  	go func() {
   460  		defer close(donec)
   461  		if !secure {
   462  			srv.ListenAndServe()
   463  		} else {
   464  			srv.ListenAndServeTLS(testTLSInfo.CertFile, testTLSInfo.KeyFile)
   465  		}
   466  	}()
   467  	time.Sleep(200 * time.Millisecond)
   468  
   469  	cfg := ServerConfig{
   470  		Logger: testLogger,
   471  		From:   url.URL{Scheme: scheme, Host: srcAddr},
   472  		To:     url.URL{Scheme: scheme, Host: dstAddr},
   473  	}
   474  	if secure {
   475  		cfg.TLSInfo = testTLSInfo
   476  	}
   477  	p := NewServer(cfg)
   478  	<-p.Ready()
   479  	defer p.Close()
   480  
   481  	data := "Hello World!"
   482  
   483  	now := time.Now()
   484  	var resp *http.Response
   485  	if secure {
   486  		tp, terr := transport.NewTransport(testTLSInfo, 3*time.Second)
   487  		if terr != nil {
   488  			t.Fatal(terr)
   489  		}
   490  		cli := &http.Client{Transport: tp}
   491  		resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
   492  	} else {
   493  		resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
   494  	}
   495  	if err != nil {
   496  		t.Fatal(err)
   497  	}
   498  	d, err := ioutil.ReadAll(resp.Body)
   499  	if err != nil {
   500  		t.Fatal(err)
   501  	}
   502  	took1 := time.Since(now)
   503  	t.Logf("took %v with no latency", took1)
   504  
   505  	rs1 := string(d)
   506  	exp := fmt.Sprintf("%q(confirmed)", data)
   507  	if rs1 != exp {
   508  		t.Fatalf("got %q, expected %q", rs1, exp)
   509  	}
   510  
   511  	lat, rv := 100*time.Millisecond, 10*time.Millisecond
   512  	if delayTx {
   513  		p.DelayTx(lat, rv)
   514  		defer p.UndelayTx()
   515  	} else {
   516  		p.DelayRx(lat, rv)
   517  		defer p.UndelayRx()
   518  	}
   519  
   520  	now = time.Now()
   521  	if secure {
   522  		tp, terr := transport.NewTransport(testTLSInfo, 3*time.Second)
   523  		if terr != nil {
   524  			t.Fatal(terr)
   525  		}
   526  		cli := &http.Client{Transport: tp}
   527  		resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
   528  	} else {
   529  		resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
   530  	}
   531  	if err != nil {
   532  		t.Fatal(err)
   533  	}
   534  	d, err = ioutil.ReadAll(resp.Body)
   535  	if err != nil {
   536  		t.Fatal(err)
   537  	}
   538  	took2 := time.Since(now)
   539  	t.Logf("took %v with latency %v±%v", took2, lat, rv)
   540  
   541  	rs2 := string(d)
   542  	if rs2 != exp {
   543  		t.Fatalf("got %q, expected %q", rs2, exp)
   544  	}
   545  	if took1 > took2 {
   546  		t.Fatalf("expected took1 %v < took2 %v", took1, took2)
   547  	}
   548  }
   549  
   550  func newUnixAddr() string {
   551  	now := time.Now().UnixNano()
   552  	rand.Seed(now)
   553  	addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000))
   554  	os.RemoveAll(addr)
   555  	return addr
   556  }
   557  
   558  func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
   559  	var err error
   560  	if !tlsInfo.Empty() {
   561  		ln, err = transport.NewListener(addr, scheme, &tlsInfo)
   562  	} else {
   563  		ln, err = net.Listen(scheme, addr)
   564  	}
   565  	if err != nil {
   566  		t.Fatal(err)
   567  	}
   568  	return ln
   569  }
   570  
   571  func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) {
   572  	var out net.Conn
   573  	var err error
   574  	if !tlsInfo.Empty() {
   575  		tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
   576  		if terr != nil {
   577  			t.Fatal(terr)
   578  		}
   579  		out, err = tp.Dial(scheme, addr)
   580  	} else {
   581  		out, err = net.Dial(scheme, addr)
   582  	}
   583  	if err != nil {
   584  		t.Fatal(err)
   585  	}
   586  	if _, err = out.Write(data); err != nil {
   587  		t.Fatal(err)
   588  	}
   589  	if err = out.Close(); err != nil {
   590  		t.Fatal(err)
   591  	}
   592  }
   593  
   594  func receive(t *testing.T, ln net.Listener) (data []byte) {
   595  	buf := bytes.NewBuffer(make([]byte, 0, 1024))
   596  	for {
   597  		in, err := ln.Accept()
   598  		if err != nil {
   599  			t.Fatal(err)
   600  		}
   601  		var n int64
   602  		n, err = buf.ReadFrom(in)
   603  		if err != nil {
   604  			t.Fatal(err)
   605  		}
   606  		if n > 0 {
   607  			break
   608  		}
   609  	}
   610  	return buf.Bytes()
   611  }