github.com/sagernet/sing-box@v1.9.0-rc.20/transport/wireguard/client_bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/bufio"
    12  	E "github.com/sagernet/sing/common/exceptions"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  	"github.com/sagernet/wireguard-go/conn"
    16  )
    17  
    18  var _ conn.Bind = (*ClientBind)(nil)
    19  
    20  type ClientBind struct {
    21  	ctx                 context.Context
    22  	errorHandler        E.Handler
    23  	dialer              N.Dialer
    24  	reservedForEndpoint map[netip.AddrPort][3]uint8
    25  	connAccess          sync.Mutex
    26  	conn                *wireConn
    27  	done                chan struct{}
    28  	isConnect           bool
    29  	connectAddr         netip.AddrPort
    30  	reserved            [3]uint8
    31  }
    32  
    33  func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
    34  	return &ClientBind{
    35  		ctx:                 ctx,
    36  		errorHandler:        errorHandler,
    37  		dialer:              dialer,
    38  		reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
    39  		done:                make(chan struct{}),
    40  		isConnect:           isConnect,
    41  		connectAddr:         connectAddr,
    42  		reserved:            reserved,
    43  	}
    44  }
    45  
    46  func (c *ClientBind) connect() (*wireConn, error) {
    47  	serverConn := c.conn
    48  	if serverConn != nil {
    49  		select {
    50  		case <-serverConn.done:
    51  			serverConn = nil
    52  		default:
    53  			return serverConn, nil
    54  		}
    55  	}
    56  	c.connAccess.Lock()
    57  	defer c.connAccess.Unlock()
    58  	serverConn = c.conn
    59  	if serverConn != nil {
    60  		select {
    61  		case <-serverConn.done:
    62  			serverConn = nil
    63  		default:
    64  			return serverConn, nil
    65  		}
    66  	}
    67  	if c.isConnect {
    68  		udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  		c.conn = &wireConn{
    73  			PacketConn: bufio.NewUnbindPacketConn(udpConn),
    74  			done:       make(chan struct{}),
    75  		}
    76  	} else {
    77  		udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		c.conn = &wireConn{
    82  			PacketConn: bufio.NewPacketConn(udpConn),
    83  			done:       make(chan struct{}),
    84  		}
    85  	}
    86  	return c.conn, nil
    87  }
    88  
    89  func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    90  	select {
    91  	case <-c.done:
    92  		c.done = make(chan struct{})
    93  	default:
    94  	}
    95  	return []conn.ReceiveFunc{c.receive}, 0, nil
    96  }
    97  
    98  func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
    99  	udpConn, err := c.connect()
   100  	if err != nil {
   101  		select {
   102  		case <-c.done:
   103  			return
   104  		default:
   105  		}
   106  		c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server"))
   107  		err = nil
   108  		time.Sleep(time.Second)
   109  		return
   110  	}
   111  	n, addr, err := udpConn.ReadFrom(packets[0])
   112  	if err != nil {
   113  		udpConn.Close()
   114  		select {
   115  		case <-c.done:
   116  		default:
   117  			c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
   118  			err = nil
   119  		}
   120  		return
   121  	}
   122  	sizes[0] = n
   123  	if n > 3 {
   124  		b := packets[0]
   125  		common.ClearArray(b[1:4])
   126  	}
   127  	eps[0] = Endpoint(M.AddrPortFromNet(addr))
   128  	count = 1
   129  	return
   130  }
   131  
   132  func (c *ClientBind) Close() error {
   133  	common.Close(common.PtrOrNil(c.conn))
   134  	select {
   135  	case <-c.done:
   136  	default:
   137  		close(c.done)
   138  	}
   139  	return nil
   140  }
   141  
   142  func (c *ClientBind) SetMark(mark uint32) error {
   143  	return nil
   144  }
   145  
   146  func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
   147  	udpConn, err := c.connect()
   148  	if err != nil {
   149  		return err
   150  	}
   151  	destination := netip.AddrPort(ep.(Endpoint))
   152  	for _, b := range bufs {
   153  		if len(b) > 3 {
   154  			reserved, loaded := c.reservedForEndpoint[destination]
   155  			if !loaded {
   156  				reserved = c.reserved
   157  			}
   158  			copy(b[1:4], reserved[:])
   159  		}
   160  		_, err = udpConn.WriteToUDPAddrPort(b, destination)
   161  		if err != nil {
   162  			udpConn.Close()
   163  			return err
   164  		}
   165  	}
   166  	return nil
   167  }
   168  
   169  func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   170  	ap, err := netip.ParseAddrPort(s)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	return Endpoint(ap), nil
   175  }
   176  
   177  func (c *ClientBind) BatchSize() int {
   178  	return 1
   179  }
   180  
   181  func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
   182  	c.reservedForEndpoint[destination] = reserved
   183  }
   184  
   185  type wireConn struct {
   186  	net.PacketConn
   187  	conn   net.Conn
   188  	access sync.Mutex
   189  	done   chan struct{}
   190  }
   191  
   192  func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
   193  	if w.conn != nil {
   194  		return w.conn.Write(b)
   195  	}
   196  	return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
   197  }
   198  
   199  func (w *wireConn) Close() error {
   200  	w.access.Lock()
   201  	defer w.access.Unlock()
   202  	select {
   203  	case <-w.done:
   204  		return net.ErrClosed
   205  	default:
   206  	}
   207  	w.PacketConn.Close()
   208  	close(w.done)
   209  	return nil
   210  }