github.com/igoogolx/clash@v1.19.8/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"sync"
    14  
    15  	C "github.com/igoogolx/clash/constant"
    16  	"github.com/igoogolx/clash/transport/socks5"
    17  	"github.com/igoogolx/clash/transport/vmess"
    18  
    19  	"github.com/Dreamacro/protobytes"
    20  )
    21  
    22  const (
    23  	// max packet length
    24  	maxLength = 8192
    25  )
    26  
    27  var (
    28  	defaultALPN          = []string{"h2", "http/1.1"}
    29  	defaultWebsocketALPN = []string{"http/1.1"}
    30  
    31  	crlf = []byte{'\r', '\n'}
    32  )
    33  
    34  type Command = byte
    35  
    36  var (
    37  	CommandTCP byte = 1
    38  	CommandUDP byte = 3
    39  )
    40  
    41  type Option struct {
    42  	Password       string
    43  	ALPN           []string
    44  	ServerName     string
    45  	SkipCertVerify bool
    46  }
    47  
    48  type WebsocketOption struct {
    49  	Host    string
    50  	Port    string
    51  	Path    string
    52  	Headers http.Header
    53  }
    54  
    55  type Trojan struct {
    56  	option      *Option
    57  	hexPassword []byte
    58  }
    59  
    60  func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
    61  	alpn := defaultALPN
    62  	if len(t.option.ALPN) != 0 {
    63  		alpn = t.option.ALPN
    64  	}
    65  
    66  	tlsConfig := &tls.Config{
    67  		NextProtos:         alpn,
    68  		MinVersion:         tls.VersionTLS12,
    69  		InsecureSkipVerify: t.option.SkipCertVerify,
    70  		ServerName:         t.option.ServerName,
    71  	}
    72  
    73  	tlsConn := tls.Client(conn, tlsConfig)
    74  
    75  	// fix tls handshake not timeout
    76  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    77  	defer cancel()
    78  	if err := tlsConn.HandshakeContext(ctx); err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	return tlsConn, nil
    83  }
    84  
    85  func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
    86  	alpn := defaultWebsocketALPN
    87  	if len(t.option.ALPN) != 0 {
    88  		alpn = t.option.ALPN
    89  	}
    90  
    91  	tlsConfig := &tls.Config{
    92  		NextProtos:         alpn,
    93  		MinVersion:         tls.VersionTLS12,
    94  		InsecureSkipVerify: t.option.SkipCertVerify,
    95  		ServerName:         t.option.ServerName,
    96  	}
    97  
    98  	return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{
    99  		Host:      wsOptions.Host,
   100  		Port:      wsOptions.Port,
   101  		Path:      wsOptions.Path,
   102  		Headers:   wsOptions.Headers,
   103  		TLS:       true,
   104  		TLSConfig: tlsConfig,
   105  	})
   106  }
   107  
   108  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   109  	buf := protobytes.BytesWriter{}
   110  	buf.PutSlice(t.hexPassword)
   111  	buf.PutSlice(crlf)
   112  
   113  	buf.PutUint8(command)
   114  	buf.PutSlice(socks5Addr)
   115  	buf.PutSlice(crlf)
   116  
   117  	_, err := w.Write(buf.Bytes())
   118  	return err
   119  }
   120  
   121  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   122  	return &PacketConn{
   123  		Conn: conn,
   124  	}
   125  }
   126  
   127  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   128  	buf := protobytes.BytesWriter{}
   129  	buf.PutSlice(socks5Addr)
   130  	buf.PutUint16be(uint16(len(payload)))
   131  	buf.PutSlice(crlf)
   132  	buf.PutSlice(payload)
   133  	return w.Write(buf.Bytes())
   134  }
   135  
   136  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   137  	if len(payload) <= maxLength {
   138  		return writePacket(w, socks5Addr, payload)
   139  	}
   140  
   141  	offset := 0
   142  	total := len(payload)
   143  	for {
   144  		cursor := offset + maxLength
   145  		if cursor > total {
   146  			cursor = total
   147  		}
   148  
   149  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   150  		if err != nil {
   151  			return offset + n, err
   152  		}
   153  
   154  		offset = cursor
   155  		if offset == total {
   156  			break
   157  		}
   158  	}
   159  
   160  	return total, nil
   161  }
   162  
   163  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
   164  	addr, err := socks5.ReadAddr(r, payload)
   165  	if err != nil {
   166  		return nil, 0, 0, errors.New("read addr error")
   167  	}
   168  	uAddr := addr.UDPAddr()
   169  	if uAddr == nil {
   170  		return nil, 0, 0, errors.New("parse addr error")
   171  	}
   172  
   173  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   174  		return nil, 0, 0, errors.New("read length error")
   175  	}
   176  
   177  	total := int(binary.BigEndian.Uint16(payload[:2]))
   178  	if total > maxLength {
   179  		return nil, 0, 0, errors.New("packet invalid")
   180  	}
   181  
   182  	// read crlf
   183  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   184  		return nil, 0, 0, errors.New("read crlf error")
   185  	}
   186  
   187  	length := len(payload)
   188  	if total < length {
   189  		length = total
   190  	}
   191  
   192  	if _, err = io.ReadFull(r, payload[:length]); err != nil {
   193  		return nil, 0, 0, errors.New("read packet error")
   194  	}
   195  
   196  	return uAddr, length, total - length, nil
   197  }
   198  
   199  func New(option *Option) *Trojan {
   200  	return &Trojan{option, hexSha224([]byte(option.Password))}
   201  }
   202  
   203  type PacketConn struct {
   204  	net.Conn
   205  	remain int
   206  	rAddr  net.Addr
   207  	mux    sync.Mutex
   208  }
   209  
   210  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   211  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   212  }
   213  
   214  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   215  	pc.mux.Lock()
   216  	defer pc.mux.Unlock()
   217  	if pc.remain != 0 {
   218  		length := len(b)
   219  		if pc.remain < length {
   220  			length = pc.remain
   221  		}
   222  
   223  		n, err := pc.Conn.Read(b[:length])
   224  		if err != nil {
   225  			return 0, nil, err
   226  		}
   227  
   228  		pc.remain -= n
   229  		addr := pc.rAddr
   230  		if pc.remain == 0 {
   231  			pc.rAddr = nil
   232  		}
   233  
   234  		return n, addr, nil
   235  	}
   236  
   237  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   238  	if err != nil {
   239  		return 0, nil, err
   240  	}
   241  
   242  	if remain != 0 {
   243  		pc.remain = remain
   244  		pc.rAddr = addr
   245  	}
   246  
   247  	return n, addr, nil
   248  }
   249  
   250  func hexSha224(data []byte) []byte {
   251  	buf := make([]byte, 56)
   252  	hash := sha256.New224()
   253  	hash.Write(data)
   254  	hex.Encode(buf, hash.Sum(nil))
   255  	return buf
   256  }