github.com/xraypb/xray-core@v1.6.6/proxy/wireguard/bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"strconv"
    10  	"sync"
    11  
    12  	"github.com/sagernet/wireguard-go/conn"
    13  	xnet "github.com/xraypb/xray-core/common/net"
    14  	"github.com/xraypb/xray-core/features/dns"
    15  	"github.com/xraypb/xray-core/transport/internet"
    16  )
    17  
    18  type netReadInfo struct {
    19  	// status
    20  	waiter sync.WaitGroup
    21  	// param
    22  	buff []byte
    23  	// result
    24  	bytes    int
    25  	endpoint conn.Endpoint
    26  	err      error
    27  }
    28  
    29  type netBindClient struct {
    30  	workers   int
    31  	dialer    internet.Dialer
    32  	dns       dns.Client
    33  	dnsOption dns.IPOption
    34  
    35  	readQueue chan *netReadInfo
    36  }
    37  
    38  func (n *netBindClient) ParseEndpoint(s string) (conn.Endpoint, error) {
    39  	ipStr, port, _, err := splitAddrPort(s)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	var addr net.IP
    45  	if IsDomainName(ipStr) {
    46  		ips, err := n.dns.LookupIP(ipStr, n.dnsOption)
    47  		if err != nil {
    48  			return nil, err
    49  		} else if len(ips) == 0 {
    50  			return nil, dns.ErrEmptyResponse
    51  		}
    52  		addr = ips[0]
    53  	} else {
    54  		addr = net.ParseIP(ipStr)
    55  	}
    56  	if addr == nil {
    57  		return nil, errors.New("failed to parse ip: " + ipStr)
    58  	}
    59  
    60  	var ip xnet.Address
    61  	if p4 := addr.To4(); len(p4) == net.IPv4len {
    62  		ip = xnet.IPAddress(p4[:])
    63  	} else {
    64  		ip = xnet.IPAddress(addr[:])
    65  	}
    66  
    67  	dst := xnet.Destination{
    68  		Address: ip,
    69  		Port:    xnet.Port(port),
    70  		Network: xnet.Network_UDP,
    71  	}
    72  
    73  	return &netEndpoint{
    74  		dst: dst,
    75  	}, nil
    76  }
    77  
    78  func (bind *netBindClient) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
    79  	bind.readQueue = make(chan *netReadInfo)
    80  
    81  	fun := func(buff []byte) (cap int, ep conn.Endpoint, err error) {
    82  		defer func() {
    83  			if r := recover(); r != nil {
    84  				cap = 0
    85  				ep = nil
    86  				err = errors.New("channel closed")
    87  			}
    88  		}()
    89  
    90  		r := &netReadInfo{
    91  			buff: buff,
    92  		}
    93  		r.waiter.Add(1)
    94  		bind.readQueue <- r
    95  		r.waiter.Wait() // wait read goroutine done, or we will miss the result
    96  		return r.bytes, r.endpoint, r.err
    97  	}
    98  	workers := bind.workers
    99  	if workers <= 0 {
   100  		workers = 1
   101  	}
   102  	arr := make([]conn.ReceiveFunc, workers)
   103  	for i := 0; i < workers; i++ {
   104  		arr[i] = fun
   105  	}
   106  
   107  	return arr, uint16(uport), nil
   108  }
   109  
   110  func (bind *netBindClient) Close() error {
   111  	if bind.readQueue != nil {
   112  		close(bind.readQueue)
   113  	}
   114  	return nil
   115  }
   116  
   117  func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
   118  	c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	endpoint.conn = c
   123  
   124  	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
   125  		for {
   126  			v, ok := <-readQueue
   127  			if !ok {
   128  				return
   129  			}
   130  			i, err := c.Read(v.buff)
   131  			v.bytes = i
   132  			v.endpoint = endpoint
   133  			v.err = err
   134  			v.waiter.Done()
   135  			if err != nil && errors.Is(err, io.EOF) {
   136  				endpoint.conn = nil
   137  				return
   138  			}
   139  		}
   140  	}(bind.readQueue, endpoint)
   141  
   142  	return nil
   143  }
   144  
   145  func (bind *netBindClient) Send(buff []byte, endpoint conn.Endpoint) error {
   146  	var err error
   147  
   148  	nend, ok := endpoint.(*netEndpoint)
   149  	if !ok {
   150  		return conn.ErrWrongEndpointType
   151  	}
   152  
   153  	if nend.conn == nil {
   154  		err = bind.connectTo(nend)
   155  		if err != nil {
   156  			return err
   157  		}
   158  	}
   159  
   160  	_, err = nend.conn.Write(buff)
   161  
   162  	return err
   163  }
   164  
   165  func (bind *netBindClient) SetMark(mark uint32) error {
   166  	return nil
   167  }
   168  
   169  type netEndpoint struct {
   170  	dst  xnet.Destination
   171  	conn net.Conn
   172  }
   173  
   174  func (netEndpoint) ClearSrc() {}
   175  
   176  func (e netEndpoint) DstIP() netip.Addr {
   177  	return toNetIpAddr(e.dst.Address)
   178  }
   179  
   180  func (e netEndpoint) SrcIP() netip.Addr {
   181  	return netip.Addr{}
   182  }
   183  
   184  func (e netEndpoint) DstToBytes() []byte {
   185  	var dat []byte
   186  	if e.dst.Address.Family().IsIPv4() {
   187  		dat = e.dst.Address.IP().To4()[:]
   188  	} else {
   189  		dat = e.dst.Address.IP().To16()[:]
   190  	}
   191  	dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
   192  	return dat
   193  }
   194  
   195  func (e netEndpoint) DstToString() string {
   196  	return e.dst.NetAddr()
   197  }
   198  
   199  func (e netEndpoint) SrcToString() string {
   200  	return ""
   201  }
   202  
   203  func toNetIpAddr(addr xnet.Address) netip.Addr {
   204  	if addr.Family().IsIPv4() {
   205  		ip := addr.IP()
   206  		return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
   207  	} else {
   208  		ip := addr.IP()
   209  		arr := [16]byte{}
   210  		for i := 0; i < 16; i++ {
   211  			arr[i] = ip[i]
   212  		}
   213  		return netip.AddrFrom16(arr)
   214  	}
   215  }
   216  
   217  func stringsLastIndexByte(s string, b byte) int {
   218  	for i := len(s) - 1; i >= 0; i-- {
   219  		if s[i] == b {
   220  			return i
   221  		}
   222  	}
   223  	return -1
   224  }
   225  
   226  func splitAddrPort(s string) (ip string, port uint16, v6 bool, err error) {
   227  	i := stringsLastIndexByte(s, ':')
   228  	if i == -1 {
   229  		return "", 0, false, errors.New("not an ip:port")
   230  	}
   231  
   232  	ip = s[:i]
   233  	portStr := s[i+1:]
   234  	if len(ip) == 0 {
   235  		return "", 0, false, errors.New("no IP")
   236  	}
   237  	if len(portStr) == 0 {
   238  		return "", 0, false, errors.New("no port")
   239  	}
   240  	port64, err := strconv.ParseUint(portStr, 10, 16)
   241  	if err != nil {
   242  		return "", 0, false, errors.New("invalid port " + strconv.Quote(portStr) + " parsing " + strconv.Quote(s))
   243  	}
   244  	port = uint16(port64)
   245  	if ip[0] == '[' {
   246  		if len(ip) < 2 || ip[len(ip)-1] != ']' {
   247  			return "", 0, false, errors.New("missing ]")
   248  		}
   249  		ip = ip[1 : len(ip)-1]
   250  		v6 = true
   251  	}
   252  
   253  	return ip, port, v6, nil
   254  }