github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/sniffy/tls/tls.go (about)

     1  package tls
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"golang.org/x/crypto/cryptobyte"
    13  )
    14  
    15  type Conn struct {
    16  	addr     string
    17  	helloMsg bool
    18  	conn     net.Conn
    19  	ctx      context.Context
    20  	cancel   func()
    21  
    22  	tls       bool
    23  	dialer    func(ctx context.Context, network, addr string) (net.Conn, error)
    24  	tlsDialer func(ctx context.Context, network, addr string) (net.Conn, error)
    25  }
    26  
    27  func (c *Conn) Read(b []byte) (int, error) {
    28  	<-c.ctx.Done()
    29  
    30  	if c.conn == nil {
    31  		return 0, nil
    32  	}
    33  
    34  	return c.conn.Read(b)
    35  }
    36  
    37  func (c *Conn) Write(b []byte) (int, error) {
    38  	if !c.helloMsg {
    39  		c.tls = check(b)
    40  		c.helloMsg = true
    41  	}
    42  
    43  	if c.conn == nil {
    44  		var err error
    45  		if c.tls {
    46  			c.conn, err = c.tlsDialer(c.ctx, "tcp", c.addr)
    47  		} else {
    48  			c.conn, err = c.dialer(c.ctx, "tcp", c.addr)
    49  		}
    50  		if err != nil {
    51  			return 0, err
    52  		}
    53  		c.cancel()
    54  	}
    55  
    56  	return c.conn.Write(b)
    57  }
    58  
    59  func (c *Conn) Close() error {
    60  	c.cancel()
    61  	if c.conn == nil {
    62  		return nil
    63  	}
    64  
    65  	return c.conn.Close()
    66  }
    67  
    68  func (c *Conn) LocalAddr() net.Addr {
    69  	if c.conn == nil {
    70  		return &net.TCPAddr{
    71  			IP:   net.IPv4zero,
    72  			Port: 0,
    73  		}
    74  	}
    75  
    76  	return c.conn.LocalAddr()
    77  }
    78  
    79  func (c *Conn) RemoteAddr() net.Addr {
    80  	if c.conn == nil {
    81  		return &net.TCPAddr{
    82  			IP:   net.IPv4zero,
    83  			Port: 0,
    84  		}
    85  	}
    86  
    87  	return c.conn.RemoteAddr()
    88  }
    89  
    90  func (c *Conn) SetDeadline(t time.Time) error {
    91  	if c.conn == nil {
    92  		return nil
    93  	}
    94  	return c.conn.SetDeadline(t)
    95  }
    96  
    97  func (c *Conn) SetReadDeadline(t time.Time) error {
    98  	if c.conn == nil {
    99  		return nil
   100  	}
   101  	return c.conn.SetReadDeadline(t)
   102  }
   103  
   104  func (c *Conn) SetWriteDeadline(t time.Time) error {
   105  	if c.conn == nil {
   106  		return nil
   107  	}
   108  
   109  	return c.conn.SetWriteDeadline(t)
   110  }
   111  
   112  func check(buf []byte) bool {
   113  	n := len(buf)
   114  
   115  	if n <= 5 {
   116  		return false
   117  	}
   118  
   119  	// tls record type
   120  	if recordType(buf[0]) != recordTypeHandshake {
   121  		return false
   122  	}
   123  
   124  	// tls major version
   125  	// fmt.Println("tls version", buf[1])
   126  	// if buf[1] != 3 {
   127  	// 	log.Println("TLS version < 3 not supported")
   128  	// 	return false
   129  	// }
   130  
   131  	// payload length
   132  	//l := int(buf[3])<<16 + int(buf[4])
   133  
   134  	//log.Printf("length: %d, got: %d", l, n)
   135  
   136  	// handshake message type
   137  	if uint8(buf[5]) != typeClientHello {
   138  		return false
   139  	}
   140  
   141  	msg := &clientHelloMsg{}
   142  
   143  	// client hello message not include tls header, 5 bytes
   144  	ret := msg.unmarshal(buf[5:n])
   145  	if !ret {
   146  		return false
   147  	}
   148  
   149  	fmt.Println("server name", msg.serverName)
   150  
   151  	return true
   152  }
   153  
   154  func Get(str string) error {
   155  	hc := http.Client{
   156  		Transport: &http.Transport{
   157  			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   158  				ctx, cancel := context.WithCancel(ctx)
   159  				return &Conn{
   160  					addr:      addr,
   161  					ctx:       ctx,
   162  					cancel:    cancel,
   163  					dialer:    (&net.Dialer{}).DialContext,
   164  					tlsDialer: (&net.Dialer{}).DialContext,
   165  				}, nil
   166  			},
   167  		},
   168  	}
   169  
   170  	resp, err := hc.Get(str)
   171  	if err != nil {
   172  		return err
   173  	}
   174  	defer resp.Body.Close()
   175  
   176  	data, _ := io.ReadAll(resp.Body)
   177  	fmt.Println("data length", len(data))
   178  	return nil
   179  }
   180  
   181  // Copyright 2009 The Go Authors. All rights reserved.
   182  // Use of this source code is governed by a BSD-style
   183  // license that can be found in the LICENSE file.
   184  
   185  // this file is from $GOROOT/src/crypto/tls
   186  
   187  // TLS record types.
   188  type recordType uint8
   189  
   190  const (
   191  	recordTypeChangeCipherSpec recordType = 20
   192  	recordTypeAlert            recordType = 21
   193  	recordTypeHandshake        recordType = 22
   194  	recordTypeApplicationData  recordType = 23
   195  )
   196  
   197  // TLS handshake message types.
   198  const (
   199  	typeHelloRequest        uint8 = 0
   200  	typeClientHello         uint8 = 1
   201  	typeServerHello         uint8 = 2
   202  	typeNewSessionTicket    uint8 = 4
   203  	typeEndOfEarlyData      uint8 = 5
   204  	typeEncryptedExtensions uint8 = 8
   205  	typeCertificate         uint8 = 11
   206  	typeServerKeyExchange   uint8 = 12
   207  	typeCertificateRequest  uint8 = 13
   208  	typeServerHelloDone     uint8 = 14
   209  	typeCertificateVerify   uint8 = 15
   210  	typeClientKeyExchange   uint8 = 16
   211  	typeFinished            uint8 = 20
   212  	typeCertificateStatus   uint8 = 22
   213  	typeKeyUpdate           uint8 = 24
   214  	typeNextProtocol        uint8 = 67  // Not IANA assigned
   215  	typeMessageHash         uint8 = 254 // synthetic message
   216  )
   217  
   218  // TLS extension numbers
   219  const (
   220  	extensionServerName              uint16 = 0
   221  	extensionStatusRequest           uint16 = 5
   222  	extensionSupportedCurves         uint16 = 10 // supported_groups in TLS 1.3, see RFC 8446, Section 4.2.7
   223  	extensionSupportedPoints         uint16 = 11
   224  	extensionSignatureAlgorithms     uint16 = 13
   225  	extensionALPN                    uint16 = 16
   226  	extensionSCT                     uint16 = 18
   227  	extensionSessionTicket           uint16 = 35
   228  	extensionPreSharedKey            uint16 = 41
   229  	extensionEarlyData               uint16 = 42
   230  	extensionSupportedVersions       uint16 = 43
   231  	extensionCookie                  uint16 = 44
   232  	extensionPSKModes                uint16 = 45
   233  	extensionCertificateAuthorities  uint16 = 47
   234  	extensionSignatureAlgorithmsCert uint16 = 50
   235  	extensionKeyShare                uint16 = 51
   236  	extensionRenegotiationInfo       uint16 = 0xff01
   237  )
   238  
   239  // readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
   240  // []byte instead of a cryptobyte.String.
   241  func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
   242  	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
   243  }
   244  
   245  type clientHelloMsg struct {
   246  	vers               uint16
   247  	random             []byte
   248  	sessionId          []byte
   249  	compressionMethods []uint8
   250  	serverName         string
   251  }
   252  
   253  func (m *clientHelloMsg) unmarshal(data []byte) bool {
   254  	*m = clientHelloMsg{}
   255  	s := cryptobyte.String(data)
   256  
   257  	if !s.Skip(4) || // message type and uint24 length field
   258  		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
   259  		!readUint8LengthPrefixed(&s, &m.sessionId) {
   260  		return false
   261  	}
   262  
   263  	var cipherSuites cryptobyte.String
   264  	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
   265  		return false
   266  	}
   267  	for !cipherSuites.Empty() {
   268  		var suite uint16
   269  		if !cipherSuites.ReadUint16(&suite) {
   270  			return false
   271  		}
   272  	}
   273  
   274  	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
   275  		return false
   276  	}
   277  
   278  	if s.Empty() {
   279  		// ClientHello is optionally followed by extension data
   280  		return true
   281  	}
   282  
   283  	var extensions cryptobyte.String
   284  	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
   285  		return false
   286  	}
   287  
   288  	seenExts := make(map[uint16]bool)
   289  	for !extensions.Empty() {
   290  		var extension uint16
   291  		var extData cryptobyte.String
   292  		if !extensions.ReadUint16(&extension) ||
   293  			!extensions.ReadUint16LengthPrefixed(&extData) {
   294  			return false
   295  		}
   296  
   297  		if seenExts[extension] {
   298  			return false
   299  		}
   300  		seenExts[extension] = true
   301  
   302  		switch extension {
   303  		case extensionServerName:
   304  			// RFC 6066, Section 3
   305  			var nameList cryptobyte.String
   306  			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
   307  				return false
   308  			}
   309  			for !nameList.Empty() {
   310  				var nameType uint8
   311  				var serverName cryptobyte.String
   312  				if !nameList.ReadUint8(&nameType) ||
   313  					!nameList.ReadUint16LengthPrefixed(&serverName) ||
   314  					serverName.Empty() {
   315  					return false
   316  				}
   317  				if nameType != 0 {
   318  					continue
   319  				}
   320  				if len(m.serverName) != 0 {
   321  					// Multiple names of the same name_type are prohibited.
   322  					return false
   323  				}
   324  				m.serverName = string(serverName)
   325  				// An SNI value may not include a trailing dot.
   326  				if strings.HasSuffix(m.serverName, ".") {
   327  					return false
   328  				}
   329  			}
   330  		default:
   331  			// Ignore unknown extensions.
   332  			continue
   333  		}
   334  
   335  		if !extData.Empty() {
   336  			return false
   337  		}
   338  	}
   339  
   340  	return true
   341  }