github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/trojan/client.go (about)

     1  package trojan
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/sha256"
     7  	"encoding/binary"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/socks5/tools"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    17  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    18  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    19  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    20  )
    21  
    22  const (
    23  	MaxPacketSize = 1024 * 8
    24  )
    25  
    26  type Command byte
    27  
    28  const (
    29  	Connect   Command = 1 // TCP
    30  	Associate Command = 3 // UDP
    31  	Mux       Command = 0x7f
    32  )
    33  
    34  var crlf = []byte{'\r', '\n'}
    35  
    36  func (c *Client) WriteHeader(conn net.Conn, cmd Command, addr netapi.Address) (err error) {
    37  	buf := pool.GetBytesWriter(pool.DefaultSize)
    38  	defer buf.Free()
    39  
    40  	_, _ = buf.Write(c.password)
    41  	_, _ = buf.Write(crlf)
    42  	buf.WriteByte(byte(cmd))
    43  	tools.EncodeAddr(addr, buf)
    44  	_, _ = buf.Write(crlf)
    45  
    46  	_, err = conn.Write(buf.Bytes())
    47  	return
    48  }
    49  
    50  // modified from https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/client.go
    51  type Client struct {
    52  	proxy netapi.Proxy
    53  	netapi.EmptyDispatch
    54  	password []byte
    55  }
    56  
    57  func init() {
    58  	point.RegisterProtocol(NewClient)
    59  }
    60  
    61  func NewClient(config *protocol.Protocol_Trojan) point.WrapProxy {
    62  	return func(dialer netapi.Proxy) (netapi.Proxy, error) {
    63  		return &Client{
    64  			password: hexSha224([]byte(config.Trojan.Password)),
    65  			proxy:    dialer,
    66  		}, nil
    67  	}
    68  }
    69  
    70  func (c *Client) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
    71  	conn, err := c.proxy.Conn(ctx, addr)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	if err = c.WriteHeader(conn, Connect, addr); err != nil {
    77  		conn.Close()
    78  		return nil, fmt.Errorf("write header failed: %w", err)
    79  	}
    80  	return conn, nil
    81  }
    82  
    83  func (c *Client) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
    84  	conn, err := c.proxy.Conn(ctx, addr)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	if err = c.WriteHeader(conn, Associate, addr); err != nil {
    89  		conn.Close()
    90  		return nil, fmt.Errorf("write header failed: %w", err)
    91  	}
    92  	return &PacketConn{Conn: conn}, nil
    93  }
    94  
    95  type PacketConn struct {
    96  	net.Conn
    97  
    98  	remain int
    99  	addr   netapi.Address
   100  	mux    sync.Mutex
   101  }
   102  
   103  func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) {
   104  	taddr, err := netapi.ParseSysAddr(addr)
   105  	if err != nil {
   106  		return 0, fmt.Errorf("failed to parse addr: %w", err)
   107  	}
   108  
   109  	w := pool.GetBuffer()
   110  	defer pool.PutBuffer(w)
   111  
   112  	tools.EncodeAddr(taddr, w)
   113  	addrSize := w.Len()
   114  
   115  	b := bytes.NewBuffer(payload)
   116  
   117  	for b.Len() > 0 {
   118  		data := b.Next(MaxPacketSize)
   119  
   120  		w.Truncate(addrSize)
   121  
   122  		binary.Write(w, binary.BigEndian, uint16(len(data)))
   123  
   124  		w.Write(crlf) // crlf
   125  
   126  		w.Write(data)
   127  
   128  		_, err = c.Conn.Write(w.Bytes())
   129  		if err != nil {
   130  			return len(payload) - b.Len() + len(data), fmt.Errorf("write to %v failed: %w", addr, err)
   131  		}
   132  	}
   133  
   134  	return len(payload), nil
   135  }
   136  
   137  func (c *PacketConn) ReadFrom(payload []byte) (n int, _ net.Addr, err error) {
   138  	c.mux.Lock()
   139  	defer c.mux.Unlock()
   140  
   141  	if c.remain > 0 {
   142  		z := min(len(payload), c.remain)
   143  
   144  		n, err := c.Conn.Read(payload[:z])
   145  		if err != nil {
   146  			return 0, c.addr, err
   147  		}
   148  
   149  		c.remain -= n
   150  		return n, c.addr, err
   151  	}
   152  
   153  	addr, err := tools.ResolveAddr(c.Conn)
   154  	if err != nil {
   155  		return 0, nil, fmt.Errorf("failed to resolve udp packet addr: %w", err)
   156  	}
   157  
   158  	c.addr = addr.Address(statistic.Type_udp)
   159  
   160  	var length uint16
   161  	if err = binary.Read(c.Conn, binary.BigEndian, &length); err != nil {
   162  		return 0, nil, fmt.Errorf("read length failed: %w", err)
   163  	}
   164  	if length > MaxPacketSize {
   165  		return 0, nil, fmt.Errorf("invalid packet size")
   166  	}
   167  
   168  	crlf := [2]byte{}
   169  	if _, err := io.ReadFull(c.Conn, crlf[:]); err != nil {
   170  		return 0, nil, fmt.Errorf("read crlf failed: %w", err)
   171  	}
   172  
   173  	plen := min(int(length), len(payload))
   174  	c.remain = int(length) - plen
   175  
   176  	n, err = io.ReadFull(c.Conn, payload[:plen])
   177  	return n, c.addr, err
   178  }
   179  
   180  func hexSha224(data []byte) []byte {
   181  	buf := make([]byte, 56)
   182  	hash := sha256.New224()
   183  	hash.Write(data)
   184  	hex.Encode(buf, hash.Sum(nil))
   185  	return buf
   186  }