github.com/moqsien/xraycore@v1.8.5/proxy/wireguard/bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"strconv"
    10  	"sync"
    11  
    12  	"github.com/sagernet/wireguard-go/conn"
    13  	xnet "github.com/moqsien/xraycore/common/net"
    14  	"github.com/moqsien/xraycore/features/dns"
    15  	"github.com/moqsien/xraycore/transport/internet"
    16  )
    17  
    18  type netReadInfo struct {
    19  	// status
    20  	waiter sync.WaitGroup
    21  	// param
    22  	buff [][]byte
    23  	// result
    24  	bytes    int
    25  	endpoint conn.Endpoint
    26  	err      error
    27  }
    28  
    29  type netBindClient struct {
    30  	workers   int
    31  	dialer    internet.Dialer
    32  	dns       dns.Client
    33  	dnsOption dns.IPOption
    34  	reserved  []byte
    35  
    36  	readQueue chan *netReadInfo
    37  }
    38  
    39  func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
    40  	ipStr, port, _, err := splitAddrPort(s)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	var addr net.IP
    46  	if IsDomainName(ipStr) {
    47  		ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
    48  		if err != nil {
    49  			return nil, err
    50  		} else if len(ips) == 0 {
    51  			return nil, dns.ErrEmptyResponse
    52  		}
    53  		addr = ips[0]
    54  	} else {
    55  		addr = net.ParseIP(ipStr)
    56  	}
    57  	if addr == nil {
    58  		return nil, errors.New("failed to parse ip: " + ipStr)
    59  	}
    60  
    61  	var ip xnet.Address
    62  	if p4 := addr.To4(); len(p4) == net.IPv4len {
    63  		ip = xnet.IPAddress(p4[:])
    64  	} else {
    65  		ip = xnet.IPAddress(addr[:])
    66  	}
    67  
    68  	dst := xnet.Destination{
    69  		Address: ip,
    70  		Port:    xnet.Port(port),
    71  		Network: xnet.Network_UDP,
    72  	}
    73  
    74  	return &netEndpoint{
    75  		dst: dst,
    76  	}, nil
    77  }
    78  
    79  func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
    80  	bind.readQueue = make(chan *netReadInfo)
    81  
    82  	fun := func(buff [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
    83  		defer func() {
    84  			if r := recover(); r != nil {
    85  				count = 0
    86  				err = errors.New("channel closed")
    87  			}
    88  		}()
    89  
    90  		r := &netReadInfo{
    91  			buff: buff,
    92  		}
    93  		r.waiter.Add(1)
    94  		bind.readQueue <- r
    95  		r.waiter.Wait() // wait read goroutine done, or we will miss the result
    96  		return r.bytes, r.err
    97  	}
    98  	workers := bind.workers
    99  	if workers <= 0 {
   100  		workers = 1
   101  	}
   102  	arr := make([]conn.ReceiveFunc, workers)
   103  	for i := 0; i < workers; i++ {
   104  		arr[i] = fun
   105  	}
   106  
   107  	return arr, uint16(uport), nil
   108  }
   109  
   110  func (bind *netBindClient) Close() error {
   111  	if bind.readQueue != nil {
   112  		close(bind.readQueue)
   113  	}
   114  	return nil
   115  }
   116  
   117  func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
   118  	c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	endpoint.conn = c
   123  
   124  	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
   125  		for {
   126  			v, ok := <-readQueue
   127  			if !ok {
   128  				return
   129  			}
   130  			b := v.buff[0]
   131  			i, err := c.Read(b)
   132  
   133  			if i > 3 {
   134  				b[1] = 0
   135  				b[2] = 0
   136  				b[3] = 0
   137  			}
   138  
   139  			v.bytes = i
   140  			v.endpoint = endpoint
   141  			v.err = err
   142  			v.waiter.Done()
   143  			if err != nil && errors.Is(err, io.EOF) {
   144  				endpoint.conn = nil
   145  				return
   146  			}
   147  		}
   148  	}(bind.readQueue, endpoint)
   149  
   150  	return nil
   151  }
   152  
   153  func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
   154  	var err error
   155  
   156  	nend, ok := endpoint.(*netEndpoint)
   157  	if !ok {
   158  		return conn.ErrWrongEndpointType
   159  	}
   160  
   161  	if nend.conn == nil {
   162  		err = bind.connectTo(nend)
   163  		if err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	if len(buff[0]) > 3 && len(bind.reserved) == 3 {
   169  		copy(buff[0][1:], bind.reserved)
   170  	}
   171  
   172  	_, err = nend.conn.Write(buff[0])
   173  
   174  	return err
   175  }
   176  
   177  func (bind *netBindClient) SetMark(mark uint32) error {
   178  	return nil
   179  }
   180  
   181  func (bind *netBindClient) BatchSize() int {
   182  	return 1
   183  }
   184  
   185  type netEndpoint struct {
   186  	dst  xnet.Destination
   187  	conn net.Conn
   188  }
   189  
   190  func (netEndpoint) ClearSrc() {}
   191  
   192  func (e netEndpoint) DstIP() netip.Addr {
   193  	return toNetIpAddr(e.dst.Address)
   194  }
   195  
   196  func (e netEndpoint) SrcIP() netip.Addr {
   197  	return netip.Addr{}
   198  }
   199  
   200  func (e netEndpoint) DstToBytes() []byte {
   201  	var dat []byte
   202  	if e.dst.Address.Family().IsIPv4() {
   203  		dat = e.dst.Address.IP().To4()[:]
   204  	} else {
   205  		dat = e.dst.Address.IP().To16()[:]
   206  	}
   207  	dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
   208  	return dat
   209  }
   210  
   211  func (e netEndpoint) DstToString() string {
   212  	return e.dst.NetAddr()
   213  }
   214  
   215  func (e netEndpoint) SrcToString() string {
   216  	return ""
   217  }
   218  
   219  func toNetIpAddr(addr xnet.Address) netip.Addr {
   220  	if addr.Family().IsIPv4() {
   221  		ip := addr.IP()
   222  		return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
   223  	} else {
   224  		ip := addr.IP()
   225  		arr := [16]byte{}
   226  		for i := 0; i < 16; i++ {
   227  			arr[i] = ip[i]
   228  		}
   229  		return netip.AddrFrom16(arr)
   230  	}
   231  }
   232  
   233  func stringsLastIndexByte(s string, b byte) int {
   234  	for i := len(s) - 1; i >= 0; i-- {
   235  		if s[i] == b {
   236  			return i
   237  		}
   238  	}
   239  	return -1
   240  }
   241  
   242  func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
   243  	i := stringsLastIndexByte(s, ':')
   244  	if i == -1 {
   245  		return "", 0, false, errors.New("not an ip:port")
   246  	}
   247  
   248  	ip = s[:i]
   249  	portStr := s[i+1:]
   250  	if len(ip) == 0 {
   251  		return "", 0, false, errors.New("no IP")
   252  	}
   253  	if len(portStr) == 0 {
   254  		return "", 0, false, errors.New("no port")
   255  	}
   256  	port64, err := strconv.ParseUint(portStr, 10, 16)
   257  	if err != nil {
   258  		return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
   259  	}
   260  	port = uint16(port64)
   261  	if ip[0] == '[' {
   262  		if len(ip) < 2 || ip[len(ip)-1] != ']' {
   263  			return "", 0, false, errors.New("missing ]")
   264  		}
   265  		ip = ip[1 : len(ip)-1]
   266  		v6 = true
   267  	}
   268  
   269  	return ip, port, v6, nil
   270  }