github.com/sagernet/sing-box@v1.9.0-rc.20/transport/wireguard/client_bind.go (about) 1 package wireguard 2 3 import ( 4 "context" 5 "net" 6 "net/netip" 7 "sync" 8 "time" 9 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/bufio" 12 E "github.com/sagernet/sing/common/exceptions" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 "github.com/sagernet/wireguard-go/conn" 16 ) 17 18 var _ conn.Bind = (*ClientBind)(nil) 19 20 type ClientBind struct { 21 ctx context.Context 22 errorHandler E.Handler 23 dialer N.Dialer 24 reservedForEndpoint map[netip.AddrPort][3]uint8 25 connAccess sync.Mutex 26 conn *wireConn 27 done chan struct{} 28 isConnect bool 29 connectAddr netip.AddrPort 30 reserved [3]uint8 31 } 32 33 func NewClientBind(ctx context.Context, errorHandler E.Handler, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind { 34 return &ClientBind{ 35 ctx: ctx, 36 errorHandler: errorHandler, 37 dialer: dialer, 38 reservedForEndpoint: make(map[netip.AddrPort][3]uint8), 39 done: make(chan struct{}), 40 isConnect: isConnect, 41 connectAddr: connectAddr, 42 reserved: reserved, 43 } 44 } 45 46 func (c *ClientBind) connect() (*wireConn, error) { 47 serverConn := c.conn 48 if serverConn != nil { 49 select { 50 case <-serverConn.done: 51 serverConn = nil 52 default: 53 return serverConn, nil 54 } 55 } 56 c.connAccess.Lock() 57 defer c.connAccess.Unlock() 58 serverConn = c.conn 59 if serverConn != nil { 60 select { 61 case <-serverConn.done: 62 serverConn = nil 63 default: 64 return serverConn, nil 65 } 66 } 67 if c.isConnect { 68 udpConn, err := c.dialer.DialContext(c.ctx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr)) 69 if err != nil { 70 return nil, err 71 } 72 c.conn = &wireConn{ 73 PacketConn: bufio.NewUnbindPacketConn(udpConn), 74 done: make(chan struct{}), 75 } 76 } else { 77 udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) 78 if err != nil { 79 return nil, err 80 } 81 c.conn = &wireConn{ 82 PacketConn: bufio.NewPacketConn(udpConn), 83 done: make(chan struct{}), 84 } 85 } 86 return c.conn, nil 87 } 88 89 func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { 90 select { 91 case <-c.done: 92 c.done = make(chan struct{}) 93 default: 94 } 95 return []conn.ReceiveFunc{c.receive}, 0, nil 96 } 97 98 func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) { 99 udpConn, err := c.connect() 100 if err != nil { 101 select { 102 case <-c.done: 103 return 104 default: 105 } 106 c.errorHandler.NewError(context.Background(), E.Cause(err, "connect to server")) 107 err = nil 108 time.Sleep(time.Second) 109 return 110 } 111 n, addr, err := udpConn.ReadFrom(packets[0]) 112 if err != nil { 113 udpConn.Close() 114 select { 115 case <-c.done: 116 default: 117 c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet")) 118 err = nil 119 } 120 return 121 } 122 sizes[0] = n 123 if n > 3 { 124 b := packets[0] 125 common.ClearArray(b[1:4]) 126 } 127 eps[0] = Endpoint(M.AddrPortFromNet(addr)) 128 count = 1 129 return 130 } 131 132 func (c *ClientBind) Close() error { 133 common.Close(common.PtrOrNil(c.conn)) 134 select { 135 case <-c.done: 136 default: 137 close(c.done) 138 } 139 return nil 140 } 141 142 func (c *ClientBind) SetMark(mark uint32) error { 143 return nil 144 } 145 146 func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error { 147 udpConn, err := c.connect() 148 if err != nil { 149 return err 150 } 151 destination := netip.AddrPort(ep.(Endpoint)) 152 for _, b := range bufs { 153 if len(b) > 3 { 154 reserved, loaded := c.reservedForEndpoint[destination] 155 if !loaded { 156 reserved = c.reserved 157 } 158 copy(b[1:4], reserved[:]) 159 } 160 _, err = udpConn.WriteToUDPAddrPort(b, destination) 161 if err != nil { 162 udpConn.Close() 163 return err 164 } 165 } 166 return nil 167 } 168 169 func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) { 170 ap, err := netip.ParseAddrPort(s) 171 if err != nil { 172 return nil, err 173 } 174 return Endpoint(ap), nil 175 } 176 177 func (c *ClientBind) BatchSize() int { 178 return 1 179 } 180 181 func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) { 182 c.reservedForEndpoint[destination] = reserved 183 } 184 185 type wireConn struct { 186 net.PacketConn 187 conn net.Conn 188 access sync.Mutex 189 done chan struct{} 190 } 191 192 func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { 193 if w.conn != nil { 194 return w.conn.Write(b) 195 } 196 return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr()) 197 } 198 199 func (w *wireConn) Close() error { 200 w.access.Lock() 201 defer w.access.Unlock() 202 select { 203 case <-w.done: 204 return net.ErrClosed 205 default: 206 } 207 w.PacketConn.Close() 208 close(w.done) 209 return nil 210 }