github.com/metacubex/mihomo@v1.18.5/transport/shadowtls/shadowtls.go (about)

     1  package shadowtls
     2  
     3  import (
     4  	"context"
     5  	"crypto/hmac"
     6  	"crypto/sha1"
     7  	"crypto/tls"
     8  	"encoding/binary"
     9  	"fmt"
    10  	"hash"
    11  	"io"
    12  	"net"
    13  
    14  	"github.com/metacubex/mihomo/common/pool"
    15  	C "github.com/metacubex/mihomo/constant"
    16  )
    17  
    18  const (
    19  	chunkSize           = 1 << 13
    20  	Mode         string = "shadow-tls"
    21  	hashLen      int    = 8
    22  	tlsHeaderLen int    = 5
    23  )
    24  
    25  var (
    26  	DefaultALPN = []string{"h2", "http/1.1"}
    27  )
    28  
    29  // ShadowTLS is shadow-tls implementation
    30  type ShadowTLS struct {
    31  	net.Conn
    32  	password     []byte
    33  	remain       int
    34  	firstRequest bool
    35  	tlsConfig    *tls.Config
    36  }
    37  
    38  type HashedConn struct {
    39  	net.Conn
    40  	hasher hash.Hash
    41  }
    42  
    43  func newHashedStream(conn net.Conn, password []byte) HashedConn {
    44  	return HashedConn{
    45  		Conn:   conn,
    46  		hasher: hmac.New(sha1.New, password),
    47  	}
    48  }
    49  
    50  func (h HashedConn) Read(b []byte) (n int, err error) {
    51  	n, err = h.Conn.Read(b)
    52  	h.hasher.Write(b[:n])
    53  	return
    54  }
    55  
    56  func (s *ShadowTLS) read(b []byte) (int, error) {
    57  	var buf [tlsHeaderLen]byte
    58  	_, err := io.ReadFull(s.Conn, buf[:])
    59  	if err != nil {
    60  		return 0, fmt.Errorf("shadowtls read failed %w", err)
    61  	}
    62  	if buf[0] != 0x17 || buf[1] != 0x3 || buf[2] != 0x3 {
    63  		return 0, fmt.Errorf("invalid shadowtls header %v", buf)
    64  	}
    65  	length := int(binary.BigEndian.Uint16(buf[3:]))
    66  
    67  	if length > len(b) {
    68  		n, err := s.Conn.Read(b)
    69  		if err != nil {
    70  			return n, err
    71  		}
    72  		s.remain = length - n
    73  		return n, nil
    74  	}
    75  
    76  	return io.ReadFull(s.Conn, b[:length])
    77  }
    78  
    79  func (s *ShadowTLS) Read(b []byte) (int, error) {
    80  	if s.remain > 0 {
    81  		length := s.remain
    82  		if length > len(b) {
    83  			length = len(b)
    84  		}
    85  
    86  		n, err := io.ReadFull(s.Conn, b[:length])
    87  		if err != nil {
    88  			return n, fmt.Errorf("shadowtls Read failed with %w", err)
    89  		}
    90  		s.remain -= n
    91  		return n, nil
    92  	}
    93  
    94  	return s.read(b)
    95  }
    96  
    97  func (s *ShadowTLS) Write(b []byte) (int, error) {
    98  	length := len(b)
    99  	for i := 0; i < length; i += chunkSize {
   100  		end := i + chunkSize
   101  		if end > length {
   102  			end = length
   103  		}
   104  
   105  		n, err := s.write(b[i:end])
   106  		if err != nil {
   107  			return n, fmt.Errorf("shadowtls Write failed with %w, i=%d, end=%d, n=%d", err, i, end, n)
   108  		}
   109  	}
   110  	return length, nil
   111  }
   112  
   113  func (s *ShadowTLS) write(b []byte) (int, error) {
   114  	var hashVal []byte
   115  	if s.firstRequest {
   116  		hashedConn := newHashedStream(s.Conn, s.password)
   117  		tlsConn := tls.Client(hashedConn, s.tlsConfig)
   118  		// fix tls handshake not timeout
   119  		ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   120  		defer cancel()
   121  		if err := tlsConn.HandshakeContext(ctx); err != nil {
   122  			return 0, fmt.Errorf("tls connect failed with %w", err)
   123  		}
   124  		hashVal = hashedConn.hasher.Sum(nil)[:hashLen]
   125  		s.firstRequest = false
   126  	}
   127  
   128  	buf := pool.GetBuffer()
   129  	defer pool.PutBuffer(buf)
   130  	buf.Write([]byte{0x17, 0x03, 0x03})
   131  	binary.Write(buf, binary.BigEndian, uint16(len(b)+len(hashVal)))
   132  	buf.Write(hashVal)
   133  	buf.Write(b)
   134  	_, err := s.Conn.Write(buf.Bytes())
   135  	if err != nil {
   136  		// return 0 because errors occur here make the
   137  		// whole situation irrecoverable
   138  		return 0, err
   139  	}
   140  	return len(b), nil
   141  }
   142  
   143  // NewShadowTLS return a ShadowTLS
   144  func NewShadowTLS(conn net.Conn, password string, tlsConfig *tls.Config) net.Conn {
   145  	return &ShadowTLS{
   146  		Conn:         conn,
   147  		password:     []byte(password),
   148  		firstRequest: true,
   149  		tlsConfig:    tlsConfig,
   150  	}
   151  }