github.com/yaling888/clash@v1.53.0/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"net/http"
    14  	"sync"
    15  
    16  	"github.com/yaling888/clash/common/pool"
    17  	C "github.com/yaling888/clash/constant"
    18  	"github.com/yaling888/clash/transport/h2"
    19  	"github.com/yaling888/clash/transport/socks5"
    20  	"github.com/yaling888/clash/transport/vmess"
    21  )
    22  
    23  const (
    24  	// max packet length
    25  	maxLength = 8192
    26  )
    27  
    28  var (
    29  	defaultALPN          = []string{"h2", "http/1.1"}
    30  	defaultWebsocketALPN = []string{"http/1.1"}
    31  
    32  	crlf = []byte{'\r', '\n'}
    33  )
    34  
    35  type Command = byte
    36  
    37  const (
    38  	CommandTCP byte = 1
    39  	CommandUDP byte = 3
    40  )
    41  
    42  type Option struct {
    43  	Password       string
    44  	ALPN           []string
    45  	ServerName     string
    46  	SkipCertVerify bool
    47  }
    48  
    49  type HTTPOptions struct {
    50  	Host    string
    51  	Port    int
    52  	Hosts   []string
    53  	Path    string
    54  	Headers http.Header
    55  }
    56  
    57  type WebsocketOption struct {
    58  	Host    string
    59  	Port    string
    60  	Path    string
    61  	Headers http.Header
    62  }
    63  
    64  type Trojan struct {
    65  	option      *Option
    66  	hexPassword []byte
    67  }
    68  
    69  func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
    70  	alpn := defaultALPN
    71  	if len(t.option.ALPN) != 0 {
    72  		alpn = t.option.ALPN
    73  	}
    74  
    75  	tlsConfig := &tls.Config{
    76  		NextProtos:         alpn,
    77  		MinVersion:         tls.VersionTLS12,
    78  		InsecureSkipVerify: t.option.SkipCertVerify,
    79  		ServerName:         t.option.ServerName,
    80  	}
    81  
    82  	tlsConn := tls.Client(conn, tlsConfig)
    83  
    84  	// fix tls handshake not timeout
    85  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    86  	defer cancel()
    87  	if err := tlsConn.HandshakeContext(ctx); err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	return tlsConn, nil
    92  }
    93  
    94  func (t *Trojan) StreamH2Conn(conn net.Conn, h2Option *HTTPOptions) (net.Conn, error) {
    95  	tlsConfig := &tls.Config{
    96  		NextProtos:         []string{"h2"},
    97  		MinVersion:         tls.VersionTLS12,
    98  		InsecureSkipVerify: t.option.SkipCertVerify,
    99  		ServerName:         t.option.ServerName,
   100  	}
   101  
   102  	tlsConn := tls.Client(conn, tlsConfig)
   103  
   104  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   105  	defer cancel()
   106  	if err := tlsConn.HandshakeContext(ctx); err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return h2.StreamH2Conn(tlsConn, &h2.Config{
   111  		Hosts:   h2Option.Hosts,
   112  		Path:    h2Option.Path,
   113  		Headers: h2Option.Headers,
   114  	})
   115  }
   116  
   117  func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
   118  	alpn := defaultWebsocketALPN
   119  	if len(t.option.ALPN) != 0 {
   120  		alpn = t.option.ALPN
   121  	}
   122  
   123  	tlsConfig := &tls.Config{
   124  		NextProtos:         alpn,
   125  		MinVersion:         tls.VersionTLS12,
   126  		InsecureSkipVerify: t.option.SkipCertVerify,
   127  		ServerName:         t.option.ServerName,
   128  	}
   129  
   130  	return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{
   131  		Host:      wsOptions.Host,
   132  		Port:      wsOptions.Port,
   133  		Path:      wsOptions.Path,
   134  		Headers:   wsOptions.Headers,
   135  		TLS:       true,
   136  		TLSConfig: tlsConfig,
   137  	})
   138  }
   139  
   140  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   141  	buf := pool.BufferWriter{}
   142  
   143  	buf.PutSlice(t.hexPassword)
   144  	buf.PutSlice(crlf)
   145  
   146  	buf.PutUint8(command)
   147  	buf.PutSlice(socks5Addr)
   148  	buf.PutSlice(crlf)
   149  
   150  	_, err := w.Write(buf.Bytes())
   151  	return err
   152  }
   153  
   154  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   155  	return &PacketConn{
   156  		Conn: conn,
   157  	}
   158  }
   159  
   160  func writePacket(w io.Writer, socks5Addr, payload []byte) (n int, err error) {
   161  	bufP := pool.GetNetBuf()
   162  	defer pool.PutNetBuf(bufP)
   163  
   164  	n = len(payload)
   165  	t := copy(*bufP, socks5Addr)
   166  	binary.BigEndian.PutUint16((*bufP)[t:], uint16(n))
   167  	t += 2
   168  	t += copy((*bufP)[t:], crlf)
   169  	t += copy((*bufP)[t:], payload)
   170  
   171  	delta := t - n
   172  	n, err = w.Write((*bufP)[:t])
   173  	if n < t && err == nil {
   174  		err = io.ErrShortWrite
   175  	}
   176  	n = max(n-delta, 0)
   177  	return
   178  }
   179  
   180  func WritePacket(w io.Writer, socks5Addr, payload []byte) (n int, err error) {
   181  	total := len(payload)
   182  	if total <= maxLength {
   183  		return writePacket(w, socks5Addr, payload)
   184  	}
   185  
   186  	offset := 0
   187  	cursor := 0
   188  	for {
   189  		cursor = min(offset+maxLength, total)
   190  
   191  		n, err = writePacket(w, socks5Addr, payload[offset:cursor])
   192  
   193  		offset = min(offset+n, total)
   194  		if err != nil || offset == total {
   195  			n = offset
   196  			return
   197  		}
   198  	}
   199  }
   200  
   201  func ReadPacket(r io.Reader, payload []byte) (addr *net.UDPAddr, n int, remain int, err error) {
   202  	var socAddr socks5.Addr
   203  	socAddr, err = socks5.ReadAddr(r, payload)
   204  	if err != nil {
   205  		if err != io.EOF {
   206  			err = fmt.Errorf("read addr error, %w", err)
   207  		}
   208  		return
   209  	}
   210  	addr = socAddr.UDPAddr()
   211  	if addr == nil {
   212  		err = errors.New("parse addr error")
   213  		return
   214  	}
   215  
   216  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   217  		if err != io.EOF {
   218  			err = fmt.Errorf("read length error, %w", err)
   219  		}
   220  		return
   221  	}
   222  
   223  	total := int(binary.BigEndian.Uint16(payload[:2]))
   224  	if total > maxLength {
   225  		err = errors.New("invalid packet")
   226  		return
   227  	}
   228  
   229  	// read crlf
   230  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   231  		if err != io.EOF {
   232  			err = fmt.Errorf("read crlf error, %w", err)
   233  		}
   234  		return
   235  	}
   236  
   237  	length := min(len(payload), total)
   238  	if length, err = io.ReadFull(r, payload[:length]); err != nil && err != io.EOF {
   239  		err = fmt.Errorf("read packet error, %w", err)
   240  	}
   241  
   242  	return addr, length, total - length, err
   243  }
   244  
   245  func New(option *Option) *Trojan {
   246  	return &Trojan{option, hexSha224([]byte(option.Password))}
   247  }
   248  
   249  type PacketConn struct {
   250  	net.Conn
   251  	remain int
   252  	rAddr  net.Addr
   253  	mux    sync.Mutex
   254  }
   255  
   256  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   257  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   258  }
   259  
   260  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   261  	pc.mux.Lock()
   262  	defer pc.mux.Unlock()
   263  	if pc.remain != 0 {
   264  		length := min(len(b), pc.remain)
   265  
   266  		n, err := pc.Conn.Read(b[:length])
   267  
   268  		pc.remain -= n
   269  		addr := pc.rAddr
   270  		if pc.remain == 0 {
   271  			pc.rAddr = nil
   272  		}
   273  
   274  		return n, addr, err
   275  	}
   276  
   277  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   278  	if err == nil && remain > 0 {
   279  		pc.remain = remain
   280  		pc.rAddr = addr
   281  	}
   282  
   283  	return n, addr, err
   284  }
   285  
   286  func hexSha224(data []byte) []byte {
   287  	buf := make([]byte, 56)
   288  	hash := sha256.New224()
   289  	hash.Write(data)
   290  	hex.Encode(buf, hash.Sum(nil))
   291  	return buf
   292  }