github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/yuubinsya/client.go (about)

     1  package yuubinsya
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"sync"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/socks5/tools"
    16  	websocket "github.com/Asutorufa/yuhaiin/pkg/net/proxy/websocket/x"
    17  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/yuubinsya/types"
    18  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    19  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    20  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    21  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    22  )
    23  
    24  type client struct {
    25  	netapi.Proxy
    26  
    27  	overTCP bool
    28  
    29  	handshaker types.Handshaker
    30  	packetAuth types.Auth
    31  }
    32  
    33  func init() {
    34  	point.RegisterProtocol(NewClient)
    35  }
    36  
    37  func NewClient(config *protocol.Protocol_Yuubinsya) point.WrapProxy {
    38  	return func(dialer netapi.Proxy) (netapi.Proxy, error) {
    39  		auth, err := NewAuth(config.Yuubinsya.GetUdpEncrypt(), []byte(config.Yuubinsya.Password))
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  
    44  		c := &client{
    45  			dialer,
    46  			config.Yuubinsya.UdpOverStream,
    47  			NewHandshaker(
    48  				false,
    49  				config.Yuubinsya.GetTcpEncrypt(),
    50  				[]byte(config.Yuubinsya.Password),
    51  			),
    52  			auth,
    53  		}
    54  
    55  		return c, nil
    56  	}
    57  }
    58  
    59  func (c *client) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
    60  	conn, err := c.Proxy.Conn(ctx, addr)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	hconn, err := c.handshaker.Handshake(conn)
    66  	if err != nil {
    67  		conn.Close()
    68  		return nil, err
    69  	}
    70  
    71  	return newConn(hconn, addr, c.handshaker), nil
    72  }
    73  
    74  func (c *client) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
    75  	if !c.overTCP {
    76  		packet, err := c.Proxy.PacketConn(ctx, addr)
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  
    81  		return NewAuthPacketConn(packet).WithTarget(addr).WithAuth(c.packetAuth).WithPrefix(true), nil
    82  	}
    83  
    84  	conn, err := c.Proxy.Conn(ctx, addr)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	hconn, err := c.handshaker.Handshake(conn)
    90  	if err != nil {
    91  		conn.Close()
    92  		return nil, err
    93  	}
    94  	return newPacketConn(hconn, c.handshaker, false), nil
    95  }
    96  
    97  type PacketConn struct {
    98  	headerWrote bool
    99  	remain      int
   100  
   101  	net.Conn
   102  
   103  	handshaker types.Handshaker
   104  	addr       netapi.Address
   105  
   106  	hmux sync.Mutex
   107  	rmux sync.Mutex
   108  
   109  	r *bufio.Reader
   110  }
   111  
   112  func newPacketConn(conn net.Conn, handshaker types.Handshaker, server bool) *PacketConn {
   113  	return &PacketConn{
   114  		Conn:        conn,
   115  		handshaker:  handshaker,
   116  		headerWrote: server,
   117  		r:           websocket.NewBufioReader(conn),
   118  	}
   119  }
   120  
   121  func (c *PacketConn) WriteTo(payload []byte, addr net.Addr) (int, error) {
   122  	taddr, err := netapi.ParseSysAddr(addr)
   123  	if err != nil {
   124  		return 0, fmt.Errorf("failed to parse addr: %w", err)
   125  	}
   126  
   127  	s5Addr := tools.ParseAddr(taddr)
   128  	defer s5Addr.Free()
   129  
   130  	w := pool.GetBuffer()
   131  	defer pool.PutBuffer(w)
   132  
   133  	if !c.headerWrote {
   134  		c.hmux.Lock()
   135  		if !c.headerWrote {
   136  			c.handshaker.EncodeHeader(types.UDP, w, netapi.EmptyAddr)
   137  			defer func() {
   138  				c.headerWrote = true
   139  				c.hmux.Unlock()
   140  			}()
   141  		} else {
   142  			c.hmux.Unlock()
   143  		}
   144  	}
   145  
   146  	b := bytes.NewBuffer(payload)
   147  
   148  	for b.Len() > 0 {
   149  		data := b.Next(nat.MaxSegmentSize)
   150  		w.Write(s5Addr.Bytes.Bytes())
   151  		_ = binary.Write(w, binary.BigEndian, uint16(len(data)))
   152  		w.Write(data)
   153  
   154  		n, err := c.Conn.Write(w.Bytes())
   155  
   156  		w.Reset()
   157  
   158  		if err != nil {
   159  			return len(payload) - b.Len() + len(data) - n, fmt.Errorf("write to %v failed: %w", addr, err)
   160  		}
   161  	}
   162  
   163  	return len(payload), nil
   164  }
   165  
   166  func (c *PacketConn) ReadFrom(payload []byte) (n int, _ net.Addr, err error) {
   167  	c.rmux.Lock()
   168  	defer c.rmux.Unlock()
   169  
   170  	if c.remain > 0 {
   171  		n, err := c.r.Read(payload[:min(len(payload), c.remain)])
   172  		c.remain -= n
   173  		return n, c.addr, err
   174  	}
   175  
   176  	addr, err := tools.ResolveAddr(c.r)
   177  	if err != nil {
   178  		return 0, nil, fmt.Errorf("failed to resolve udp packet addr: %w", err)
   179  	}
   180  
   181  	c.addr = addr.Address(statistic.Type_udp)
   182  
   183  	lengthBytes, err := c.r.Peek(2)
   184  	if err != nil {
   185  		return 0, nil, fmt.Errorf("read length failed: %w", err)
   186  	}
   187  
   188  	_, _ = c.r.Discard(2)
   189  
   190  	length := binary.BigEndian.Uint16(lengthBytes)
   191  
   192  	readlen := min(len(payload), int(length))
   193  	c.remain = int(length) - readlen
   194  
   195  	n, err = io.ReadFull(c.r, payload[:readlen])
   196  	return n, c.addr, err
   197  }
   198  
   199  type Conn struct {
   200  	headerWrote bool
   201  
   202  	net.Conn
   203  
   204  	addr       netapi.Address
   205  	handshaker types.Handshaker
   206  }
   207  
   208  func newConn(con net.Conn, addr netapi.Address, handshaker types.Handshaker) net.Conn {
   209  	return &Conn{
   210  		Conn:       con,
   211  		addr:       addr,
   212  		handshaker: handshaker,
   213  	}
   214  }
   215  
   216  func (c *Conn) Write(b []byte) (int, error) {
   217  	if c.headerWrote {
   218  		return c.Conn.Write(b)
   219  	}
   220  
   221  	c.headerWrote = true
   222  
   223  	buf := pool.GetBytesWriter(pool.DefaultSize + len(b))
   224  	defer buf.Free()
   225  
   226  	c.handshaker.EncodeHeader(types.TCP, buf, c.addr)
   227  	_, _ = buf.Write(b)
   228  
   229  	if n, err := c.Conn.Write(buf.Bytes()); err != nil {
   230  		return n, err
   231  	}
   232  
   233  	return len(b), nil
   234  }