gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/gmtls/conn_test.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  gmtls是基于`golang/go`的`tls`包实现的国密改造版本。
    11  对应版权声明: thrid_licenses/github.com/golang/go/LICENSE
    12  */
    13  
    14  package gmtls
    15  
    16  import (
    17  	"bytes"
    18  	"io"
    19  	"net"
    20  	"testing"
    21  )
    22  
    23  func TestRoundUp(t *testing.T) {
    24  	if roundUp(0, 16) != 0 ||
    25  		roundUp(1, 16) != 16 ||
    26  		roundUp(15, 16) != 16 ||
    27  		roundUp(16, 16) != 16 ||
    28  		roundUp(17, 16) != 32 {
    29  		t.Error("roundUp broken")
    30  	}
    31  }
    32  
    33  // will be initialized with {0, 255, 255, ..., 255}
    34  var padding255Bad = [256]byte{}
    35  
    36  // will be initialized with {255, 255, 255, ..., 255}
    37  var padding255Good = [256]byte{255}
    38  
    39  var paddingTests = []struct {
    40  	in          []byte
    41  	good        bool
    42  	expectedLen int
    43  }{
    44  	{[]byte{1, 2, 3, 4, 0}, true, 4},
    45  	{[]byte{1, 2, 3, 4, 0, 1}, false, 0},
    46  	{[]byte{1, 2, 3, 4, 99, 99}, false, 0},
    47  	{[]byte{1, 2, 3, 4, 1, 1}, true, 4},
    48  	{[]byte{1, 2, 3, 2, 2, 2}, true, 3},
    49  	{[]byte{1, 2, 3, 3, 3, 3}, true, 2},
    50  	{[]byte{1, 2, 3, 4, 3, 3}, false, 0},
    51  	{[]byte{1, 4, 4, 4, 4, 4}, true, 1},
    52  	{[]byte{5, 5, 5, 5, 5, 5}, true, 0},
    53  	{[]byte{6, 6, 6, 6, 6, 6}, false, 0},
    54  	{padding255Bad[:], false, 0},
    55  	{padding255Good[:], true, 0},
    56  }
    57  
    58  func TestRemovePadding(t *testing.T) {
    59  	for i := 1; i < len(padding255Bad); i++ {
    60  		padding255Bad[i] = 255
    61  		padding255Good[i] = 255
    62  	}
    63  	for i, test := range paddingTests {
    64  		paddingLen, good := extractPadding(test.in)
    65  		expectedGood := byte(255)
    66  		if !test.good {
    67  			expectedGood = 0
    68  		}
    69  		if good != expectedGood {
    70  			t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good)
    71  		}
    72  		if good == 255 && len(test.in)-paddingLen != test.expectedLen {
    73  			t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen)
    74  		}
    75  	}
    76  }
    77  
    78  var certExampleCom = `308201713082011ba003020102021005a75ddf21014d5f417083b7a010ba2e300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343135335a170d3137303831373231343135335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b37f0fdd67e715bf532046ac34acbd8fdc4dabe2b598588f3f58b1f12e6219a16cbfe54d2b4b665396013589262360b6721efa27d546854f17cc9aeec6751db10203010001a34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300d06092a864886f70d01010b050003410059fc487866d3d855503c8e064ca32aac5e9babcece89ec597f8b2b24c17867f4a5d3b4ece06e795bfc5448ccbd2ffca1b3433171ebf3557a4737b020565350a0`
    79  
    80  var certWildcardExampleCom = `308201743082011ea003020102021100a7aa6297c9416a4633af8bec2958c607300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343231395a170d3137303831373231343231395a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b105afc859a711ee864114e7d2d46c2dcbe392d3506249f6c2285b0eb342cc4bf2d803677c61c0abde443f084745c1a6d62080e5664ef2cc8f50ad8a0ab8870b0203010001a34f304d300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030180603551d110411300f820d2a2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100af26088584d266e3f6566360cf862c7fecc441484b098b107439543144a2b93f20781988281e108c6d7656934e56950e1e5f2bcf38796b814ccb729445856c34`
    81  
    82  var certFooExampleCom = `308201753082011fa00302010202101bbdb6070b0aeffc49008cde74deef29300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343234345a170d3137303831373231343234345a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100f00ac69d8ca2829f26216c7b50f1d4bbabad58d447706476cd89a2f3e1859943748aa42c15eedc93ac7c49e40d3b05ed645cb6b81c4efba60d961f44211a54eb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f666f6f2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100a0957fca6d1e0f1ef4b247348c7a8ca092c29c9c0ecc1898ea6b8065d23af6d922a410dd2335a0ea15edd1394cef9f62c9e876a21e35250a0b4fe1ddceba0f36`
    83  
    84  func TestCertificateSelection(t *testing.T) {
    85  	config := Config{
    86  		Certificates: []Certificate{
    87  			{
    88  				Certificate: [][]byte{fromHex(certExampleCom)},
    89  			},
    90  			{
    91  				Certificate: [][]byte{fromHex(certWildcardExampleCom)},
    92  			},
    93  			{
    94  				Certificate: [][]byte{fromHex(certFooExampleCom)},
    95  			},
    96  		},
    97  	}
    98  
    99  	config.BuildNameToCertificate()
   100  
   101  	pointerToIndex := func(c *Certificate) int {
   102  		for i := range config.Certificates {
   103  			if c == &config.Certificates[i] {
   104  				return i
   105  			}
   106  		}
   107  		return -1
   108  	}
   109  
   110  	certificateForName := func(name string) *Certificate {
   111  		clientHello := &ClientHelloInfo{
   112  			ServerName: name,
   113  		}
   114  		if cert, err := config.getCertificate(clientHello); err != nil {
   115  			t.Errorf("unable to get certificate for name '%s': %s", name, err)
   116  			return nil
   117  		} else {
   118  			return cert
   119  		}
   120  	}
   121  
   122  	if n := pointerToIndex(certificateForName("example.com")); n != 0 {
   123  		t.Errorf("example.com returned certificate %d, not 0", n)
   124  	}
   125  	if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
   126  		t.Errorf("bar.example.com returned certificate %d, not 1", n)
   127  	}
   128  	if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
   129  		t.Errorf("foo.example.com returned certificate %d, not 2", n)
   130  	}
   131  	if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 0 {
   132  		t.Errorf("foo.bar.example.com returned certificate %d, not 0", n)
   133  	}
   134  }
   135  
   136  // Run with multiple crypto configs to test the logic for computing TLS record overheads.
   137  func runDynamicRecordSizingTest(t *testing.T, config *Config) {
   138  	clientConn, serverConn := localPipe(t)
   139  
   140  	serverConfig := config.Clone()
   141  	serverConfig.DynamicRecordSizingDisabled = false
   142  	tlsConn := Server(serverConn, serverConfig)
   143  
   144  	handshakeDone := make(chan struct{})
   145  	recordSizesChan := make(chan []int, 1)
   146  	defer func() { <-recordSizesChan }() // wait for the goroutine to exit
   147  	go func() {
   148  		// This goroutine performs a TLS handshake over clientConn and
   149  		// then reads TLS records until EOF. It writes a slice that
   150  		// contains all the record sizes to recordSizesChan.
   151  		defer close(recordSizesChan)
   152  		defer func(clientConn net.Conn) {
   153  			err := clientConn.Close()
   154  			if err != nil {
   155  				panic(err)
   156  			}
   157  		}(clientConn)
   158  
   159  		tlsConn := Client(clientConn, config)
   160  		if err := tlsConn.Handshake(); err != nil {
   161  			t.Errorf("Error from client handshake: %v", err)
   162  			return
   163  		}
   164  		close(handshakeDone)
   165  
   166  		var recordHeader [recordHeaderLen]byte
   167  		var record []byte
   168  		var recordSizes []int
   169  
   170  		for {
   171  			n, err := io.ReadFull(clientConn, recordHeader[:])
   172  			if err == io.EOF {
   173  				break
   174  			}
   175  			if err != nil || n != len(recordHeader) {
   176  				t.Errorf("io.ReadFull = %d, %v", n, err)
   177  				return
   178  			}
   179  
   180  			length := int(recordHeader[3])<<8 | int(recordHeader[4])
   181  			if len(record) < length {
   182  				record = make([]byte, length)
   183  			}
   184  
   185  			n, err = io.ReadFull(clientConn, record[:length])
   186  			if err != nil || n != length {
   187  				t.Errorf("io.ReadFull = %d, %v", n, err)
   188  				return
   189  			}
   190  
   191  			recordSizes = append(recordSizes, recordHeaderLen+length)
   192  		}
   193  
   194  		recordSizesChan <- recordSizes
   195  	}()
   196  
   197  	if err := tlsConn.Handshake(); err != nil {
   198  		t.Fatalf("Error from server handshake: %s", err)
   199  	}
   200  	<-handshakeDone
   201  
   202  	// The server writes these plaintexts in order.
   203  	plaintext := bytes.Join([][]byte{
   204  		bytes.Repeat([]byte("x"), recordSizeBoostThreshold),
   205  		bytes.Repeat([]byte("y"), maxPlaintext*2),
   206  		bytes.Repeat([]byte("z"), maxPlaintext),
   207  	}, nil)
   208  
   209  	if _, err := tlsConn.Write(plaintext); err != nil {
   210  		t.Fatalf("Error from server write: %s", err)
   211  	}
   212  	if err := tlsConn.Close(); err != nil {
   213  		t.Fatalf("Error from server close: %s", err)
   214  	}
   215  
   216  	recordSizes := <-recordSizesChan
   217  	if recordSizes == nil {
   218  		t.Fatalf("Client encountered an error")
   219  	}
   220  
   221  	// Drop the size of the second to last record, which is likely to be
   222  	// truncated, and the last record, which is a close_notify alert.
   223  	recordSizes = recordSizes[:len(recordSizes)-2]
   224  
   225  	// recordSizes should contain a series of records smaller than
   226  	// tcpMSSEstimate followed by some larger than maxPlaintext.
   227  	seenLargeRecord := false
   228  	for i, size := range recordSizes {
   229  		if !seenLargeRecord {
   230  			if size > (i+1)*tcpMSSEstimate {
   231  				t.Fatalf("Record #%d has size %d, which is too large too soon", i, size)
   232  			}
   233  			if size >= maxPlaintext {
   234  				seenLargeRecord = true
   235  			}
   236  		} else if size <= maxPlaintext {
   237  			t.Fatalf("Record #%d has size %d but should be full sized", i, size)
   238  		}
   239  	}
   240  
   241  	if !seenLargeRecord {
   242  		t.Fatalf("No large records observed")
   243  	}
   244  }
   245  
   246  func TestDynamicRecordSizingWithStreamCipher(t *testing.T) {
   247  	config := testConfig.Clone()
   248  	config.MaxVersion = VersionTLS12
   249  	config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA}
   250  	runDynamicRecordSizingTest(t, config)
   251  }
   252  
   253  func TestDynamicRecordSizingWithCBC(t *testing.T) {
   254  	config := testConfig.Clone()
   255  	config.MaxVersion = VersionTLS12
   256  	config.CipherSuites = []uint16{TLS_RSA_WITH_AES_256_CBC_SHA}
   257  	runDynamicRecordSizingTest(t, config)
   258  }
   259  
   260  func TestDynamicRecordSizingWithAEAD(t *testing.T) {
   261  	config := testConfig.Clone()
   262  	config.MaxVersion = VersionTLS12
   263  	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
   264  	runDynamicRecordSizingTest(t, config)
   265  }
   266  
   267  func TestDynamicRecordSizingWithTLSv13(t *testing.T) {
   268  	config := testConfig.Clone()
   269  	runDynamicRecordSizingTest(t, config)
   270  }
   271  
   272  // hairpinConn is a net.Conn that makes a “hairpin” call when closed, back into
   273  // the tls.Conn which is calling it.
   274  type hairpinConn struct {
   275  	net.Conn
   276  	tlsConn *Conn
   277  }
   278  
   279  func (conn *hairpinConn) Close() error {
   280  	conn.tlsConn.ConnectionState()
   281  	return nil
   282  }
   283  
   284  func TestHairpinInClose(t *testing.T) {
   285  	// This tests that the underlying net.Conn can call back into the
   286  	// tls.Conn when being closed without deadlocking.
   287  	client, server := localPipe(t)
   288  	defer func(server net.Conn) {
   289  		err := server.Close()
   290  		if err != nil {
   291  			panic(err)
   292  		}
   293  	}(server)
   294  	defer func(client net.Conn) {
   295  		err := client.Close()
   296  		if err != nil {
   297  			panic(err)
   298  		}
   299  	}(client)
   300  
   301  	conn := &hairpinConn{client, nil}
   302  	tlsConn := Server(conn, &Config{
   303  		GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
   304  			panic("unreachable")
   305  		},
   306  	})
   307  	conn.tlsConn = tlsConn
   308  
   309  	// This call should not deadlock.
   310  	err := tlsConn.Close()
   311  	if err != nil {
   312  		t.Fatal(err)
   313  	}
   314  }