github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/wireguard/bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"strconv"
    10  	"sync"
    11  
    12  	"golang.zx2c4.com/wireguard/conn"
    13  
    14  	xnet "github.com/xtls/xray-core/common/net"
    15  	"github.com/xtls/xray-core/features/dns"
    16  	"github.com/xtls/xray-core/transport/internet"
    17  )
    18  
    19  type netReadInfo struct {
    20  	// status
    21  	waiter sync.WaitGroup
    22  	// param
    23  	buff []byte
    24  	// result
    25  	bytes    int
    26  	endpoint conn.Endpoint
    27  	err      error
    28  }
    29  
    30  // reduce duplicated code
    31  type netBind struct {
    32  	dns       dns.Client
    33  	dnsOption dns.IPOption
    34  
    35  	workers   int
    36  	readQueue chan *netReadInfo
    37  }
    38  
    39  // SetMark implements conn.Bind
    40  func (bind *netBind) SetMark(mark uint32) error {
    41  	return nil
    42  }
    43  
    44  // ParseEndpoint implements conn.Bind
    45  func (n *netBind) ParseEndpoint(s string) (conn.Endpoint, error) {
    46  	ipStr, port, err := net.SplitHostPort(s)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	portNum, err := strconv.Atoi(port)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	addr := xnet.ParseAddress(ipStr)
    56  	if addr.Family() == xnet.AddressFamilyDomain {
    57  		ips, err := n.dns.LookupIP(addr.Domain(), n.dnsOption)
    58  		if err != nil {
    59  			return nil, err
    60  		} else if len(ips) == 0 {
    61  			return nil, dns.ErrEmptyResponse
    62  		}
    63  		addr = xnet.IPAddress(ips[0])
    64  	}
    65  
    66  	dst := xnet.Destination{
    67  		Address: addr,
    68  		Port:    xnet.Port(portNum),
    69  		Network: xnet.Network_UDP,
    70  	}
    71  
    72  	return &netEndpoint{
    73  		dst: dst,
    74  	}, nil
    75  }
    76  
    77  // BatchSize implements conn.Bind
    78  func (bind *netBind) BatchSize() int {
    79  	return 1
    80  }
    81  
    82  // Open implements conn.Bind
    83  func (bind *netBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
    84  	bind.readQueue = make(chan *netReadInfo)
    85  
    86  	fun := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
    87  		defer func() {
    88  			if r := recover(); r != nil {
    89  				n = 0
    90  				err = errors.New("channel closed")
    91  			}
    92  		}()
    93  
    94  		r := &netReadInfo{
    95  			buff: bufs[0],
    96  		}
    97  		r.waiter.Add(1)
    98  		bind.readQueue <- r
    99  		r.waiter.Wait() // wait read goroutine done, or we will miss the result
   100  		sizes[0], eps[0] = r.bytes, r.endpoint
   101  		return 1, r.err
   102  	}
   103  	workers := bind.workers
   104  	if workers <= 0 {
   105  		workers = 1
   106  	}
   107  	arr := make([]conn.ReceiveFunc, workers)
   108  	for i := 0; i < workers; i++ {
   109  		arr[i] = fun
   110  	}
   111  
   112  	return arr, uint16(uport), nil
   113  }
   114  
   115  // Close implements conn.Bind
   116  func (bind *netBind) Close() error {
   117  	if bind.readQueue != nil {
   118  		close(bind.readQueue)
   119  	}
   120  	return nil
   121  }
   122  
   123  type netBindClient struct {
   124  	netBind
   125  
   126  	dialer   internet.Dialer
   127  	reserved []byte
   128  }
   129  
   130  func (bind *netBindClient) connectTo(endpoint *netEndpoint) error {
   131  	c, err := bind.dialer.Dial(context.Background(), endpoint.dst)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	endpoint.conn = c
   136  
   137  	go func(readQueue <-chan *netReadInfo, endpoint *netEndpoint) {
   138  		for {
   139  			v, ok := <-readQueue
   140  			if !ok {
   141  				return
   142  			}
   143  			i, err := c.Read(v.buff)
   144  
   145  			if i > 3 {
   146  				v.buff[1] = 0
   147  				v.buff[2] = 0
   148  				v.buff[3] = 0
   149  			}
   150  
   151  			v.bytes = i
   152  			v.endpoint = endpoint
   153  			v.err = err
   154  			v.waiter.Done()
   155  			if err != nil && errors.Is(err, io.EOF) {
   156  				endpoint.conn = nil
   157  				return
   158  			}
   159  		}
   160  	}(bind.readQueue, endpoint)
   161  
   162  	return nil
   163  }
   164  
   165  func (bind *netBindClient) Send(buff [][]byte, endpoint conn.Endpoint) error {
   166  	var err error
   167  
   168  	nend, ok := endpoint.(*netEndpoint)
   169  	if !ok {
   170  		return conn.ErrWrongEndpointType
   171  	}
   172  
   173  	if nend.conn == nil {
   174  		err = bind.connectTo(nend)
   175  		if err != nil {
   176  			return err
   177  		}
   178  	}
   179  
   180  	for _, buff := range buff {
   181  		if len(buff) > 3 && len(bind.reserved) == 3 {
   182  			copy(buff[1:], bind.reserved)
   183  		}
   184  		if _, err = nend.conn.Write(buff); err != nil {
   185  			return err
   186  		}
   187  	}
   188  	return nil
   189  }
   190  
   191  type netBindServer struct {
   192  	netBind
   193  }
   194  
   195  func (bind *netBindServer) Send(buff [][]byte, endpoint conn.Endpoint) error {
   196  	var err error
   197  
   198  	nend, ok := endpoint.(*netEndpoint)
   199  	if !ok {
   200  		return conn.ErrWrongEndpointType
   201  	}
   202  
   203  	if nend.conn == nil {
   204  		return newError("connection not open yet")
   205  	}
   206  
   207  	for _, buff := range buff {
   208  		if _, err = nend.conn.Write(buff); err != nil {
   209  			return err
   210  		}
   211  	}
   212  
   213  	return err
   214  }
   215  
   216  type netEndpoint struct {
   217  	dst  xnet.Destination
   218  	conn net.Conn
   219  }
   220  
   221  func (netEndpoint) ClearSrc() {}
   222  
   223  func (e netEndpoint) DstIP() netip.Addr {
   224  	return netip.Addr{}
   225  }
   226  
   227  func (e netEndpoint) SrcIP() netip.Addr {
   228  	return netip.Addr{}
   229  }
   230  
   231  func (e netEndpoint) DstToBytes() []byte {
   232  	var dat []byte
   233  	if e.dst.Address.Family().IsIPv4() {
   234  		dat = e.dst.Address.IP().To4()[:]
   235  	} else {
   236  		dat = e.dst.Address.IP().To16()[:]
   237  	}
   238  	dat = append(dat, byte(e.dst.Port), byte(e.dst.Port>>8))
   239  	return dat
   240  }
   241  
   242  func (e netEndpoint) DstToString() string {
   243  	return e.dst.NetAddr()
   244  }
   245  
   246  func (e netEndpoint) SrcToString() string {
   247  	return ""
   248  }
   249  
   250  func toNetIpAddr(addr xnet.Address) netip.Addr {
   251  	if addr.Family().IsIPv4() {
   252  		ip := addr.IP()
   253  		return netip.AddrFrom4([4]byte{ip[0], ip[1], ip[2], ip[3]})
   254  	} else {
   255  		ip := addr.IP()
   256  		arr := [16]byte{}
   257  		for i := 0; i < 16; i++ {
   258  			arr[i] = ip[i]
   259  		}
   260  		return netip.AddrFrom16(arr)
   261  	}
   262  }