github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/client_conn.go (about) 1 package mux 2 3 import ( 4 "encoding/binary" 5 "io" 6 "net" 7 "sync" 8 9 "github.com/sagernet/sing/common" 10 "github.com/sagernet/sing/common/buf" 11 E "github.com/sagernet/sing/common/exceptions" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 ) 15 16 type clientConn struct { 17 net.Conn 18 destination M.Socksaddr 19 requestWritten bool 20 responseRead bool 21 } 22 23 func (c *clientConn) NeedHandshake() bool { 24 return !c.requestWritten 25 } 26 27 func (c *clientConn) readResponse() error { 28 response, err := ReadStreamResponse(c.Conn) 29 if err != nil { 30 return err 31 } 32 if response.Status == statusError { 33 return E.New("remote error: ", response.Message) 34 } 35 return nil 36 } 37 38 func (c *clientConn) Read(b []byte) (n int, err error) { 39 if !c.responseRead { 40 err = c.readResponse() 41 if err != nil { 42 return 43 } 44 c.responseRead = true 45 } 46 return c.Conn.Read(b) 47 } 48 49 func (c *clientConn) Write(b []byte) (n int, err error) { 50 if c.requestWritten { 51 return c.Conn.Write(b) 52 } 53 request := StreamRequest{ 54 Network: N.NetworkTCP, 55 Destination: c.destination, 56 } 57 buffer := buf.NewSize(streamRequestLen(request) + len(b)) 58 defer buffer.Release() 59 err = EncodeStreamRequest(request, buffer) 60 if err != nil { 61 return 62 } 63 buffer.Write(b) 64 _, err = c.Conn.Write(buffer.Bytes()) 65 if err != nil { 66 return 67 } 68 c.requestWritten = true 69 return len(b), nil 70 } 71 72 func (c *clientConn) LocalAddr() net.Addr { 73 return c.Conn.LocalAddr() 74 } 75 76 func (c *clientConn) RemoteAddr() net.Addr { 77 return c.destination.TCPAddr() 78 } 79 80 func (c *clientConn) ReaderReplaceable() bool { 81 return c.responseRead 82 } 83 84 func (c *clientConn) WriterReplaceable() bool { 85 return c.requestWritten 86 } 87 88 func (c *clientConn) NeedAdditionalReadDeadline() bool { 89 return true 90 } 91 92 func (c *clientConn) Upstream() any { 93 return c.Conn 94 } 95 96 var _ N.NetPacketConn = (*clientPacketConn)(nil) 97 98 type clientPacketConn struct { 99 N.AbstractConn 100 conn N.ExtendedConn 101 access sync.Mutex 102 destination M.Socksaddr 103 requestWritten bool 104 responseRead bool 105 readWaitOptions N.ReadWaitOptions 106 } 107 108 func (c *clientPacketConn) NeedHandshake() bool { 109 return !c.requestWritten 110 } 111 112 func (c *clientPacketConn) readResponse() error { 113 response, err := ReadStreamResponse(c.conn) 114 if err != nil { 115 return err 116 } 117 if response.Status == statusError { 118 return E.New("remote error: ", response.Message) 119 } 120 return nil 121 } 122 123 func (c *clientPacketConn) Read(b []byte) (n int, err error) { 124 if !c.responseRead { 125 err = c.readResponse() 126 if err != nil { 127 return 128 } 129 c.responseRead = true 130 } 131 var length uint16 132 err = binary.Read(c.conn, binary.BigEndian, &length) 133 if err != nil { 134 return 135 } 136 if cap(b) < int(length) { 137 return 0, io.ErrShortBuffer 138 } 139 return io.ReadFull(c.conn, b[:length]) 140 } 141 142 func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) { 143 request := StreamRequest{ 144 Network: N.NetworkUDP, 145 Destination: c.destination, 146 } 147 rLen := streamRequestLen(request) 148 if len(payload) > 0 { 149 rLen += 2 + len(payload) 150 } 151 buffer := buf.NewSize(rLen) 152 defer buffer.Release() 153 err = EncodeStreamRequest(request, buffer) 154 if err != nil { 155 return 156 } 157 if len(payload) > 0 { 158 common.Must( 159 binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 160 common.Error(buffer.Write(payload)), 161 ) 162 } 163 _, err = c.conn.Write(buffer.Bytes()) 164 if err != nil { 165 return 166 } 167 c.requestWritten = true 168 return len(payload), nil 169 } 170 171 func (c *clientPacketConn) Write(b []byte) (n int, err error) { 172 if !c.requestWritten { 173 c.access.Lock() 174 if c.requestWritten { 175 c.access.Unlock() 176 } else { 177 defer c.access.Unlock() 178 return c.writeRequest(b) 179 } 180 } 181 err = binary.Write(c.conn, binary.BigEndian, uint16(len(b))) 182 if err != nil { 183 return 184 } 185 return c.conn.Write(b) 186 } 187 188 func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) { 189 if !c.responseRead { 190 err = c.readResponse() 191 if err != nil { 192 return 193 } 194 c.responseRead = true 195 } 196 var length uint16 197 err = binary.Read(c.conn, binary.BigEndian, &length) 198 if err != nil { 199 return 200 } 201 _, err = buffer.ReadFullFrom(c.conn, int(length)) 202 return 203 } 204 205 func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error { 206 if !c.requestWritten { 207 c.access.Lock() 208 if c.requestWritten { 209 c.access.Unlock() 210 } else { 211 defer c.access.Unlock() 212 defer buffer.Release() 213 return common.Error(c.writeRequest(buffer.Bytes())) 214 } 215 } 216 bLen := buffer.Len() 217 binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen)) 218 return c.conn.WriteBuffer(buffer) 219 } 220 221 func (c *clientPacketConn) FrontHeadroom() int { 222 return 2 223 } 224 225 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 226 if !c.responseRead { 227 err = c.readResponse() 228 if err != nil { 229 return 230 } 231 c.responseRead = true 232 } 233 var length uint16 234 err = binary.Read(c.conn, binary.BigEndian, &length) 235 if err != nil { 236 return 237 } 238 if cap(p) < int(length) { 239 return 0, nil, io.ErrShortBuffer 240 } 241 n, err = io.ReadFull(c.conn, p[:length]) 242 return 243 } 244 245 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 246 if !c.requestWritten { 247 c.access.Lock() 248 if c.requestWritten { 249 c.access.Unlock() 250 } else { 251 defer c.access.Unlock() 252 return c.writeRequest(p) 253 } 254 } 255 err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) 256 if err != nil { 257 return 258 } 259 return c.conn.Write(p) 260 } 261 262 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 263 err = c.ReadBuffer(buffer) 264 return 265 } 266 267 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 268 return c.WriteBuffer(buffer) 269 } 270 271 func (c *clientPacketConn) LocalAddr() net.Addr { 272 return c.conn.LocalAddr() 273 } 274 275 func (c *clientPacketConn) RemoteAddr() net.Addr { 276 return c.destination.UDPAddr() 277 } 278 279 func (c *clientPacketConn) NeedAdditionalReadDeadline() bool { 280 return true 281 } 282 283 func (c *clientPacketConn) Upstream() any { 284 return c.conn 285 } 286 287 var _ N.NetPacketConn = (*clientPacketAddrConn)(nil) 288 289 type clientPacketAddrConn struct { 290 N.AbstractConn 291 conn N.ExtendedConn 292 access sync.Mutex 293 destination M.Socksaddr 294 requestWritten bool 295 responseRead bool 296 readWaitOptions N.ReadWaitOptions 297 } 298 299 func (c *clientPacketAddrConn) NeedHandshake() bool { 300 return !c.requestWritten 301 } 302 303 func (c *clientPacketAddrConn) readResponse() error { 304 response, err := ReadStreamResponse(c.conn) 305 if err != nil { 306 return err 307 } 308 if response.Status == statusError { 309 return E.New("remote error: ", response.Message) 310 } 311 return nil 312 } 313 314 func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 315 if !c.responseRead { 316 err = c.readResponse() 317 if err != nil { 318 return 319 } 320 c.responseRead = true 321 } 322 destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn) 323 if err != nil { 324 return 325 } 326 if destination.IsFqdn() { 327 addr = destination 328 } else { 329 addr = destination.UDPAddr() 330 } 331 var length uint16 332 err = binary.Read(c.conn, binary.BigEndian, &length) 333 if err != nil { 334 return 335 } 336 if cap(p) < int(length) { 337 return 0, nil, io.ErrShortBuffer 338 } 339 n, err = io.ReadFull(c.conn, p[:length]) 340 return 341 } 342 343 func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { 344 request := StreamRequest{ 345 Network: N.NetworkUDP, 346 Destination: c.destination, 347 PacketAddr: true, 348 } 349 rLen := streamRequestLen(request) 350 if len(payload) > 0 { 351 rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload) 352 } 353 buffer := buf.NewSize(rLen) 354 defer buffer.Release() 355 err = EncodeStreamRequest(request, buffer) 356 if err != nil { 357 return 358 } 359 if len(payload) > 0 { 360 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 361 if err != nil { 362 return 363 } 364 common.Must( 365 binary.Write(buffer, binary.BigEndian, uint16(len(payload))), 366 common.Error(buffer.Write(payload)), 367 ) 368 } 369 _, err = c.conn.Write(buffer.Bytes()) 370 if err != nil { 371 return 372 } 373 c.requestWritten = true 374 return len(payload), nil 375 } 376 377 func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 378 if !c.requestWritten { 379 c.access.Lock() 380 if c.requestWritten { 381 c.access.Unlock() 382 } else { 383 defer c.access.Unlock() 384 return c.writeRequest(p, M.SocksaddrFromNet(addr)) 385 } 386 } 387 err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr)) 388 if err != nil { 389 return 390 } 391 err = binary.Write(c.conn, binary.BigEndian, uint16(len(p))) 392 if err != nil { 393 return 394 } 395 return c.conn.Write(p) 396 } 397 398 func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 399 if !c.responseRead { 400 err = c.readResponse() 401 if err != nil { 402 return 403 } 404 c.responseRead = true 405 } 406 destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn) 407 if err != nil { 408 return 409 } 410 var length uint16 411 err = binary.Read(c.conn, binary.BigEndian, &length) 412 if err != nil { 413 return 414 } 415 _, err = buffer.ReadFullFrom(c.conn, int(length)) 416 return 417 } 418 419 func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 420 if !c.requestWritten { 421 c.access.Lock() 422 if c.requestWritten { 423 c.access.Unlock() 424 } else { 425 defer c.access.Unlock() 426 defer buffer.Release() 427 return common.Error(c.writeRequest(buffer.Bytes(), destination)) 428 } 429 } 430 bLen := buffer.Len() 431 header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2)) 432 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 433 if err != nil { 434 return err 435 } 436 common.Must(binary.Write(header, binary.BigEndian, uint16(bLen))) 437 return c.conn.WriteBuffer(buffer) 438 } 439 440 func (c *clientPacketAddrConn) LocalAddr() net.Addr { 441 return c.conn.LocalAddr() 442 } 443 444 func (c *clientPacketAddrConn) FrontHeadroom() int { 445 return 2 + M.MaxSocksaddrLength 446 } 447 448 func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool { 449 return true 450 } 451 452 func (c *clientPacketAddrConn) Upstream() any { 453 return c.conn 454 }