github.com/cloudwego/hertz@v0.9.3/pkg/network/standard/dial_test.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package standard 18 19 import ( 20 "context" 21 "crypto/rand" 22 "crypto/rsa" 23 "crypto/tls" 24 "crypto/x509" 25 "crypto/x509/pkix" 26 "encoding/pem" 27 "math/big" 28 "net" 29 "testing" 30 "time" 31 32 "github.com/cloudwego/hertz/pkg/common/config" 33 "github.com/cloudwego/hertz/pkg/common/test/assert" 34 ) 35 36 func TestDial(t *testing.T) { 37 const nw = "tcp" 38 const addr = "localhost:10104" 39 transporter := NewTransporter(&config.Options{ 40 Addr: addr, 41 Network: nw, 42 }) 43 44 go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { 45 return nil 46 }) 47 defer transporter.Close() 48 time.Sleep(time.Millisecond * 100) 49 50 dial := NewDialer() 51 _, err := dial.DialConnection(nw, addr, time.Second, nil) 52 assert.Nil(t, err) 53 54 nConn, err := dial.DialTimeout(nw, addr, time.Second, nil) 55 assert.Nil(t, err) 56 defer nConn.Close() 57 } 58 59 func TestDialTLS(t *testing.T) { 60 const nw = "tcp" 61 const addr = "localhost:10105" 62 data := []byte("abcdefg") 63 listened := make(chan struct{}) 64 go func() { 65 mockTLSServe(nw, addr, func(conn net.Conn) { 66 defer conn.Close() 67 _, err := conn.Write(data) 68 assert.Nil(t, err) 69 }, listened) 70 }() 71 72 select { 73 case <-listened: 74 case <-time.After(time.Second * 5): 75 t.Fatalf("timeout") 76 } 77 78 buf := make([]byte, len(data)) 79 80 dial := NewDialer() 81 conn, err := dial.DialConnection(nw, addr, time.Second, &tls.Config{ 82 InsecureSkipVerify: true, 83 }) 84 assert.Nil(t, err) 85 86 _, err = conn.Read(buf) 87 assert.Nil(t, err) 88 assert.DeepEqual(t, string(data), string(buf)) 89 90 conn, err = dial.DialConnection(nw, addr, time.Second, nil) 91 assert.Nil(t, err) 92 nConn, err := dial.AddTLS(conn, &tls.Config{ 93 InsecureSkipVerify: true, 94 }) 95 assert.Nil(t, err) 96 97 _, err = nConn.Read(buf) 98 assert.Nil(t, err) 99 assert.DeepEqual(t, string(data), string(buf)) 100 } 101 102 func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan struct{}) (err error) { 103 certData, keyData, err := generateTestCertificate("") 104 if err != nil { 105 return 106 } 107 108 cert, err := tls.X509KeyPair(certData, keyData) 109 if err != nil { 110 return 111 } 112 113 tlsConfig := &tls.Config{ 114 Certificates: []tls.Certificate{cert}, 115 } 116 ln, err := tls.Listen(nw, addr, tlsConfig) 117 if err != nil { 118 return 119 } 120 defer ln.Close() 121 122 listened <- struct{}{} 123 for { 124 conn, err := ln.Accept() 125 if err != nil { 126 continue 127 } 128 go handle(conn) 129 } 130 } 131 132 // generateTestCertificate generates a test certificate and private key based on the given host. 133 func generateTestCertificate(host string) ([]byte, []byte, error) { 134 priv, err := rsa.GenerateKey(rand.Reader, 2048) 135 if err != nil { 136 return nil, nil, err 137 } 138 139 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 140 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 141 if err != nil { 142 return nil, nil, err 143 } 144 145 cert := &x509.Certificate{ 146 SerialNumber: serialNumber, 147 Subject: pkix.Name{ 148 Organization: []string{"fasthttp test"}, 149 }, 150 NotBefore: time.Now(), 151 NotAfter: time.Now().Add(365 * 24 * time.Hour), 152 KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, 153 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, 154 SignatureAlgorithm: x509.SHA256WithRSA, 155 DNSNames: []string{host}, 156 BasicConstraintsValid: true, 157 IsCA: true, 158 } 159 160 certBytes, err := x509.CreateCertificate( 161 rand.Reader, cert, cert, &priv.PublicKey, priv, 162 ) 163 164 p := pem.EncodeToMemory( 165 &pem.Block{ 166 Type: "PRIVATE KEY", 167 Bytes: x509.MarshalPKCS1PrivateKey(priv), 168 }, 169 ) 170 171 b := pem.EncodeToMemory( 172 &pem.Block{ 173 Type: "CERTIFICATE", 174 Bytes: certBytes, 175 }, 176 ) 177 178 return b, p, err 179 }