github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/transport/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"crypto/tls"
     8  	"encoding/binary"
     9  	"encoding/hex"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"net/http"
    14  	"sync"
    15  
    16  	C "github.com/laof/lite-speed-test/constant"
    17  	"github.com/laof/lite-speed-test/transport/socks5"
    18  	"github.com/laof/lite-speed-test/transport/vmess"
    19  )
    20  
    21  const (
    22  	// max packet length
    23  	maxLength = 8192
    24  )
    25  
    26  var (
    27  	defaultALPN          = []string{"h2", "http/1.1"}
    28  	defaultWebsocketALPN = []string{"http/1.1"}
    29  	crlf                 = []byte{'\r', '\n'}
    30  
    31  	bufPool = sync.Pool{New: func() interface{} { return &bytes.Buffer{} }}
    32  )
    33  
    34  type Command = byte
    35  
    36  var (
    37  	CommandTCP byte = 1
    38  	CommandUDP byte = 3
    39  )
    40  
    41  type Option struct {
    42  	Password           string
    43  	ALPN               []string
    44  	ServerName         string
    45  	SkipCertVerify     bool
    46  	ClientSessionCache tls.ClientSessionCache
    47  }
    48  
    49  type WebsocketOption struct {
    50  	Host    string
    51  	Port    string
    52  	Path    string
    53  	Headers http.Header
    54  }
    55  
    56  type Trojan struct {
    57  	option      *Option
    58  	hexPassword []byte
    59  }
    60  
    61  func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
    62  	alpn := defaultALPN
    63  	if len(t.option.ALPN) != 0 {
    64  		alpn = t.option.ALPN
    65  	}
    66  
    67  	tlsConfig := &tls.Config{
    68  		NextProtos:         alpn,
    69  		MinVersion:         tls.VersionTLS12,
    70  		InsecureSkipVerify: t.option.SkipCertVerify,
    71  		ServerName:         t.option.ServerName,
    72  		ClientSessionCache: t.option.ClientSessionCache,
    73  	}
    74  
    75  	tlsConn := tls.Client(conn, tlsConfig)
    76  	ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
    77  	defer cancel()
    78  	if err := tlsConn.HandshakeContext(ctx); err != nil {
    79  		return nil, err
    80  	}
    81  
    82  	return tlsConn, nil
    83  }
    84  
    85  func (t *Trojan) StreamWebsocketConn(conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
    86  	alpn := defaultWebsocketALPN
    87  	if len(t.option.ALPN) != 0 {
    88  		alpn = t.option.ALPN
    89  	}
    90  
    91  	tlsConfig := &tls.Config{
    92  		NextProtos:         alpn,
    93  		MinVersion:         tls.VersionTLS12,
    94  		InsecureSkipVerify: t.option.SkipCertVerify,
    95  		ServerName:         t.option.ServerName,
    96  	}
    97  
    98  	return vmess.StreamWebsocketConn(conn, &vmess.WebsocketConfig{
    99  		Host:      wsOptions.Host,
   100  		Port:      wsOptions.Port,
   101  		Path:      wsOptions.Path,
   102  		Headers:   wsOptions.Headers,
   103  		TLS:       true,
   104  		TLSConfig: tlsConfig,
   105  	})
   106  }
   107  
   108  func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
   109  	buf := bufPool.Get().(*bytes.Buffer)
   110  	defer bufPool.Put(buf)
   111  	defer buf.Reset()
   112  
   113  	buf.Write(t.hexPassword)
   114  	buf.Write(crlf)
   115  
   116  	buf.WriteByte(command)
   117  	buf.Write(socks5Addr)
   118  	buf.Write(crlf)
   119  
   120  	_, err := w.Write(buf.Bytes())
   121  	return err
   122  }
   123  
   124  func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
   125  	return &PacketConn{
   126  		Conn: conn,
   127  	}
   128  }
   129  
   130  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   131  	buf := bufPool.Get().(*bytes.Buffer)
   132  	defer bufPool.Put(buf)
   133  	defer buf.Reset()
   134  
   135  	buf.Write(socks5Addr)
   136  	binary.Write(buf, binary.BigEndian, uint16(len(payload)))
   137  	buf.Write(crlf)
   138  	buf.Write(payload)
   139  
   140  	return w.Write(buf.Bytes())
   141  }
   142  
   143  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   144  	if len(payload) <= maxLength {
   145  		return writePacket(w, socks5Addr, payload)
   146  	}
   147  
   148  	offset := 0
   149  	total := len(payload)
   150  	for {
   151  		cursor := offset + maxLength
   152  		if cursor > total {
   153  			cursor = total
   154  		}
   155  
   156  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   157  		if err != nil {
   158  			return offset + n, err
   159  		}
   160  
   161  		offset = cursor
   162  		if offset == total {
   163  			break
   164  		}
   165  	}
   166  
   167  	return total, nil
   168  }
   169  
   170  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
   171  	addr, err := socks5.ReadAddr(r, payload)
   172  	if err != nil {
   173  		return nil, 0, 0, errors.New("read addr error")
   174  	}
   175  	uAddr := addr.UDPAddr()
   176  
   177  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   178  		return nil, 0, 0, errors.New("read length error")
   179  	}
   180  
   181  	total := int(binary.BigEndian.Uint16(payload[:2]))
   182  	if total > maxLength {
   183  		return nil, 0, 0, errors.New("packet invalid")
   184  	}
   185  
   186  	// read crlf
   187  	if _, err = io.ReadFull(r, payload[:2]); err != nil {
   188  		return nil, 0, 0, errors.New("read crlf error")
   189  	}
   190  
   191  	length := len(payload)
   192  	if total < length {
   193  		length = total
   194  	}
   195  
   196  	if _, err = io.ReadFull(r, payload[:length]); err != nil {
   197  		return nil, 0, 0, errors.New("read packet error")
   198  	}
   199  
   200  	return uAddr, length, total - length, nil
   201  }
   202  
   203  func New(option *Option) *Trojan {
   204  	return &Trojan{option, hexSha224([]byte(option.Password))}
   205  }
   206  
   207  type PacketConn struct {
   208  	net.Conn
   209  	remain int
   210  	rAddr  net.Addr
   211  	mux    sync.Mutex
   212  }
   213  
   214  func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   215  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   216  }
   217  
   218  func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   219  	pc.mux.Lock()
   220  	defer pc.mux.Unlock()
   221  	if pc.remain != 0 {
   222  		length := len(b)
   223  		if pc.remain < length {
   224  			length = pc.remain
   225  		}
   226  
   227  		n, err := pc.Conn.Read(b[:length])
   228  		if err != nil {
   229  			return 0, nil, err
   230  		}
   231  
   232  		pc.remain -= n
   233  		addr := pc.rAddr
   234  		if pc.remain == 0 {
   235  			pc.rAddr = nil
   236  		}
   237  
   238  		return n, addr, nil
   239  	}
   240  
   241  	addr, n, remain, err := ReadPacket(pc.Conn, b)
   242  	if err != nil {
   243  		return 0, nil, err
   244  	}
   245  
   246  	if remain != 0 {
   247  		pc.remain = remain
   248  		pc.rAddr = addr
   249  	}
   250  
   251  	return n, addr, nil
   252  }
   253  
   254  func hexSha224(data []byte) []byte {
   255  	buf := make([]byte, 56)
   256  	hash := sha256.New224()
   257  	hash.Write(data)
   258  	hex.Encode(buf, hash.Sum(nil))
   259  	return buf
   260  }