github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/warp/tls.go (about)

     1  package warp
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"math/big"
     8  	"math/rand"
     9  	"net"
    10  	"net/netip"
    11  	"time"
    12  
    13  	tls "github.com/refraction-networking/utls"
    14  )
    15  
    16  // Dialer is a struct that holds various options for custom dialing.
    17  type Dialer struct {
    18  }
    19  
    20  const (
    21  	extensionServerName   uint16 = 0x0
    22  	utlsExtensionSNICurve uint16 = 0x15
    23  )
    24  
    25  func hostnameInSNI(name string) string {
    26  	return name
    27  }
    28  
    29  // SNIExtension implements server_name (0)
    30  type SNIExtension struct {
    31  	*tls.GenericExtension
    32  	ServerName string // not an array because go crypto/tls doesn't support multiple SNIs
    33  }
    34  
    35  // Len returns the length of the SNIExtension.
    36  func (e *SNIExtension) Len() int {
    37  	// Literal IP addresses, absolute FQDNs, and empty strings are not permitted as SNI values.
    38  	// See RFC 6066, Section 3.
    39  	hostName := hostnameInSNI(e.ServerName)
    40  	if len(hostName) == 0 {
    41  		return 0
    42  	}
    43  	return 4 + 2 + 1 + 2 + len(hostName)
    44  }
    45  
    46  // Read reads the SNIExtension.
    47  func (e *SNIExtension) Read(b []byte) (int, error) {
    48  	// Literal IP addresses, absolute FQDNs, and empty strings are not permitted as SNI values.
    49  	// See RFC 6066, Section 3.
    50  	hostName := hostnameInSNI(e.ServerName)
    51  	if len(hostName) == 0 {
    52  		return 0, io.EOF
    53  	}
    54  	if len(b) < e.Len() {
    55  		return 0, io.ErrShortBuffer
    56  	}
    57  	// RFC 3546, section 3.1
    58  	b[0] = byte(extensionServerName >> 8)
    59  	b[1] = byte(extensionServerName)
    60  	b[2] = byte((len(hostName) + 5) >> 8)
    61  	b[3] = byte(len(hostName) + 5)
    62  	b[4] = byte((len(hostName) + 3) >> 8)
    63  	b[5] = byte(len(hostName) + 3)
    64  	// b[6] Server Name Type: host_name (0)
    65  	b[7] = byte(len(hostName) >> 8)
    66  	b[8] = byte(len(hostName))
    67  	copy(b[9:], hostName)
    68  	return e.Len(), io.EOF
    69  }
    70  
    71  // SNICurveExtension implements SNICurve (0x15) extension
    72  type SNICurveExtension struct {
    73  	*tls.GenericExtension
    74  	SNICurveLen int
    75  	WillPad     bool // set false to disable extension
    76  }
    77  
    78  // Len returns the length of the SNICurveExtension.
    79  func (e *SNICurveExtension) Len() int {
    80  	if e.WillPad {
    81  		return 4 + e.SNICurveLen
    82  	}
    83  	return 0
    84  }
    85  
    86  // Read reads the SNICurveExtension.
    87  func (e *SNICurveExtension) Read(b []byte) (n int, err error) {
    88  	if !e.WillPad {
    89  		return 0, io.EOF
    90  	}
    91  	if len(b) < e.Len() {
    92  		return 0, io.ErrShortBuffer
    93  	}
    94  	// https://tools.ietf.org/html/rfc7627
    95  	b[0] = byte(utlsExtensionSNICurve >> 8)
    96  	b[1] = byte(utlsExtensionSNICurve)
    97  	b[2] = byte(e.SNICurveLen >> 8)
    98  	b[3] = byte(e.SNICurveLen)
    99  	y := make([]byte, 1200)
   100  	copy(b[4:], y)
   101  	return e.Len(), io.EOF
   102  }
   103  
   104  // makeTLSHelloPacketWithSNICurve creates a TLS hello packet with SNICurve.
   105  func (d *Dialer) makeTLSHelloPacketWithSNICurve(plainConn net.Conn, config *tls.Config, sni string) (*tls.UConn, error) {
   106  	SNICurveSize := 1200
   107  
   108  	utlsConn := tls.UClient(plainConn, config, tls.HelloCustom)
   109  	spec := tls.ClientHelloSpec{
   110  		TLSVersMax: tls.VersionTLS12,
   111  		TLSVersMin: tls.VersionTLS12,
   112  		CipherSuites: []uint16{
   113  			tls.GREASE_PLACEHOLDER,
   114  			tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
   115  			tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
   116  			tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
   117  			tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
   118  			tls.TLS_AES_128_GCM_SHA256, // tls 1.3
   119  			tls.FAKE_TLS_DHE_RSA_WITH_AES_256_CBC_SHA,
   120  			tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   121  			tls.TLS_RSA_WITH_AES_256_CBC_SHA,
   122  		},
   123  		Extensions: []tls.TLSExtension{
   124  			&SNICurveExtension{
   125  				SNICurveLen: SNICurveSize,
   126  				WillPad:     true,
   127  			},
   128  			&tls.SupportedCurvesExtension{Curves: []tls.CurveID{tls.X25519, tls.CurveP256}},
   129  			&tls.SupportedPointsExtension{SupportedPoints: []byte{0}}, // uncompressed
   130  			&tls.SessionTicketExtension{},
   131  			&tls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}},
   132  			&tls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []tls.SignatureScheme{
   133  				tls.ECDSAWithP256AndSHA256,
   134  				tls.ECDSAWithP384AndSHA384,
   135  				tls.ECDSAWithP521AndSHA512,
   136  				tls.PSSWithSHA256,
   137  				tls.PSSWithSHA384,
   138  				tls.PSSWithSHA512,
   139  				tls.PKCS1WithSHA256,
   140  				tls.PKCS1WithSHA384,
   141  				tls.PKCS1WithSHA512,
   142  				tls.ECDSAWithSHA1,
   143  				tls.PKCS1WithSHA1}},
   144  			&tls.KeyShareExtension{KeyShares: []tls.KeyShare{
   145  				{Group: tls.CurveID(tls.GREASE_PLACEHOLDER), Data: []byte{0}},
   146  				{Group: tls.X25519},
   147  			}},
   148  			&tls.PSKKeyExchangeModesExtension{Modes: []uint8{1}}, // pskModeDHE
   149  			&SNIExtension{
   150  				ServerName: sni,
   151  			},
   152  		},
   153  		GetSessionID: nil,
   154  	}
   155  	err := utlsConn.ApplyPreset(&spec)
   156  
   157  	if err != nil {
   158  		return nil, fmt.Errorf("uTlsConn.Handshake() error: %+v", err)
   159  	}
   160  
   161  	err = utlsConn.Handshake()
   162  
   163  	if err != nil {
   164  		return nil, fmt.Errorf("uTlsConn.Handshake() error: %+v", err)
   165  	}
   166  
   167  	return utlsConn, nil
   168  }
   169  
   170  // RandomIPFromPrefix returns a random IP from the provided CIDR prefix.
   171  // Supports IPv4 and IPv6. Does not support mapped inputs.
   172  func RandomIPFromPrefix(cidr netip.Prefix) (netip.Addr, error) {
   173  	startingAddress := cidr.Masked().Addr()
   174  	if startingAddress.Is4In6() {
   175  		return netip.Addr{}, errors.New("mapped v4 addresses not supported")
   176  	}
   177  
   178  	prefixLen := cidr.Bits()
   179  	if prefixLen == -1 {
   180  		return netip.Addr{}, fmt.Errorf("invalid cidr: %s", cidr)
   181  	}
   182  
   183  	// Initialise rand number generator
   184  	rng := rand.New(rand.NewSource(time.Now().UnixNano()))
   185  
   186  	// Find the bit length of the Host portion of the provided CIDR
   187  	// prefix
   188  	hostLen := big.NewInt(int64(startingAddress.BitLen() - prefixLen))
   189  
   190  	// Find the max value for our random number
   191  	max := new(big.Int).Exp(big.NewInt(2), hostLen, nil)
   192  
   193  	// Generate the random number
   194  	randInt := new(big.Int).Rand(rng, max)
   195  
   196  	// Get the first address in the CIDR prefix in 16-bytes form
   197  	startingAddress16 := startingAddress.As16()
   198  
   199  	// Convert the first address into a decimal number
   200  	startingAddressInt := new(big.Int).SetBytes(startingAddress16[:])
   201  
   202  	// Add the random number to the decimal form of the starting address
   203  	// to get a random address in the desired range
   204  	randomAddressInt := new(big.Int).Add(startingAddressInt, randInt)
   205  
   206  	// Convert the random address from decimal form back into netip.Addr
   207  	randomAddress, ok := netip.AddrFromSlice(randomAddressInt.FillBytes(make([]byte, 16)))
   208  	if !ok {
   209  		return netip.Addr{}, fmt.Errorf("failed to generate random IP from CIDR: %s", cidr)
   210  	}
   211  
   212  	// Unmap any mapped v4 addresses before return
   213  	return randomAddress.Unmap(), nil
   214  }
   215  
   216  // TLSDial dials a TLS connection.
   217  func (d *Dialer) TLSDial(plainDialer *net.Dialer, network, addr string) (net.Conn, error) {
   218  	sni, _, err := net.SplitHostPort(addr)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	ip, err := RandomIPFromPrefix(netip.MustParsePrefix("141.101.113.0/24"))
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	plainConn, err := plainDialer.Dial(network, ip.String()+":443")
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	config := tls.Config{
   232  		ServerName:         sni,
   233  		InsecureSkipVerify: true,
   234  		NextProtos:         nil,
   235  		MinVersion:         tls.VersionTLS10,
   236  	}
   237  
   238  	utlsConn, handshakeErr := d.makeTLSHelloPacketWithSNICurve(plainConn, &config, sni)
   239  	if handshakeErr != nil {
   240  		_ = plainConn.Close()
   241  		fmt.Println(handshakeErr)
   242  		return nil, handshakeErr
   243  	}
   244  	return utlsConn, nil
   245  }