github.com/kelleygo/clashcore@v1.0.2/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"sync"
    14  
    15  	N "github.com/kelleygo/clashcore/common/net"
    16  	"github.com/kelleygo/clashcore/common/pool"
    17  	"github.com/kelleygo/clashcore/component/ca"
    18  	tlsC "github.com/kelleygo/clashcore/component/tls"
    19  	C "github.com/kelleygo/clashcore/constant"
    20  	"github.com/kelleygo/clashcore/transport/socks5"
    21  	"github.com/kelleygo/clashcore/transport/vmess"
    22  )
    23  
    24  const (
    25  	// max packet length
    26  	maxLength = 8192
    27  )
    28  
    29  var (
    30  	defaultALPN          = []string{"h2", "http/1.1"}
    31  	defaultWebsocketALPN = []string{"http/1.1"}
    32  
    33  	crlf = []byte{'\r', '\n'}
    34  )
    35  
    36  type Command = byte
    37  
    38  const (
    39  	CommandTCP byte = 1
    40  	CommandUDP byte = 3
    41  
    42  	// deprecated XTLS commands, as souvenirs
    43  	commandXRD byte = 0xf0 // XTLS direct mode
    44  	commandXRO byte = 0xf1 // XTLS origin mode
    45  )
    46  
    47  type Option struct {
    48  	Password          string
    49  	ALPN              []string
    50  	ServerName        string
    51  	SkipCertVerify    bool
    52  	Fingerprint       string
    53  	ClientFingerprint string
    54  	Reality           *tlsC.RealityConfig
    55  }
    56  
    57  type WebsocketOption struct {
    58  	Host                     string
    59  	Port                     string
    60  	Path                     string
    61  	Headers                  http.Header
    62  	V2rayHttpUpgrade         bool
    63  	V2rayHttpUpgradeFastOpen bool
    64  }
    65  
    66  type Trojan struct {
    67  	option      *Option
    68  	hexPassword []byte
    69  }
    70  
    71  func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
    72  	alpn := defaultALPN
    73  	if len(t.option.ALPN) != 0 {
    74  		alpn = t.option.ALPN
    75  	}
    76  	tlsConfig := &tls.Config{
    77  		NextProtos:         alpn,
    78  		MinVersion:         tls.VersionTLS12,
    79  		InsecureSkipVerify: t.option.SkipCertVerify,
    80  		ServerName:         t.option.ServerName,
    81  	}
    82  
    83  	var err error
    84  	tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	if len(t.option.ClientFingerprint) != 0 {
    90  		if t.option.Reality == nil {
    91  			utlsConn, valid := vmess.GetUTLSConn(conn, t.option.ClientFingerprint, tlsConfig)
    92  			if valid {
    93  				ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    94  				defer cancel()
    95  
    96  				err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx)
    97  				return utlsConn, err
    98  			}
    99  		} else {
   100  			ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   101  			defer cancel()
   102  			return tlsC.GetRealityConn(ctx, conn, t.option.ClientFingerprint, tlsConfig, t.option.Reality)
   103  		}
   104  	}
   105  	if t.option.Reality != nil {
   106  		return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint")
   107  	}
   108  
   109  	tlsConn := tls.Client(conn, tlsConfig)
   110  
   111  	// fix tls handshake not timeout
   112  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   113  	defer cancel()
   114  
   115  	err = tlsConn.HandshakeContext(ctx)
   116  	return tlsConn, err
   117  }
   118  
   119  func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
   120  	alpn := defaultWebsocketALPN
   121  	if len(t.option.ALPN) != 0 {
   122  		alpn = t.option.ALPN
   123  	}
   124  
   125  	tlsConfig := &tls.Config{
   126  		NextProtos:         alpn,
   127  		MinVersion:         tls.VersionTLS12,
   128  		InsecureSkipVerify: t.option.SkipCertVerify,
   129  		ServerName:         t.option.ServerName,
   130  	}
   131  
   132  	return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{
   133  		Host:                     wsOptions.Host,
   134  		Port:                     wsOptions.Port,
   135  		Path:                     wsOptions.Path,
   136  		Headers:                  wsOptions.Headers,
   137  		V2rayHttpUpgrade:         wsOptions.V2rayHttpUpgrade,
   138  		V2rayHttpUpgradeFastOpen: wsOptions.V2rayHttpUpgradeFastOpen,
   139  		TLS:                      true,
   140  		TLSConfig:                tlsConfig,
   141  		ClientFingerprint:        t.option.ClientFingerprint,
   142  	})
   143  }
   144  
   145  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   146  	buf := pool.GetBuffer()
   147  	defer pool.PutBuffer(buf)
   148  
   149  	buf.Write(t.hexPassword)
   150  	buf.Write(crlf)
   151  
   152  	buf.WriteByte(command)
   153  	buf.Write(socks5Addr)
   154  	buf.Write(crlf)
   155  
   156  	_, err := w.Write(buf.Bytes())
   157  	return err
   158  }
   159  
   160  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   161  	return &PacketConn{
   162  		Conn: conn,
   163  	}
   164  }
   165  
   166  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   167  	buf := pool.GetBuffer()
   168  	defer pool.PutBuffer(buf)
   169  
   170  	buf.Write(socks5Addr)
   171  	binary.Write(buf, binary.BigEndian, uint16(len(payload)))
   172  	buf.Write(crlf)
   173  	buf.Write(payload)
   174  
   175  	return w.Write(buf.Bytes())
   176  }
   177  
   178  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   179  	if len(payload) <= maxLength {
   180  		return writePacket(w, socks5Addr, payload)
   181  	}
   182  
   183  	offset := 0
   184  	total := len(payload)
   185  	for {
   186  		cursor := offset + maxLength
   187  		if cursor > total {
   188  			cursor = total
   189  		}
   190  
   191  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   192  		if err != nil {
   193  			return offset + n, err
   194  		}
   195  
   196  		offset = cursor
   197  		if offset == total {
   198  			break
   199  		}
   200  	}
   201  
   202  	return total, nil
   203  }
   204  
   205  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
   206  	addr, err := socks5.ReadAddr(r, payload)
   207  	if err != nil {
   208  		return nil, 0, 0, errors.New("read addr error")
   209  	}
   210  	uAddr := addr.UDPAddr()
   211  	if uAddr == nil {
   212  		return nil, 0, 0, errors.New("parse addr error")
   213  	}
   214  
   215  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   216  		return nil, 0, 0, errors.New("read length error")
   217  	}
   218  
   219  	total := int(binary.BigEndian.Uint16(payload[:2]))
   220  	if total > maxLength {
   221  		return nil, 0, 0, errors.New("packet invalid")
   222  	}
   223  
   224  	// read crlf
   225  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   226  		return nil, 0, 0, errors.New("read crlf error")
   227  	}
   228  
   229  	length := len(payload)
   230  	if total < length {
   231  		length = total
   232  	}
   233  
   234  	if _, err = io.ReadFull(r, payload[:length]); err != nil {
   235  		return nil, 0, 0, errors.New("read packet error")
   236  	}
   237  
   238  	return uAddr, length, total - length, nil
   239  }
   240  
   241  func New(option *Option) *Trojan {
   242  	return &Trojan{option, hexSha224([]byte(option.Password))}
   243  }
   244  
   245  var _ N.EnhancePacketConn = (*PacketConn)(nil)
   246  
   247  type PacketConn struct {
   248  	net.Conn
   249  	remain int
   250  	rAddr  net.Addr
   251  	mux    sync.Mutex
   252  }
   253  
   254  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   255  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   256  }
   257  
   258  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   259  	pc.mux.Lock()
   260  	defer pc.mux.Unlock()
   261  	if pc.remain != 0 {
   262  		length := len(b)
   263  		if pc.remain < length {
   264  			length = pc.remain
   265  		}
   266  
   267  		n, err := pc.Conn.Read(b[:length])
   268  		if err != nil {
   269  			return 0, nil, err
   270  		}
   271  
   272  		pc.remain -= n
   273  		addr := pc.rAddr
   274  		if pc.remain == 0 {
   275  			pc.rAddr = nil
   276  		}
   277  
   278  		return n, addr, nil
   279  	}
   280  
   281  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   282  	if err != nil {
   283  		return 0, nil, err
   284  	}
   285  
   286  	if remain != 0 {
   287  		pc.remain = remain
   288  		pc.rAddr = addr
   289  	}
   290  
   291  	return n, addr, nil
   292  }
   293  
   294  func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
   295  	pc.mux.Lock()
   296  	defer pc.mux.Unlock()
   297  
   298  	destination, err := socks5.ReadAddr0(pc.Conn)
   299  	if err != nil {
   300  		return nil, nil, nil, err
   301  	}
   302  	addr = destination.UDPAddr()
   303  
   304  	data = pool.Get(pool.UDPBufferSize)
   305  	put = func() {
   306  		_ = pool.Put(data)
   307  	}
   308  
   309  	_, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF
   310  	if err != nil {
   311  		if put != nil {
   312  			put()
   313  		}
   314  		return nil, nil, nil, err
   315  	}
   316  	length := binary.BigEndian.Uint16(data)
   317  
   318  	if length > 0 {
   319  		data = data[:length]
   320  		_, err = io.ReadFull(pc.Conn, data)
   321  		if err != nil {
   322  			if put != nil {
   323  				put()
   324  			}
   325  			return nil, nil, nil, err
   326  		}
   327  	} else {
   328  		if put != nil {
   329  			put()
   330  		}
   331  		return nil, nil, addr, nil
   332  	}
   333  
   334  	return
   335  }
   336  
   337  func hexSha224(data []byte) []byte {
   338  	buf := make([]byte, 56)
   339  	hash := sha256.Sum224(data)
   340  	hex.Encode(buf, hash[:])
   341  	return buf
   342  }