github.com/cloudwego/hertz@v0.9.3/pkg/network/standard/dial_test.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package standard
    18  
    19  import (
    20  	"context"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"crypto/x509/pkix"
    26  	"encoding/pem"
    27  	"math/big"
    28  	"net"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/cloudwego/hertz/pkg/common/config"
    33  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    34  )
    35  
    36  func TestDial(t *testing.T) {
    37  	const nw = "tcp"
    38  	const addr = "localhost:10104"
    39  	transporter := NewTransporter(&config.Options{
    40  		Addr:    addr,
    41  		Network: nw,
    42  	})
    43  
    44  	go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error {
    45  		return nil
    46  	})
    47  	defer transporter.Close()
    48  	time.Sleep(time.Millisecond * 100)
    49  
    50  	dial := NewDialer()
    51  	_, err := dial.DialConnection(nw, addr, time.Second, nil)
    52  	assert.Nil(t, err)
    53  
    54  	nConn, err := dial.DialTimeout(nw, addr, time.Second, nil)
    55  	assert.Nil(t, err)
    56  	defer nConn.Close()
    57  }
    58  
    59  func TestDialTLS(t *testing.T) {
    60  	const nw = "tcp"
    61  	const addr = "localhost:10105"
    62  	data := []byte("abcdefg")
    63  	listened := make(chan struct{})
    64  	go func() {
    65  		mockTLSServe(nw, addr, func(conn net.Conn) {
    66  			defer conn.Close()
    67  			_, err := conn.Write(data)
    68  			assert.Nil(t, err)
    69  		}, listened)
    70  	}()
    71  
    72  	select {
    73  	case <-listened:
    74  	case <-time.After(time.Second * 5):
    75  		t.Fatalf("timeout")
    76  	}
    77  
    78  	buf := make([]byte, len(data))
    79  
    80  	dial := NewDialer()
    81  	conn, err := dial.DialConnection(nw, addr, time.Second, &tls.Config{
    82  		InsecureSkipVerify: true,
    83  	})
    84  	assert.Nil(t, err)
    85  
    86  	_, err = conn.Read(buf)
    87  	assert.Nil(t, err)
    88  	assert.DeepEqual(t, string(data), string(buf))
    89  
    90  	conn, err = dial.DialConnection(nw, addr, time.Second, nil)
    91  	assert.Nil(t, err)
    92  	nConn, err := dial.AddTLS(conn, &tls.Config{
    93  		InsecureSkipVerify: true,
    94  	})
    95  	assert.Nil(t, err)
    96  
    97  	_, err = nConn.Read(buf)
    98  	assert.Nil(t, err)
    99  	assert.DeepEqual(t, string(data), string(buf))
   100  }
   101  
   102  func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan struct{}) (err error) {
   103  	certData, keyData, err := generateTestCertificate("")
   104  	if err != nil {
   105  		return
   106  	}
   107  
   108  	cert, err := tls.X509KeyPair(certData, keyData)
   109  	if err != nil {
   110  		return
   111  	}
   112  
   113  	tlsConfig := &tls.Config{
   114  		Certificates: []tls.Certificate{cert},
   115  	}
   116  	ln, err := tls.Listen(nw, addr, tlsConfig)
   117  	if err != nil {
   118  		return
   119  	}
   120  	defer ln.Close()
   121  
   122  	listened <- struct{}{}
   123  	for {
   124  		conn, err := ln.Accept()
   125  		if err != nil {
   126  			continue
   127  		}
   128  		go handle(conn)
   129  	}
   130  }
   131  
   132  // generateTestCertificate generates a test certificate and private key based on the given host.
   133  func generateTestCertificate(host string) ([]byte, []byte, error) {
   134  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
   135  	if err != nil {
   136  		return nil, nil, err
   137  	}
   138  
   139  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   140  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   141  	if err != nil {
   142  		return nil, nil, err
   143  	}
   144  
   145  	cert := &x509.Certificate{
   146  		SerialNumber: serialNumber,
   147  		Subject: pkix.Name{
   148  			Organization: []string{"fasthttp test"},
   149  		},
   150  		NotBefore:             time.Now(),
   151  		NotAfter:              time.Now().Add(365 * 24 * time.Hour),
   152  		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
   153  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
   154  		SignatureAlgorithm:    x509.SHA256WithRSA,
   155  		DNSNames:              []string{host},
   156  		BasicConstraintsValid: true,
   157  		IsCA:                  true,
   158  	}
   159  
   160  	certBytes, err := x509.CreateCertificate(
   161  		rand.Reader, cert, cert, &priv.PublicKey, priv,
   162  	)
   163  
   164  	p := pem.EncodeToMemory(
   165  		&pem.Block{
   166  			Type:  "PRIVATE KEY",
   167  			Bytes: x509.MarshalPKCS1PrivateKey(priv),
   168  		},
   169  	)
   170  
   171  	b := pem.EncodeToMemory(
   172  		&pem.Block{
   173  			Type:  "CERTIFICATE",
   174  			Bytes: certBytes,
   175  		},
   176  	)
   177  
   178  	return b, p, err
   179  }