github.com/metacubex/mihomo@v1.18.5/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"crypto/sha256"
     6  	"crypto/tls"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"sync"
    14  
    15  	N "github.com/metacubex/mihomo/common/net"
    16  	"github.com/metacubex/mihomo/common/pool"
    17  	"github.com/metacubex/mihomo/component/ca"
    18  	tlsC "github.com/metacubex/mihomo/component/tls"
    19  	C "github.com/metacubex/mihomo/constant"
    20  	"github.com/metacubex/mihomo/transport/socks5"
    21  	"github.com/metacubex/mihomo/transport/vmess"
    22  )
    23  
    24  const (
    25  	// max packet length
    26  	maxLength = 8192
    27  )
    28  
    29  var (
    30  	defaultALPN          = []string{"h2", "http/1.1"}
    31  	defaultWebsocketALPN = []string{"http/1.1"}
    32  
    33  	crlf = []byte{'\r', '\n'}
    34  )
    35  
    36  type Command = byte
    37  
    38  const (
    39  	CommandTCP byte = 1
    40  	CommandUDP byte = 3
    41  
    42  	// deprecated XTLS commands, as souvenirs
    43  	commandXRD byte = 0xf0 // XTLS direct mode
    44  	commandXRO byte = 0xf1 // XTLS origin mode
    45  )
    46  
    47  type Option struct {
    48  	Password          string
    49  	ALPN              []string
    50  	ServerName        string
    51  	SkipCertVerify    bool
    52  	Fingerprint       string
    53  	ClientFingerprint string
    54  	Reality           *tlsC.RealityConfig
    55  }
    56  
    57  type WebsocketOption struct {
    58  	Host                     string
    59  	Port                     string
    60  	Path                     string
    61  	Headers                  http.Header
    62  	V2rayHttpUpgrade         bool
    63  	V2rayHttpUpgradeFastOpen bool
    64  }
    65  
    66  type Trojan struct {
    67  	option      *Option
    68  	hexPassword []byte
    69  }
    70  
    71  func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
    72  	alpn := defaultALPN
    73  	if len(t.option.ALPN) != 0 {
    74  		alpn = t.option.ALPN
    75  	}
    76  	tlsConfig := &tls.Config{
    77  		NextProtos:         alpn,
    78  		MinVersion:         tls.VersionTLS12,
    79  		InsecureSkipVerify: t.option.SkipCertVerify,
    80  		ServerName:         t.option.ServerName,
    81  	}
    82  
    83  	var err error
    84  	tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	if len(t.option.ClientFingerprint) != 0 {
    90  		if t.option.Reality == nil {
    91  			utlsConn, valid := vmess.GetUTLSConn(conn, t.option.ClientFingerprint, tlsConfig)
    92  			if valid {
    93  				ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    94  				defer cancel()
    95  
    96  				err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx)
    97  				return utlsConn, err
    98  			}
    99  		} else {
   100  			ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   101  			defer cancel()
   102  			return tlsC.GetRealityConn(ctx, conn, t.option.ClientFingerprint, tlsConfig, t.option.Reality)
   103  		}
   104  	}
   105  	if t.option.Reality != nil {
   106  		return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint")
   107  	}
   108  
   109  	tlsConn := tls.Client(conn, tlsConfig)
   110  
   111  	// fix tls handshake not timeout
   112  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
   113  	defer cancel()
   114  
   115  	err = tlsConn.HandshakeContext(ctx)
   116  	return tlsConn, err
   117  }
   118  
   119  func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
   120  	alpn := defaultWebsocketALPN
   121  	if len(t.option.ALPN) != 0 {
   122  		alpn = t.option.ALPN
   123  	}
   124  
   125  	tlsConfig := &tls.Config{
   126  		NextProtos:         alpn,
   127  		MinVersion:         tls.VersionTLS12,
   128  		InsecureSkipVerify: t.option.SkipCertVerify,
   129  		ServerName:         t.option.ServerName,
   130  	}
   131  
   132  	var err error
   133  	tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{
   139  		Host:                     wsOptions.Host,
   140  		Port:                     wsOptions.Port,
   141  		Path:                     wsOptions.Path,
   142  		Headers:                  wsOptions.Headers,
   143  		V2rayHttpUpgrade:         wsOptions.V2rayHttpUpgrade,
   144  		V2rayHttpUpgradeFastOpen: wsOptions.V2rayHttpUpgradeFastOpen,
   145  		TLS:                      true,
   146  		TLSConfig:                tlsConfig,
   147  		ClientFingerprint:        t.option.ClientFingerprint,
   148  	})
   149  }
   150  
   151  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   152  	buf := pool.GetBuffer()
   153  	defer pool.PutBuffer(buf)
   154  
   155  	buf.Write(t.hexPassword)
   156  	buf.Write(crlf)
   157  
   158  	buf.WriteByte(command)
   159  	buf.Write(socks5Addr)
   160  	buf.Write(crlf)
   161  
   162  	_, err := w.Write(buf.Bytes())
   163  	return err
   164  }
   165  
   166  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   167  	return &PacketConn{
   168  		Conn: conn,
   169  	}
   170  }
   171  
   172  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   173  	buf := pool.GetBuffer()
   174  	defer pool.PutBuffer(buf)
   175  
   176  	buf.Write(socks5Addr)
   177  	binary.Write(buf, binary.BigEndian, uint16(len(payload)))
   178  	buf.Write(crlf)
   179  	buf.Write(payload)
   180  
   181  	return w.Write(buf.Bytes())
   182  }
   183  
   184  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   185  	if len(payload) <= maxLength {
   186  		return writePacket(w, socks5Addr, payload)
   187  	}
   188  
   189  	offset := 0
   190  	total := len(payload)
   191  	for {
   192  		cursor := offset + maxLength
   193  		if cursor > total {
   194  			cursor = total
   195  		}
   196  
   197  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   198  		if err != nil {
   199  			return offset + n, err
   200  		}
   201  
   202  		offset = cursor
   203  		if offset == total {
   204  			break
   205  		}
   206  	}
   207  
   208  	return total, nil
   209  }
   210  
   211  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
   212  	addr, err := socks5.ReadAddr(r, payload)
   213  	if err != nil {
   214  		return nil, 0, 0, errors.New("read addr error")
   215  	}
   216  	uAddr := addr.UDPAddr()
   217  	if uAddr == nil {
   218  		return nil, 0, 0, errors.New("parse addr error")
   219  	}
   220  
   221  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   222  		return nil, 0, 0, errors.New("read length error")
   223  	}
   224  
   225  	total := int(binary.BigEndian.Uint16(payload[:2]))
   226  	if total > maxLength {
   227  		return nil, 0, 0, errors.New("packet invalid")
   228  	}
   229  
   230  	// read crlf
   231  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   232  		return nil, 0, 0, errors.New("read crlf error")
   233  	}
   234  
   235  	length := len(payload)
   236  	if total < length {
   237  		length = total
   238  	}
   239  
   240  	if _, err = io.ReadFull(r, payload[:length]); err != nil {
   241  		return nil, 0, 0, errors.New("read packet error")
   242  	}
   243  
   244  	return uAddr, length, total - length, nil
   245  }
   246  
   247  func New(option *Option) *Trojan {
   248  	return &Trojan{option, hexSha224([]byte(option.Password))}
   249  }
   250  
   251  var _ N.EnhancePacketConn = (*PacketConn)(nil)
   252  
   253  type PacketConn struct {
   254  	net.Conn
   255  	remain int
   256  	rAddr  net.Addr
   257  	mux    sync.Mutex
   258  }
   259  
   260  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   261  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   262  }
   263  
   264  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   265  	pc.mux.Lock()
   266  	defer pc.mux.Unlock()
   267  	if pc.remain != 0 {
   268  		length := len(b)
   269  		if pc.remain < length {
   270  			length = pc.remain
   271  		}
   272  
   273  		n, err := pc.Conn.Read(b[:length])
   274  		if err != nil {
   275  			return 0, nil, err
   276  		}
   277  
   278  		pc.remain -= n
   279  		addr := pc.rAddr
   280  		if pc.remain == 0 {
   281  			pc.rAddr = nil
   282  		}
   283  
   284  		return n, addr, nil
   285  	}
   286  
   287  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   288  	if err != nil {
   289  		return 0, nil, err
   290  	}
   291  
   292  	if remain != 0 {
   293  		pc.remain = remain
   294  		pc.rAddr = addr
   295  	}
   296  
   297  	return n, addr, nil
   298  }
   299  
   300  func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
   301  	pc.mux.Lock()
   302  	defer pc.mux.Unlock()
   303  
   304  	destination, err := socks5.ReadAddr0(pc.Conn)
   305  	if err != nil {
   306  		return nil, nil, nil, err
   307  	}
   308  	addr = destination.UDPAddr()
   309  
   310  	data = pool.Get(pool.UDPBufferSize)
   311  	put = func() {
   312  		_ = pool.Put(data)
   313  	}
   314  
   315  	_, err = io.ReadFull(pc.Conn, data[:2+2]) // u16be length + CR LF
   316  	if err != nil {
   317  		if put != nil {
   318  			put()
   319  		}
   320  		return nil, nil, nil, err
   321  	}
   322  	length := binary.BigEndian.Uint16(data)
   323  
   324  	if length > 0 {
   325  		data = data[:length]
   326  		_, err = io.ReadFull(pc.Conn, data)
   327  		if err != nil {
   328  			if put != nil {
   329  				put()
   330  			}
   331  			return nil, nil, nil, err
   332  		}
   333  	} else {
   334  		if put != nil {
   335  			put()
   336  		}
   337  		return nil, nil, addr, nil
   338  	}
   339  
   340  	return
   341  }
   342  
   343  func hexSha224(data []byte) []byte {
   344  	buf := make([]byte, 56)
   345  	hash := sha256.Sum224(data)
   346  	hex.Encode(buf, hash[:])
   347  	return buf
   348  }