github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/wireguard/bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"net/netip"
     8  	"strconv"
     9  	"sync"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    13  	"github.com/Asutorufa/yuhaiin/pkg/utils/yerror"
    14  	"github.com/tailscale/wireguard-go/conn"
    15  )
    16  
    17  var _ conn.Endpoint = (*Endpoint)(nil)
    18  
    19  type Endpoint netip.AddrPort
    20  
    21  func (e Endpoint) ClearSrc()           {}
    22  func (e Endpoint) SrcToString() string { return "" }
    23  func (e Endpoint) DstToString() string { return (netip.AddrPort)(e).String() }
    24  func (e Endpoint) DstToBytes() []byte  { return yerror.Ignore((netip.AddrPort)(e).MarshalBinary()) }
    25  func (e Endpoint) DstIP() netip.Addr   { return (netip.AddrPort)(e).Addr() }
    26  func (e Endpoint) SrcIP() netip.Addr   { return netip.Addr{} }
    27  
    28  type netBindClient struct {
    29  	mu       sync.Mutex
    30  	conn     net.PacketConn
    31  	reserved []byte
    32  }
    33  
    34  func newNetBindClient(reserved []byte) *netBindClient { return &netBindClient{reserved: reserved} }
    35  
    36  func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
    37  	addrPort, err := netip.ParseAddrPort(s)
    38  	if err == nil {
    39  		return Endpoint(addrPort), nil
    40  	}
    41  
    42  	ipStr, port, err := net.SplitHostPort(s)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	portNum, err := strconv.ParseUint(port, 10, 16)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	ips, err := netapi.Bootstrap.LookupIP(context.TODO(), ipStr)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	ip, ok := netip.AddrFromSlice(ips[0])
    58  	if !ok {
    59  		return nil, errors.New("failed to parse ip: " + ipStr)
    60  	}
    61  
    62  	return Endpoint(netip.AddrPortFrom(ip.Unmap(), uint16(portNum))), nil
    63  }
    64  
    65  func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
    66  	return []conn.ReceiveFunc{bind.receive}, uport, nil
    67  }
    68  
    69  func (bind *netBindClient) Close() error {
    70  	if bind.conn != nil {
    71  		return bind.conn.Close()
    72  	}
    73  	return nil
    74  }
    75  
    76  func (bind *netBindClient) connect() (net.PacketConn, error) {
    77  	conn := bind.conn
    78  	if conn != nil {
    79  		return conn, nil
    80  	}
    81  
    82  	bind.mu.Lock()
    83  	defer bind.mu.Unlock()
    84  
    85  	if bind.conn != nil {
    86  		return bind.conn, nil
    87  	}
    88  
    89  	conn, err := dialer.ListenPacket("udp", "")
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	bind.conn = conn
    95  
    96  	return conn, nil
    97  }
    98  
    99  func (bind *netBindClient) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
   100  	conn, err := bind.connect()
   101  	if err != nil {
   102  		return 0, err
   103  	}
   104  
   105  	n, addr, err := conn.ReadFrom(packets[0])
   106  	if err != nil {
   107  		return 0, err
   108  	}
   109  
   110  	var addrPort netip.AddrPort
   111  	uaddr, ok := addr.(*net.UDPAddr)
   112  	if ok {
   113  		addrPort = uaddr.AddrPort()
   114  	} else {
   115  		naddr, err := netapi.ParseSysAddr(addr)
   116  		if err != nil {
   117  			return 0, err
   118  		}
   119  
   120  		ar := naddr.AddrPort(context.Background())
   121  		if ar.Err != nil {
   122  			return 0, ar.Err
   123  		}
   124  
   125  		addrPort = ar.V
   126  	}
   127  
   128  	eps[0] = Endpoint(addrPort)
   129  	if n > 3 {
   130  		copy(packets[0][1:4], []byte{0, 0, 0})
   131  	}
   132  	sizes[0] = n
   133  
   134  	return 1, nil
   135  }
   136  
   137  func (bind *netBindClient) Send(buffs [][]byte, endpoint conn.Endpoint) error {
   138  	ep, ok := endpoint.(Endpoint)
   139  	if !ok {
   140  		return conn.ErrWrongEndpointType
   141  	}
   142  
   143  	addr := netip.AddrPort(ep)
   144  
   145  	conn, err := bind.connect()
   146  	if err != nil {
   147  		return err
   148  	}
   149  
   150  	for _, buff := range buffs {
   151  		if len(buff) > 3 && len(bind.reserved) == 3 {
   152  			copy(buff[1:], bind.reserved)
   153  		}
   154  
   155  		_, err = conn.WriteTo(buff, net.UDPAddrFromAddrPort(addr))
   156  		if err != nil {
   157  			return err
   158  		}
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  func (bind *netBindClient) SetMark(mark uint32) error { return nil }
   165  func (bind *netBindClient) BatchSize() int            { return 1 }