github.com/sagernet/sing-box@v1.2.7/transport/wireguard/client_bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "net" 6 "sync" 7 8 "github.com/sagernet/sing/common" 9 M "github.com/sagernet/sing/common/metadata" 10 N "github.com/sagernet/sing/common/network" 11 "github.com/sagernet/wireguard-go/conn" 12 ) 13 14 var _ conn.Bind = (*ClientBind)(nil) 15 16 type ClientBind struct { 17 ctx context.Context 18 dialer N.Dialer 19 peerAddr M.Socksaddr 20 reserved [3]uint8 21 connAccess sync.Mutex 22 conn *wireConn 23 done chan struct{} 24 } 25 26 func NewClientBind(ctx context.Context, dialer N.Dialer, peerAddr M.Socksaddr, reserved [3]uint8) *ClientBind { 27 return &ClientBind{ 28 ctx: ctx, 29 dialer: dialer, 30 peerAddr: peerAddr, 31 reserved: reserved, 32 } 33 } 34 35 func (c *ClientBind) connect() (*wireConn, error) { 36 serverConn := c.conn 37 if serverConn != nil { 38 select { 39 case <-serverConn.done: 40 serverConn = nil 41 default: 42 return serverConn, nil 43 } 44 } 45 c.connAccess.Lock() 46 defer c.connAccess.Unlock() 47 serverConn = c.conn 48 if serverConn != nil { 49 select { 50 case <-serverConn.done: 51 serverConn = nil 52 default: 53 return serverConn, nil 54 } 55 } 56 udpConn, err := c.dialer.DialContext(c.ctx, "udp", c.peerAddr) 57 if err != nil { 58 return nil, &wireError{err} 59 } 60 c.conn = &wireConn{ 61 Conn: udpConn, 62 done: make(chan struct{}), 63 } 64 return c.conn, nil 65 } 66 67 func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 68 select { 69 case <-c.done: 70 err = net.ErrClosed 71 return 72 default: 73 } 74 return []conn.ReceiveFunc{c.receive}, 0, nil 75 } 76 77 func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { 78 udpConn, err := c.connect() 79 if err != nil { 80 err = &wireError{err} 81 return 82 } 83 n, err = udpConn.Read(b) 84 if err != nil { 85 udpConn.Close() 86 select { 87 case <-c.done: 88 default: 89 err = &wireError{err} 90 } 91 return 92 } 93 if n > 3 { 94 b[1] = 0 95 b[2] = 0 96 b[3] = 0 97 } 98 ep = Endpoint(c.peerAddr) 99 return 100 } 101 102 func (c *ClientBind) Reset() { 103 common.Close(common.PtrOrNil(c.conn)) 104 } 105 106 func (c *ClientBind) Close() error { 107 common.Close(common.PtrOrNil(c.conn)) 108 if c.done == nil { 109 c.done = make(chan struct{}) 110 return nil 111 } 112 select { 113 case <-c.done: 114 return net.ErrClosed 115 default: 116 close(c.done) 117 } 118 return nil 119 } 120 121 func (c *ClientBind) SetMark(mark uint32) error { 122 return nil 123 } 124 125 func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error { 126 udpConn, err := c.connect() 127 if err != nil { 128 return err 129 } 130 if len(b) > 3 { 131 b[1] = c.reserved[0] 132 b[2] = c.reserved[1] 133 b[3] = c.reserved[2] 134 } 135 _, err = udpConn.Write(b) 136 if err != nil { 137 udpConn.Close() 138 } 139 return err 140 } 141 142 func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { 143 return Endpoint(c.peerAddr), nil 144 } 145 146 func (c *ClientBind) Endpoint() conn.Endpoint { 147 return Endpoint(c.peerAddr) 148 } 149 150 type wireConn struct { 151 net.Conn 152 access sync.Mutex 153 done chan struct{} 154 } 155 156 func (w *wireConn) Close() error { 157 w.access.Lock() 158 defer w.access.Unlock() 159 select { 160 case <-w.done: 161 return net.ErrClosed 162 default: 163 } 164 w.Conn.Close() 165 close(w.done) 166 return nil 167 }