github.com/igoogolx/clash@v1.19.8/transport/socks5/socks5.go (about) 1 package socks5 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "net/netip" 10 "strconv" 11 12 "github.com/igoogolx/clash/component/auth" 13 14 "github.com/Dreamacro/protobytes" 15 ) 16 17 // Error represents a SOCKS error 18 type Error byte 19 20 func (err Error) Error() string { 21 return "SOCKS error: " + strconv.Itoa(int(err)) 22 } 23 24 // Command is request commands as defined in RFC 1928 section 4. 25 type Command = uint8 26 27 const Version = 5 28 29 // SOCKS request commands as defined in RFC 1928 section 4. 30 const ( 31 CmdConnect Command = 1 32 CmdBind Command = 2 33 CmdUDPAssociate Command = 3 34 ) 35 36 // SOCKS address types as defined in RFC 1928 section 5. 37 const ( 38 AtypIPv4 = 1 39 AtypDomainName = 3 40 AtypIPv6 = 4 41 ) 42 43 // MaxAddrLen is the maximum size of SOCKS address in bytes. 44 const MaxAddrLen = 1 + 1 + 255 + 2 45 46 // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth 47 const MaxAuthLen = 255 48 49 // Addr represents a SOCKS address as defined in RFC 1928 section 5. 50 type Addr []byte 51 52 func (a Addr) String() string { 53 var host, port string 54 55 switch a[0] { 56 case AtypDomainName: 57 hostLen := uint16(a[1]) 58 host = string(a[2 : 2+hostLen]) 59 port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) 60 case AtypIPv4: 61 host = net.IP(a[1 : 1+net.IPv4len]).String() 62 port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) 63 case AtypIPv6: 64 host = net.IP(a[1 : 1+net.IPv6len]).String() 65 port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) 66 } 67 68 return net.JoinHostPort(host, port) 69 } 70 71 // UDPAddr converts a socks5.Addr to *net.UDPAddr 72 func (a Addr) UDPAddr() *net.UDPAddr { 73 if len(a) == 0 { 74 return nil 75 } 76 switch a[0] { 77 case AtypIPv4: 78 var ip [net.IPv4len]byte 79 copy(ip[0:], a[1:1+net.IPv4len]) 80 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} 81 case AtypIPv6: 82 var ip [net.IPv6len]byte 83 copy(ip[0:], a[1:1+net.IPv6len]) 84 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} 85 } 86 // Other Atyp 87 return nil 88 } 89 90 // SOCKS errors as defined in RFC 1928 section 6. 91 const ( 92 ErrGeneralFailure = Error(1) 93 ErrConnectionNotAllowed = Error(2) 94 ErrNetworkUnreachable = Error(3) 95 ErrHostUnreachable = Error(4) 96 ErrConnectionRefused = Error(5) 97 ErrTTLExpired = Error(6) 98 ErrCommandNotSupported = Error(7) 99 ErrAddressNotSupported = Error(8) 100 ) 101 102 // Auth errors used to return a specific "Auth failed" error 103 var ErrAuth = errors.New("auth failed") 104 105 type User struct { 106 Username string 107 Password string 108 } 109 110 // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. 111 func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) { 112 // Read RFC 1928 for request and reply structure and sizes. 113 buf := make([]byte, MaxAddrLen) 114 // read VER, NMETHODS, METHODS 115 if _, err = io.ReadFull(rw, buf[:2]); err != nil { 116 return 117 } 118 nmethods := buf[1] 119 if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { 120 return 121 } 122 123 // write VER METHOD 124 if authenticator != nil { 125 if _, err = rw.Write([]byte{5, 2}); err != nil { 126 return 127 } 128 129 // Get header 130 header := make([]byte, 2) 131 if _, err = io.ReadFull(rw, header); err != nil { 132 return 133 } 134 135 authBuf := make([]byte, MaxAuthLen) 136 // Get username 137 userLen := int(header[1]) 138 if userLen <= 0 { 139 rw.Write([]byte{1, 1}) 140 err = ErrAuth 141 return 142 } 143 if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { 144 return 145 } 146 user := string(authBuf[:userLen]) 147 148 // Get password 149 if _, err = rw.Read(header[:1]); err != nil { 150 return 151 } 152 passLen := int(header[0]) 153 if passLen <= 0 { 154 rw.Write([]byte{1, 1}) 155 err = ErrAuth 156 return 157 } 158 if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil { 159 return 160 } 161 pass := string(authBuf[:passLen]) 162 163 // Verify 164 if ok := authenticator.Verify(user, pass); !ok { 165 rw.Write([]byte{1, 1}) 166 err = ErrAuth 167 return 168 } 169 170 // Response auth state 171 if _, err = rw.Write([]byte{1, 0}); err != nil { 172 return 173 } 174 } else { 175 if _, err = rw.Write([]byte{5, 0}); err != nil { 176 return 177 } 178 } 179 180 // read VER CMD RSV ATYP DST.ADDR DST.PORT 181 if _, err = io.ReadFull(rw, buf[:3]); err != nil { 182 return 183 } 184 185 command = buf[1] 186 addr, err = ReadAddr(rw, buf) 187 if err != nil { 188 return 189 } 190 191 switch command { 192 case CmdConnect, CmdUDPAssociate: 193 // Acquire server listened address info 194 localAddr := ParseAddr(rw.LocalAddr().String()) 195 if localAddr == nil { 196 err = ErrAddressNotSupported 197 } else { 198 // write VER REP RSV ATYP BND.ADDR BND.PORT 199 _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) 200 } 201 case CmdBind: 202 fallthrough 203 default: 204 err = ErrCommandNotSupported 205 } 206 207 return 208 } 209 210 // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. 211 func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { 212 buf := make([]byte, MaxAddrLen) 213 var err error 214 215 // VER, NMETHODS, METHODS 216 if user != nil { 217 _, err = rw.Write([]byte{5, 1, 2}) 218 } else { 219 _, err = rw.Write([]byte{5, 1, 0}) 220 } 221 if err != nil { 222 return nil, err 223 } 224 225 // VER, METHOD 226 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 227 return nil, err 228 } 229 230 if buf[0] != 5 { 231 return nil, errors.New("SOCKS version error") 232 } 233 234 if buf[1] == 2 { 235 if user == nil { 236 return nil, ErrAuth 237 } 238 239 // password protocol version 240 authMsg := protobytes.BytesWriter{} 241 authMsg.PutUint8(1) 242 authMsg.PutUint8(uint8(len(user.Username))) 243 authMsg.PutString(user.Username) 244 authMsg.PutUint8(uint8(len(user.Password))) 245 authMsg.PutString(user.Password) 246 247 if _, err := rw.Write(authMsg.Bytes()); err != nil { 248 return nil, err 249 } 250 251 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 252 return nil, err 253 } 254 255 if buf[1] != 0 { 256 return nil, errors.New("rejected username/password") 257 } 258 } else if buf[1] != 0 { 259 return nil, errors.New("SOCKS need auth") 260 } 261 262 // VER, CMD, RSV, ADDR 263 if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { 264 return nil, err 265 } 266 267 // VER, REP, RSV 268 if _, err := io.ReadFull(rw, buf[:3]); err != nil { 269 return nil, err 270 } 271 272 return ReadAddr(rw, buf) 273 } 274 275 func ReadAddr(r io.Reader, b []byte) (Addr, error) { 276 if len(b) < MaxAddrLen { 277 return nil, io.ErrShortBuffer 278 } 279 _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type 280 if err != nil { 281 return nil, err 282 } 283 284 switch b[0] { 285 case AtypDomainName: 286 _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length 287 if err != nil { 288 return nil, err 289 } 290 domainLength := uint16(b[1]) 291 _, err = io.ReadFull(r, b[2:2+domainLength+2]) 292 return b[:1+1+domainLength+2], err 293 case AtypIPv4: 294 _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) 295 return b[:1+net.IPv4len+2], err 296 case AtypIPv6: 297 _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) 298 return b[:1+net.IPv6len+2], err 299 } 300 301 return nil, ErrAddressNotSupported 302 } 303 304 // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. 305 func SplitAddr(b []byte) Addr { 306 addrLen := 1 307 if len(b) < addrLen { 308 return nil 309 } 310 311 switch b[0] { 312 case AtypDomainName: 313 if len(b) < 2 { 314 return nil 315 } 316 addrLen = 1 + 1 + int(b[1]) + 2 317 case AtypIPv4: 318 addrLen = 1 + net.IPv4len + 2 319 case AtypIPv6: 320 addrLen = 1 + net.IPv6len + 2 321 default: 322 return nil 323 324 } 325 326 if len(b) < addrLen { 327 return nil 328 } 329 330 return b[:addrLen] 331 } 332 333 // ParseAddr parses the address in string s. Returns nil if failed. 334 func ParseAddr(s string) Addr { 335 buf := protobytes.BytesWriter{} 336 host, port, err := net.SplitHostPort(s) 337 if err != nil { 338 return nil 339 } 340 if ip := net.ParseIP(host); ip != nil { 341 if ip4 := ip.To4(); ip4 != nil { 342 buf.PutUint8(AtypIPv4) 343 buf.PutSlice(ip4) 344 } else { 345 buf.PutUint8(AtypIPv6) 346 buf.PutSlice(ip) 347 } 348 } else { 349 if len(host) > 255 { 350 return nil 351 } 352 buf.PutUint8(AtypDomainName) 353 buf.PutUint8(byte(len(host))) 354 buf.PutString(host) 355 } 356 357 portnum, err := strconv.ParseUint(port, 10, 16) 358 if err != nil { 359 return nil 360 } 361 362 buf.PutUint16be(uint16(portnum)) 363 return Addr(buf.Bytes()) 364 } 365 366 // ParseAddrToSocksAddr parse a socks addr from net.addr 367 // This is a fast path of ParseAddr(addr.String()) 368 func ParseAddrToSocksAddr(addr net.Addr) Addr { 369 var hostip net.IP 370 var port int 371 if udpaddr, ok := addr.(*net.UDPAddr); ok { 372 hostip = udpaddr.IP 373 port = udpaddr.Port 374 } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { 375 hostip = tcpaddr.IP 376 port = tcpaddr.Port 377 } 378 379 // fallback parse 380 if hostip == nil { 381 return ParseAddr(addr.String()) 382 } 383 384 var parsed protobytes.BytesWriter 385 if ip4 := hostip.To4(); ip4.DefaultMask() != nil { 386 parsed = make([]byte, 0, 1+net.IPv4len+2) 387 parsed.PutUint8(AtypIPv4) 388 parsed.PutSlice(ip4) 389 parsed.PutUint16be(uint16(port)) 390 } else { 391 parsed = make([]byte, 0, 1+net.IPv6len+2) 392 parsed.PutUint8(AtypIPv6) 393 parsed.PutSlice(hostip) 394 parsed.PutUint16be(uint16(port)) 395 } 396 return Addr(parsed) 397 } 398 399 func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr { 400 addr := addrPort.Addr() 401 if addr.Is4() { 402 ip4 := addr.As4() 403 return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())} 404 } 405 406 buf := make([]byte, 1+net.IPv6len+2) 407 buf[0] = AtypIPv6 408 copy(buf[1:], addr.AsSlice()) 409 buf[1+net.IPv6len] = byte(addrPort.Port() >> 8) 410 buf[1+net.IPv6len+1] = byte(addrPort.Port()) 411 return buf 412 } 413 414 // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` 415 func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { 416 r := protobytes.BytesReader(packet) 417 418 if r.Len() < 5 { 419 err = errors.New("insufficient length of packet") 420 return 421 } 422 423 // packet[0] and packet[1] are reserved 424 reserved, r := r.SplitAt(2) 425 if !bytes.Equal(reserved, []byte{0, 0}) { 426 err = errors.New("reserved fields should be zero") 427 return 428 } 429 430 if r.ReadUint8() != 0 /* fragments */ { 431 err = errors.New("discarding fragmented payload") 432 return 433 } 434 435 addr = SplitAddr(r) 436 if addr == nil { 437 err = errors.New("failed to read UDP header") 438 } 439 440 _, payload = r.SplitAt(len(addr)) 441 return 442 } 443 444 func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { 445 if addr == nil { 446 err = errors.New("address is invalid") 447 return 448 } 449 w := protobytes.BytesWriter{} 450 w.PutSlice([]byte{0, 0, 0}) 451 w.PutSlice(addr) 452 w.PutSlice(payload) 453 packet = w.Bytes() 454 return 455 }