go.etcd.io/etcd@v3.3.27+incompatible/pkg/transport/listener_test.go (about)

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"errors"
    21  	"io/ioutil"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
    30  	return createSelfCertEx("127.0.0.1")
    31  }
    32  
    33  func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
    34  	d, terr := ioutil.TempDir("", "etcd-test-tls-")
    35  	if terr != nil {
    36  		return nil, nil, terr
    37  	}
    38  	info, err := SelfCert(d, []string{host + ":0"}, additionalUsages...)
    39  	if err != nil {
    40  		return nil, nil, err
    41  	}
    42  	return &info, func() { os.RemoveAll(d) }, nil
    43  }
    44  
    45  func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
    46  	return func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
    47  		return cert, err
    48  	}
    49  }
    50  
    51  // TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
    52  // a TLS listener that accepts TLS connections.
    53  func TestNewListenerTLSInfo(t *testing.T) {
    54  	tlsInfo, del, err := createSelfCert()
    55  	if err != nil {
    56  		t.Fatalf("unable to create cert: %v", err)
    57  	}
    58  	defer del()
    59  	testNewListenerTLSInfoAccept(t, *tlsInfo)
    60  }
    61  
    62  func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
    63  	ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
    64  	if err != nil {
    65  		t.Fatalf("unexpected NewListener error: %v", err)
    66  	}
    67  	defer ln.Close()
    68  
    69  	tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
    70  	cli := &http.Client{Transport: tr}
    71  	go cli.Get("https://" + ln.Addr().String())
    72  
    73  	conn, err := ln.Accept()
    74  	if err != nil {
    75  		t.Fatalf("unexpected Accept error: %v", err)
    76  	}
    77  	defer conn.Close()
    78  	if _, ok := conn.(*tls.Conn); !ok {
    79  		t.Error("failed to accept *tls.Conn")
    80  	}
    81  }
    82  
    83  // TestNewListenerTLSInfoSkipClientSANVerify tests that if client IP address mismatches
    84  // with specified address in its certificate the connection is still accepted
    85  // if the flag SkipClientSANVerify is set (i.e. checkSAN() is disabled for the client side)
    86  func TestNewListenerTLSInfoSkipClientSANVerify(t *testing.T) {
    87  	tests := []struct {
    88  		skipClientSANVerify bool
    89  		goodClientHost      bool
    90  		acceptExpected      bool
    91  	}{
    92  		{false, true, true},
    93  		{false, false, false},
    94  		{true, true, true},
    95  		{true, false, true},
    96  	}
    97  	for _, test := range tests {
    98  		testNewListenerTLSInfoClientCheck(t, test.skipClientSANVerify, test.goodClientHost, test.acceptExpected)
    99  	}
   100  }
   101  
   102  func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientSANVerify, goodClientHost, acceptExpected bool) {
   103  	tlsInfo, del, err := createSelfCert()
   104  	if err != nil {
   105  		t.Fatalf("unable to create cert: %v", err)
   106  	}
   107  	defer del()
   108  
   109  	host := "127.0.0.222"
   110  	if goodClientHost {
   111  		host = "127.0.0.1"
   112  	}
   113  	clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
   114  	if err != nil {
   115  		t.Fatalf("unable to create cert: %v", err)
   116  	}
   117  	defer del2()
   118  
   119  	tlsInfo.SkipClientSANVerify = skipClientSANVerify
   120  	tlsInfo.CAFile = clientTLSInfo.CertFile
   121  
   122  	rootCAs := x509.NewCertPool()
   123  	loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
   124  	if err != nil {
   125  		t.Fatalf("unexpected missing certfile: %v", err)
   126  	}
   127  	rootCAs.AppendCertsFromPEM(loaded)
   128  
   129  	clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
   130  	if err != nil {
   131  		t.Fatalf("unable to create peer cert: %v", err)
   132  	}
   133  
   134  	tlsConfig := &tls.Config{}
   135  	tlsConfig.InsecureSkipVerify = false
   136  	tlsConfig.Certificates = []tls.Certificate{clientCert}
   137  	tlsConfig.RootCAs = rootCAs
   138  
   139  	ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
   140  	if err != nil {
   141  		t.Fatalf("unexpected NewListener error: %v", err)
   142  	}
   143  	defer ln.Close()
   144  
   145  	tr := &http.Transport{TLSClientConfig: tlsConfig}
   146  	cli := &http.Client{Transport: tr}
   147  	chClientErr := make(chan error)
   148  	go func() {
   149  		_, err := cli.Get("https://" + ln.Addr().String())
   150  		chClientErr <- err
   151  	}()
   152  
   153  	chAcceptErr := make(chan error)
   154  	chAcceptConn := make(chan net.Conn)
   155  	go func() {
   156  		conn, err := ln.Accept()
   157  		if err != nil {
   158  			chAcceptErr <- err
   159  		} else {
   160  			chAcceptConn <- conn
   161  		}
   162  	}()
   163  
   164  	select {
   165  	case <-chClientErr:
   166  		if acceptExpected {
   167  			t.Errorf("accepted for good client address: skipClientSANVerify=%v, goodClientHost=%v", skipClientSANVerify, goodClientHost)
   168  		}
   169  	case acceptErr := <-chAcceptErr:
   170  		t.Fatalf("unexpected Accept error: %v", acceptErr)
   171  	case conn := <-chAcceptConn:
   172  		defer conn.Close()
   173  		if _, ok := conn.(*tls.Conn); !ok {
   174  			t.Errorf("failed to accept *tls.Conn")
   175  		}
   176  		if !acceptExpected {
   177  			t.Errorf("accepted for bad client address: skipClientSANVerify=%v, goodClientHost=%v", skipClientSANVerify, goodClientHost)
   178  		}
   179  	}
   180  }
   181  func TestNewListenerTLSEmptyInfo(t *testing.T) {
   182  	_, err := NewListener("127.0.0.1:0", "https", nil)
   183  	if err == nil {
   184  		t.Errorf("err = nil, want not presented error")
   185  	}
   186  }
   187  
   188  func TestNewTransportTLSInfo(t *testing.T) {
   189  	tlsinfo, del, err := createSelfCert()
   190  	if err != nil {
   191  		t.Fatalf("unable to create cert: %v", err)
   192  	}
   193  	defer del()
   194  
   195  	tests := []TLSInfo{
   196  		{},
   197  		{
   198  			CertFile: tlsinfo.CertFile,
   199  			KeyFile:  tlsinfo.KeyFile,
   200  		},
   201  		{
   202  			CertFile: tlsinfo.CertFile,
   203  			KeyFile:  tlsinfo.KeyFile,
   204  			CAFile:   tlsinfo.CAFile,
   205  		},
   206  		{
   207  			CAFile: tlsinfo.CAFile,
   208  		},
   209  	}
   210  
   211  	for i, tt := range tests {
   212  		tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
   213  		trans, err := NewTransport(tt, time.Second)
   214  		if err != nil {
   215  			t.Fatalf("Received unexpected error from NewTransport: %v", err)
   216  		}
   217  
   218  		if trans.TLSClientConfig == nil {
   219  			t.Fatalf("#%d: want non-nil TLSClientConfig", i)
   220  		}
   221  	}
   222  }
   223  
   224  func TestTLSInfoNonexist(t *testing.T) {
   225  	tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
   226  	_, err := tlsInfo.ServerConfig()
   227  	werr := &os.PathError{
   228  		Op:   "open",
   229  		Path: "@badname",
   230  		Err:  errors.New("no such file or directory"),
   231  	}
   232  	if err.Error() != werr.Error() {
   233  		t.Errorf("err = %v, want %v", err, werr)
   234  	}
   235  }
   236  
   237  func TestTLSInfoEmpty(t *testing.T) {
   238  	tests := []struct {
   239  		info TLSInfo
   240  		want bool
   241  	}{
   242  		{TLSInfo{}, true},
   243  		{TLSInfo{CAFile: "baz"}, true},
   244  		{TLSInfo{CertFile: "foo"}, false},
   245  		{TLSInfo{KeyFile: "bar"}, false},
   246  		{TLSInfo{CertFile: "foo", KeyFile: "bar"}, false},
   247  		{TLSInfo{CertFile: "foo", CAFile: "baz"}, false},
   248  		{TLSInfo{KeyFile: "bar", CAFile: "baz"}, false},
   249  		{TLSInfo{CertFile: "foo", KeyFile: "bar", CAFile: "baz"}, false},
   250  	}
   251  
   252  	for i, tt := range tests {
   253  		got := tt.info.Empty()
   254  		if tt.want != got {
   255  			t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got)
   256  		}
   257  	}
   258  }
   259  
   260  func TestTLSInfoMissingFields(t *testing.T) {
   261  	tlsinfo, del, err := createSelfCert()
   262  	if err != nil {
   263  		t.Fatalf("unable to create cert: %v", err)
   264  	}
   265  	defer del()
   266  
   267  	tests := []TLSInfo{
   268  		{CertFile: tlsinfo.CertFile},
   269  		{KeyFile: tlsinfo.KeyFile},
   270  		{CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile},
   271  		{KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile},
   272  	}
   273  
   274  	for i, info := range tests {
   275  		if _, err = info.ServerConfig(); err == nil {
   276  			t.Errorf("#%d: expected non-nil error from ServerConfig()", i)
   277  		}
   278  
   279  		if _, err = info.ClientConfig(); err == nil {
   280  			t.Errorf("#%d: expected non-nil error from ClientConfig()", i)
   281  		}
   282  	}
   283  }
   284  
   285  func TestTLSInfoParseFuncError(t *testing.T) {
   286  	tlsinfo, del, err := createSelfCert()
   287  	if err != nil {
   288  		t.Fatalf("unable to create cert: %v", err)
   289  	}
   290  	defer del()
   291  
   292  	tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
   293  
   294  	if _, err = tlsinfo.ServerConfig(); err == nil {
   295  		t.Errorf("expected non-nil error from ServerConfig()")
   296  	}
   297  
   298  	if _, err = tlsinfo.ClientConfig(); err == nil {
   299  		t.Errorf("expected non-nil error from ClientConfig()")
   300  	}
   301  }
   302  
   303  func TestTLSInfoConfigFuncs(t *testing.T) {
   304  	tlsinfo, del, err := createSelfCert()
   305  	if err != nil {
   306  		t.Fatalf("unable to create cert: %v", err)
   307  	}
   308  	defer del()
   309  
   310  	tests := []struct {
   311  		info       TLSInfo
   312  		clientAuth tls.ClientAuthType
   313  		wantCAs    bool
   314  	}{
   315  		{
   316  			info:       TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile},
   317  			clientAuth: tls.NoClientCert,
   318  			wantCAs:    false,
   319  		},
   320  
   321  		{
   322  			info:       TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile},
   323  			clientAuth: tls.RequireAndVerifyClientCert,
   324  			wantCAs:    true,
   325  		},
   326  	}
   327  
   328  	for i, tt := range tests {
   329  		tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
   330  
   331  		sCfg, err := tt.info.ServerConfig()
   332  		if err != nil {
   333  			t.Errorf("#%d: expected nil error from ServerConfig(), got non-nil: %v", i, err)
   334  		}
   335  
   336  		if tt.wantCAs != (sCfg.ClientCAs != nil) {
   337  			t.Errorf("#%d: wantCAs=%t but ClientCAs=%v", i, tt.wantCAs, sCfg.ClientCAs)
   338  		}
   339  
   340  		cCfg, err := tt.info.ClientConfig()
   341  		if err != nil {
   342  			t.Errorf("#%d: expected nil error from ClientConfig(), got non-nil: %v", i, err)
   343  		}
   344  
   345  		if tt.wantCAs != (cCfg.RootCAs != nil) {
   346  			t.Errorf("#%d: wantCAs=%t but RootCAs=%v", i, tt.wantCAs, sCfg.RootCAs)
   347  		}
   348  	}
   349  }
   350  
   351  func TestNewListenerUnixSocket(t *testing.T) {
   352  	l, err := NewListener("testsocket", "unix", nil)
   353  	if err != nil {
   354  		t.Errorf("error listening on unix socket (%v)", err)
   355  	}
   356  	l.Close()
   357  }
   358  
   359  // TestNewListenerTLSInfoSelfCert tests that a new certificate accepts connections.
   360  func TestNewListenerTLSInfoSelfCert(t *testing.T) {
   361  	tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir")
   362  	if err != nil {
   363  		t.Fatal(err)
   364  	}
   365  	defer os.RemoveAll(tmpdir)
   366  	tlsinfo, err := SelfCert(tmpdir, []string{"127.0.0.1"})
   367  	if err != nil {
   368  		t.Fatal(err)
   369  	}
   370  	if tlsinfo.Empty() {
   371  		t.Fatalf("tlsinfo should have certs (%+v)", tlsinfo)
   372  	}
   373  	testNewListenerTLSInfoAccept(t, tlsinfo)
   374  }
   375  
   376  func TestIsClosedConnError(t *testing.T) {
   377  	l, err := NewListener("testsocket", "unix", nil)
   378  	if err != nil {
   379  		t.Errorf("error listening on unix socket (%v)", err)
   380  	}
   381  	l.Close()
   382  	_, err = l.Accept()
   383  	if !IsClosedConnError(err) {
   384  		t.Fatalf("expect true, got false (%v)", err)
   385  	}
   386  }