github.com/sagernet/sing-box@v1.2.7/transport/wireguard/client_bind.go (about)

     1  package wireguard
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"sync"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	M "github.com/sagernet/sing/common/metadata"
    10  	N "github.com/sagernet/sing/common/network"
    11  	"github.com/sagernet/wireguard-go/conn"
    12  )
    13  
    14  var _ conn.Bind = (*ClientBind)(nil)
    15  
    16  type ClientBind struct {
    17  	ctx        context.Context
    18  	dialer     N.Dialer
    19  	peerAddr   M.Socksaddr
    20  	reserved   [3]uint8
    21  	connAccess sync.Mutex
    22  	conn       *wireConn
    23  	done       chan struct{}
    24  }
    25  
    26  func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind {
    27  	return &ClientBind{
    28  		ctx:      ctx,
    29  		dialer:   dialer,
    30  		peerAddr: peerAddr,
    31  		reserved: reserved,
    32  	}
    33  }
    34  
    35  func (c *ClientBind) connect() (*wireConn, error) {
    36  	serverConn := c.conn
    37  	if serverConn != nil {
    38  		select {
    39  		case <-serverConn.done:
    40  			serverConn = nil
    41  		default:
    42  			return serverConn, nil
    43  		}
    44  	}
    45  	c.connAccess.Lock()
    46  	defer c.connAccess.Unlock()
    47  	serverConn = c.conn
    48  	if serverConn != nil {
    49  		select {
    50  		case <-serverConn.done:
    51  			serverConn = nil
    52  		default:
    53  			return serverConn, nil
    54  		}
    55  	}
    56  	udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr)
    57  	if err != nil {
    58  		return nil, &wireError{err}
    59  	}
    60  	c.conn = &wireConn{
    61  		Conn: udpConn,
    62  		done: make(chan struct{}),
    63  	}
    64  	return c.conn, nil
    65  }
    66  
    67  func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
    68  	select {
    69  	case <-c.done:
    70  		err = net.ErrClosed
    71  		return
    72  	default:
    73  	}
    74  	return []conn.ReceiveFunc{c.receive}, 0, nil
    75  }
    76  
    77  func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
    78  	udpConn, err := c.connect()
    79  	if err != nil {
    80  		err = &wireError{err}
    81  		return
    82  	}
    83  	n, err = udpConn.Read(b)
    84  	if err != nil {
    85  		udpConn.Close()
    86  		select {
    87  		case <-c.done:
    88  		default:
    89  			err = &wireError{err}
    90  		}
    91  		return
    92  	}
    93  	if n > 3 {
    94  		b[1] = 0
    95  		b[2] = 0
    96  		b[3] = 0
    97  	}
    98  	ep = Endpoint(c.peerAddr)
    99  	return
   100  }
   101  
   102  func (c *ClientBind) Reset() {
   103  	common.Close(common.PtrOrNil(c.conn))
   104  }
   105  
   106  func (c *ClientBind) Close() error {
   107  	common.Close(common.PtrOrNil(c.conn))
   108  	if c.done == nil {
   109  		c.done = make(chan struct{})
   110  		return nil
   111  	}
   112  	select {
   113  	case <-c.done:
   114  		return net.ErrClosed
   115  	default:
   116  		close(c.done)
   117  	}
   118  	return nil
   119  }
   120  
   121  func (c *ClientBind) SetMark(mark uint32) error {
   122  	return nil
   123  }
   124  
   125  func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
   126  	udpConn, err := c.connect()
   127  	if err != nil {
   128  		return err
   129  	}
   130  	if len(b) > 3 {
   131  		b[1] = c.reserved[0]
   132  		b[2] = c.reserved[1]
   133  		b[3] = c.reserved[2]
   134  	}
   135  	_, err = udpConn.Write(b)
   136  	if err != nil {
   137  		udpConn.Close()
   138  	}
   139  	return err
   140  }
   141  
   142  func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
   143  	return Endpoint(c.peerAddr), nil
   144  }
   145  
   146  func (c *ClientBind) Endpoint() conn.Endpoint {
   147  	return Endpoint(c.peerAddr)
   148  }
   149  
   150  type wireConn struct {
   151  	net.Conn
   152  	access sync.Mutex
   153  	done   chan struct{}
   154  }
   155  
   156  func (w *wireConn) Close() error {
   157  	w.access.Lock()
   158  	defer w.access.Unlock()
   159  	select {
   160  	case <-w.done:
   161  		return net.ErrClosed
   162  	default:
   163  	}
   164  	w.Conn.Close()
   165  	close(w.done)
   166  	return nil
   167  }