github.com/chwjbn/xclash@v0.2.0/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"crypto/tls"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"errors"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"sync"
    13  
    14  	"github.com/chwjbn/xclash/common/pool"
    15  	"github.com/chwjbn/xclash/transport/socks5"
    16  	"github.com/chwjbn/xclash/transport/vmess"
    17  )
    18  
    19  const (
    20  	// max packet length
    21  	maxLength = 8192
    22  )
    23  
    24  var (
    25  	defaultALPN          = []string{"h2", "http/1.1"}
    26  	defaultWebsocketALPN = []string{"http/1.1"}
    27  
    28  	crlf = []byte{'\r', '\n'}
    29  )
    30  
    31  type Command = byte
    32  
    33  var (
    34  	CommandTCP byte = 1
    35  	CommandUDP byte = 3
    36  )
    37  
    38  type Option struct {
    39  	Password       string
    40  	ALPN           []string
    41  	ServerName     string
    42  	SkipCertVerify bool
    43  }
    44  
    45  type WebsocketOption struct {
    46  	Host    string
    47  	Port    string
    48  	Path    string
    49  	Headers http.Header
    50  }
    51  
    52  type Trojan struct {
    53  	option      *Option
    54  	hexPassword []byte
    55  }
    56  
    57  func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
    58  	alpn := defaultALPN
    59  	if len(t.option.ALPN) != 0 {
    60  		alpn = t.option.ALPN
    61  	}
    62  
    63  	tlsConfig := &tls.Config{
    64  		NextProtos:         alpn,
    65  		MinVersion:         tls.VersionTLS12,
    66  		InsecureSkipVerify: t.option.SkipCertVerify,
    67  		ServerName:         t.option.ServerName,
    68  	}
    69  
    70  	tlsConn := tls.Client(conn, tlsConfig)
    71  	if err := tlsConn.Handshake(); err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	return tlsConn, nil
    76  }
    77  
    78  func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
    79  	alpn := defaultWebsocketALPN
    80  	if len(t.option.ALPN) != 0 {
    81  		alpn = t.option.ALPN
    82  	}
    83  
    84  	tlsConfig := &tls.Config{
    85  		NextProtos:         alpn,
    86  		MinVersion:         tls.VersionTLS12,
    87  		InsecureSkipVerify: t.option.SkipCertVerify,
    88  		ServerName:         t.option.ServerName,
    89  	}
    90  
    91  	return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{
    92  		Host:      wsOptions.Host,
    93  		Port:      wsOptions.Port,
    94  		Path:      wsOptions.Path,
    95  		Headers:   wsOptions.Headers,
    96  		TLS:       true,
    97  		TLSConfig: tlsConfig,
    98  	})
    99  }
   100  
   101  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   102  	buf := pool.GetBuffer()
   103  	defer pool.PutBuffer(buf)
   104  
   105  	buf.Write(t.hexPassword)
   106  	buf.Write(crlf)
   107  
   108  	buf.WriteByte(command)
   109  	buf.Write(socks5Addr)
   110  	buf.Write(crlf)
   111  
   112  	_, err := w.Write(buf.Bytes())
   113  	return err
   114  }
   115  
   116  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   117  	return &PacketConn{
   118  		Conn: conn,
   119  	}
   120  }
   121  
   122  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   123  	buf := pool.GetBuffer()
   124  	defer pool.PutBuffer(buf)
   125  
   126  	buf.Write(socks5Addr)
   127  	binary.Write(buf, binary.BigEndian, uint16(len(payload)))
   128  	buf.Write(crlf)
   129  	buf.Write(payload)
   130  
   131  	return w.Write(buf.Bytes())
   132  }
   133  
   134  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   135  	if len(payload) <= maxLength {
   136  		return writePacket(w, socks5Addr, payload)
   137  	}
   138  
   139  	offset := 0
   140  	total := len(payload)
   141  	for {
   142  		cursor := offset + maxLength
   143  		if cursor > total {
   144  			cursor = total
   145  		}
   146  
   147  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   148  		if err != nil {
   149  			return offset + n, err
   150  		}
   151  
   152  		offset = cursor
   153  		if offset == total {
   154  			break
   155  		}
   156  	}
   157  
   158  	return total, nil
   159  }
   160  
   161  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
   162  	addr, err := socks5.ReadAddr(r, payload)
   163  	if err != nil {
   164  		return nil, 0, 0, errors.New("read addr error")
   165  	}
   166  	uAddr := addr.UDPAddr()
   167  
   168  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   169  		return nil, 0, 0, errors.New("read length error")
   170  	}
   171  
   172  	total := int(binary.BigEndian.Uint16(payload[:2]))
   173  	if total > maxLength {
   174  		return nil, 0, 0, errors.New("packet invalid")
   175  	}
   176  
   177  	// read crlf
   178  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   179  		return nil, 0, 0, errors.New("read crlf error")
   180  	}
   181  
   182  	length := len(payload)
   183  	if total < length {
   184  		length = total
   185  	}
   186  
   187  	if _, err = io.ReadFull(r, payload[:length]); err != nil {
   188  		return nil, 0, 0, errors.New("read packet error")
   189  	}
   190  
   191  	return uAddr, length, total - length, nil
   192  }
   193  
   194  func New(option *Option) *Trojan {
   195  	return &Trojan{option, hexSha224([]byte(option.Password))}
   196  }
   197  
   198  type PacketConn struct {
   199  	net.Conn
   200  	remain int
   201  	rAddr  net.Addr
   202  	mux    sync.Mutex
   203  }
   204  
   205  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   206  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   207  }
   208  
   209  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   210  	pc.mux.Lock()
   211  	defer pc.mux.Unlock()
   212  	if pc.remain != 0 {
   213  		length := len(b)
   214  		if pc.remain < length {
   215  			length = pc.remain
   216  		}
   217  
   218  		n, err := pc.Conn.Read(b[:length])
   219  		if err != nil {
   220  			return 0, nil, err
   221  		}
   222  
   223  		pc.remain -= n
   224  		addr := pc.rAddr
   225  		if pc.remain == 0 {
   226  			pc.rAddr = nil
   227  		}
   228  
   229  		return n, addr, nil
   230  	}
   231  
   232  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   233  	if err != nil {
   234  		return 0, nil, err
   235  	}
   236  
   237  	if remain != 0 {
   238  		pc.remain = remain
   239  		pc.rAddr = addr
   240  	}
   241  
   242  	return n, addr, nil
   243  }
   244  
   245  func hexSha224(data []byte) []byte {
   246  	buf := make([]byte, 56)
   247  	hash := sha256.New224()
   248  	hash.Write(data)
   249  	hex.Encode(buf, hash.Sum(nil))
   250  	return buf
   251  }