github.com/xraypb/Xray-core@v1.8.1/proxy/wireguard/bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"strconv"
    10  	"sync"
    11  
    12  	"github.com/sagernet/wireguard-go/conn"
    13  	xnet "github.com/xraypb/Xray-core/common/net"
    14  	"github.com/xraypb/Xray-core/features/dns"
    15  	"github.com/xraypb/Xray-core/transport/internet"
    16  )
    17  
    18  type netReadInfo struct {
    19  	// status
    20  	waiter sync.WaitGroup
    21  	// param
    22  	buff []byte
    23  	// result
    24  	bytes    int
    25  	endpoint conn.Endpoint
    26  	err      error
    27  }
    28  
    29  type netBindClient struct {
    30  	workers   int
    31  	dialer    internet.Dialer
    32  	dns       dns.Client
    33  	dnsOption dns.IPOption
    34  	reserved  []byte
    35  
    36  	readQueue chan *netReadInfo
    37  }
    38  
    39  func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
    40  	ipStr, port, _, err := splitAddrPort(s)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	var addr net.IP
    46  	if IsDomainName(ipStr) {
    47  		ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
    48  		if err != nil {
    49  			return nil, err
    50  		} else if len(ips) == 0 {
    51  			return nil, dns.ErrEmptyResponse
    52  		}
    53  		addr = ips[0]
    54  	} else {
    55  		addr = net.ParseIP(ipStr)
    56  	}
    57  	if addr == nil {
    58  		return nil, errors.New("failed to parse ip: " + ipStr)
    59  	}
    60  
    61  	var ip xnet.Address
    62  	if p4 := addr.To4(); len(p4) == net.IPv4len {
    63  		ip = xnet.IPAddress(p4[:])
    64  	} else {
    65  		ip = xnet.IPAddress(addr[:])
    66  	}
    67  
    68  	dst := xnet.Destination{
    69  		Address: ip,
    70  		Port:    xnet.Port(port),
    71  		Network: xnet.Network_UDP,
    72  	}
    73  
    74  	return &netEndpoint{
    75  		dst: dst,
    76  	}, nil
    77  }
    78  
    79  func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
    80  	bind.readQueue = make(chan *netReadInfo)
    81  
    82  	fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
    83  		defer func() {
    84  			if r := recover(); r != nil {
    85  				cap = 0
    86  				ep = nil
    87  				err = errors.New("channel closed")
    88  			}
    89  		}()
    90  
    91  		r := &netReadInfo{
    92  			buff: buff,
    93  		}
    94  		r.waiter.Add(1)
    95  		bind.readQueue <- r
    96  		r.waiter.Wait() // wait read goroutine done, or we will miss the result
    97  		return r.bytes, r.endpoint, r.err
    98  	}
    99  	workers := bind.workers
   100  	if workers <= 0 {
   101  		workers = 1
   102  	}
   103  	arr := make([]conn.ReceiveFunc, workers)
   104  	for i := 0; i < workers; i++ {
   105  		arr[i] = fun
   106  	}
   107  
   108  	return arr, uint16(uport), nil
   109  }
   110  
   111  func (bind *netBindClient) Close() error {
   112  	if bind.readQueue != nil {
   113  		close(bind.readQueue)
   114  	}
   115  	return nil
   116  }
   117  
   118  func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
   119  	c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	endpoint.conn = c
   124  
   125  	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
   126  		for {
   127  			v, ok := <-readQueue
   128  			if !ok {
   129  				return
   130  			}
   131  			i, err := c.Read(v.buff)
   132  
   133  			if i > 3 {
   134  				v.buff[1] = 0
   135  				v.buff[2] = 0
   136  				v.buff[3] = 0
   137  			}
   138  
   139  			v.bytes = i
   140  			v.endpoint = endpoint
   141  			v.err = err
   142  			v.waiter.Done()
   143  			if err != nil && errors.Is(err, io.EOF) {
   144  				endpoint.conn = nil
   145  				return
   146  			}
   147  		}
   148  	}(bind.readQueue, endpoint)
   149  
   150  	return nil
   151  }
   152  
   153  func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
   154  	var err error
   155  
   156  	nend, ok := endpoint.(*netEndpoint)
   157  	if !ok {
   158  		return conn.ErrWrongEndpointType
   159  	}
   160  
   161  	if nend.conn == nil {
   162  		err = bind.connectTo(nend)
   163  		if err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	if len(buff) > 3 && len(bind.reserved) == 3 {
   169  		copy(buff[1:], bind.reserved)
   170  	}
   171  
   172  	_, err = nend.conn.Write(buff)
   173  
   174  	return err
   175  }
   176  
   177  func (bind *netBindClient) SetMark(mark uint32) error {
   178  	return nil
   179  }
   180  
   181  type netEndpoint struct {
   182  	dst  xnet.Destination
   183  	conn net.Conn
   184  }
   185  
   186  func (netEndpoint) ClearSrc() {}
   187  
   188  func (e netEndpoint) DstIP() netip.Addr {
   189  	return toNetIpAddr(e.dst.Address)
   190  }
   191  
   192  func (e netEndpoint) SrcIP() netip.Addr {
   193  	return netip.Addr{}
   194  }
   195  
   196  func (e netEndpoint) DstToBytes() []byte {
   197  	var dat []byte
   198  	if e.dst.Address.Family().IsIPv4() {
   199  		dat = e.dst.Address.IP().To4()[:]
   200  	} else {
   201  		dat = e.dst.Address.IP().To16()[:]
   202  	}
   203  	dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
   204  	return dat
   205  }
   206  
   207  func (e netEndpoint) DstToString() string {
   208  	return e.dst.NetAddr()
   209  }
   210  
   211  func (e netEndpoint) SrcToString() string {
   212  	return ""
   213  }
   214  
   215  func toNetIpAddr(addr xnet.Address) netip.Addr {
   216  	if addr.Family().IsIPv4() {
   217  		ip := addr.IP()
   218  		return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
   219  	} else {
   220  		ip := addr.IP()
   221  		arr := [16]byte{}
   222  		for i := 0; i < 16; i++ {
   223  			arr[i] = ip[i]
   224  		}
   225  		return netip.AddrFrom16(arr)
   226  	}
   227  }
   228  
   229  func stringsLastIndexByte(s string, b byte) int {
   230  	for i := len(s) - 1; i >= 0; i-- {
   231  		if s[i] == b {
   232  			return i
   233  		}
   234  	}
   235  	return -1
   236  }
   237  
   238  func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
   239  	i := stringsLastIndexByte(s, ':')
   240  	if i == -1 {
   241  		return "", 0, false, errors.New("not an ip:port")
   242  	}
   243  
   244  	ip = s[:i]
   245  	portStr := s[i+1:]
   246  	if len(ip) == 0 {
   247  		return "", 0, false, errors.New("no IP")
   248  	}
   249  	if len(portStr) == 0 {
   250  		return "", 0, false, errors.New("no port")
   251  	}
   252  	port64, err := strconv.ParseUint(portStr, 10, 16)
   253  	if err != nil {
   254  		return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
   255  	}
   256  	port = uint16(port64)
   257  	if ip[0] == '[' {
   258  		if len(ip) < 2 || ip[len(ip)-1] != ']' {
   259  			return "", 0, false, errors.New("missing ]")
   260  		}
   261  		ip = ip[1 : len(ip)-1]
   262  		v6 = true
   263  	}
   264  
   265  	return ip, port, v6, nil
   266  }