github.com/metacubex/mihomo@v1.18.5/transport/socks5/socks5.go (about) 1 package socks5 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "net/netip" 10 "strconv" 11 12 "github.com/metacubex/mihomo/component/auth" 13 ) 14 15 // Error represents a SOCKS error 16 type Error byte 17 18 func (err Error) Error() string { 19 return "SOCKS error: " + strconv.Itoa(int(err)) 20 } 21 22 // Command is request commands as defined in RFC 1928 section 4. 23 type Command = uint8 24 25 const Version = 5 26 27 // SOCKS request commands as defined in RFC 1928 section 4. 28 const ( 29 CmdConnect Command = 1 30 CmdBind Command = 2 31 CmdUDPAssociate Command = 3 32 ) 33 34 // SOCKS address types as defined in RFC 1928 section 5. 35 const ( 36 AtypIPv4 = 1 37 AtypDomainName = 3 38 AtypIPv6 = 4 39 ) 40 41 // MaxAddrLen is the maximum size of SOCKS address in bytes. 42 const MaxAddrLen = 1 + 1 + 255 + 2 43 44 // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth 45 const MaxAuthLen = 255 46 47 // Addr represents a SOCKS address as defined in RFC 1928 section 5. 48 type Addr []byte 49 50 func (a Addr) String() string { 51 var host, port string 52 53 switch a[0] { 54 case AtypDomainName: 55 hostLen := uint16(a[1]) 56 host = string(a[2 : 2+hostLen]) 57 port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) 58 case AtypIPv4: 59 host = net.IP(a[1 : 1+net.IPv4len]).String() 60 port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) 61 case AtypIPv6: 62 host = net.IP(a[1 : 1+net.IPv6len]).String() 63 port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) 64 } 65 66 return net.JoinHostPort(host, port) 67 } 68 69 // UDPAddr converts a socks5.Addr to *net.UDPAddr 70 func (a Addr) UDPAddr() *net.UDPAddr { 71 if len(a) == 0 { 72 return nil 73 } 74 switch a[0] { 75 case AtypIPv4: 76 var ip [net.IPv4len]byte 77 copy(ip[0:], a[1:1+net.IPv4len]) 78 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} 79 case AtypIPv6: 80 var ip [net.IPv6len]byte 81 copy(ip[0:], a[1:1+net.IPv6len]) 82 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} 83 } 84 // Other Atyp 85 return nil 86 } 87 88 // SOCKS errors as defined in RFC 1928 section 6. 89 const ( 90 ErrGeneralFailure = Error(1) 91 ErrConnectionNotAllowed = Error(2) 92 ErrNetworkUnreachable = Error(3) 93 ErrHostUnreachable = Error(4) 94 ErrConnectionRefused = Error(5) 95 ErrTTLExpired = Error(6) 96 ErrCommandNotSupported = Error(7) 97 ErrAddressNotSupported = Error(8) 98 ) 99 100 // Auth errors used to return a specific "Auth failed" error 101 var ErrAuth = errors.New("auth failed") 102 103 type User struct { 104 Username string 105 Password string 106 } 107 108 // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. 109 func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, user string, err error) { 110 // Read RFC 1928 for request and reply structure and sizes. 111 buf := make([]byte, MaxAddrLen) 112 // read VER, NMETHODS, METHODS 113 if _, err = io.ReadFull(rw, buf[:2]); err != nil { 114 return 115 } 116 nmethods := buf[1] 117 if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { 118 return 119 } 120 121 // write VER METHOD 122 if authenticator != nil { 123 if _, err = rw.Write([]byte{5, 2}); err != nil { 124 return 125 } 126 127 // Get header 128 header := make([]byte, 2) 129 if _, err = io.ReadFull(rw, header); err != nil { 130 return 131 } 132 133 authBuf := make([]byte, MaxAuthLen) 134 // Get username 135 userLen := int(header[1]) 136 if userLen <= 0 { 137 rw.Write([]byte{1, 1}) 138 err = ErrAuth 139 return 140 } 141 if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { 142 return 143 } 144 user = string(authBuf[:userLen]) 145 146 // Get password 147 if _, err = rw.Read(header[:1]); err != nil { 148 return 149 } 150 passLen := int(header[0]) 151 if passLen <= 0 { 152 rw.Write([]byte{1, 1}) 153 err = ErrAuth 154 return 155 } 156 if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil { 157 return 158 } 159 pass := string(authBuf[:passLen]) 160 161 // Verify 162 if ok := authenticator.Verify(string(user), string(pass)); !ok { 163 rw.Write([]byte{1, 1}) 164 err = ErrAuth 165 return 166 } 167 168 // Response auth state 169 if _, err = rw.Write([]byte{1, 0}); err != nil { 170 return 171 } 172 } else { 173 if _, err = rw.Write([]byte{5, 0}); err != nil { 174 return 175 } 176 } 177 178 // read VER CMD RSV ATYP DST.ADDR DST.PORT 179 if _, err = io.ReadFull(rw, buf[:3]); err != nil { 180 return 181 } 182 183 command = buf[1] 184 addr, err = ReadAddr(rw, buf) 185 if err != nil { 186 return 187 } 188 189 switch command { 190 case CmdConnect, CmdUDPAssociate: 191 // Acquire server listened address info 192 localAddr := ParseAddr(rw.LocalAddr().String()) 193 if localAddr == nil { 194 err = ErrAddressNotSupported 195 } else { 196 // write VER REP RSV ATYP BND.ADDR BND.PORT 197 _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) 198 } 199 case CmdBind: 200 fallthrough 201 default: 202 err = ErrCommandNotSupported 203 } 204 205 return 206 } 207 208 // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. 209 func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { 210 buf := make([]byte, MaxAddrLen) 211 var err error 212 213 // VER, NMETHODS, METHODS 214 if user != nil { 215 _, err = rw.Write([]byte{5, 1, 2}) 216 } else { 217 _, err = rw.Write([]byte{5, 1, 0}) 218 } 219 if err != nil { 220 return nil, err 221 } 222 223 // VER, METHOD 224 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 225 return nil, err 226 } 227 228 if buf[0] != 5 { 229 return nil, errors.New("SOCKS version error") 230 } 231 232 if buf[1] == 2 { 233 if user == nil { 234 return nil, ErrAuth 235 } 236 237 // password protocol version 238 authMsg := &bytes.Buffer{} 239 authMsg.WriteByte(1) 240 authMsg.WriteByte(uint8(len(user.Username))) 241 authMsg.WriteString(user.Username) 242 authMsg.WriteByte(uint8(len(user.Password))) 243 authMsg.WriteString(user.Password) 244 245 if _, err := rw.Write(authMsg.Bytes()); err != nil { 246 return nil, err 247 } 248 249 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 250 return nil, err 251 } 252 253 if buf[1] != 0 { 254 return nil, errors.New("rejected username/password") 255 } 256 } else if buf[1] != 0 { 257 return nil, errors.New("SOCKS need auth") 258 } 259 260 // VER, CMD, RSV, ADDR 261 if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { 262 return nil, err 263 } 264 265 // VER, REP, RSV 266 if _, err := io.ReadFull(rw, buf[:3]); err != nil { 267 return nil, err 268 } 269 270 return ReadAddr(rw, buf) 271 } 272 273 func ReadAddr(r io.Reader, b []byte) (Addr, error) { 274 if len(b) < MaxAddrLen { 275 return nil, io.ErrShortBuffer 276 } 277 _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type 278 if err != nil { 279 return nil, err 280 } 281 282 switch b[0] { 283 case AtypDomainName: 284 _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length 285 if err != nil { 286 return nil, err 287 } 288 domainLength := uint16(b[1]) 289 _, err = io.ReadFull(r, b[2:2+domainLength+2]) 290 return b[:1+1+domainLength+2], err 291 case AtypIPv4: 292 _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) 293 return b[:1+net.IPv4len+2], err 294 case AtypIPv6: 295 _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) 296 return b[:1+net.IPv6len+2], err 297 } 298 299 return nil, ErrAddressNotSupported 300 } 301 302 func ReadAddr0(r io.Reader) (Addr, error) { 303 aType, err := ReadByte(r) // read 1st byte for address type 304 if err != nil { 305 return nil, err 306 } 307 308 switch aType { 309 case AtypDomainName: 310 var domainLength byte 311 domainLength, err = ReadByte(r) // read 2nd byte for domain length 312 if err != nil { 313 return nil, err 314 } 315 b := make([]byte, 1+1+uint16(domainLength)+2) 316 _, err = io.ReadFull(r, b[2:]) 317 b[0] = aType 318 b[1] = domainLength 319 return b, err 320 case AtypIPv4: 321 var b [1 + net.IPv4len + 2]byte 322 _, err = io.ReadFull(r, b[1:]) 323 b[0] = aType 324 return b[:], err 325 case AtypIPv6: 326 var b [1 + net.IPv6len + 2]byte 327 _, err = io.ReadFull(r, b[1:]) 328 b[0] = aType 329 return b[:], err 330 } 331 332 return nil, ErrAddressNotSupported 333 } 334 335 func ReadByte(reader io.Reader) (byte, error) { 336 if br, isBr := reader.(io.ByteReader); isBr { 337 return br.ReadByte() 338 } 339 var b [1]byte 340 if _, err := io.ReadFull(reader, b[:]); err != nil { 341 return 0, err 342 } 343 return b[0], nil 344 } 345 346 // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. 347 func SplitAddr(b []byte) Addr { 348 addrLen := 1 349 if len(b) < addrLen { 350 return nil 351 } 352 353 switch b[0] { 354 case AtypDomainName: 355 if len(b) < 2 { 356 return nil 357 } 358 addrLen = 1 + 1 + int(b[1]) + 2 359 case AtypIPv4: 360 addrLen = 1 + net.IPv4len + 2 361 case AtypIPv6: 362 addrLen = 1 + net.IPv6len + 2 363 default: 364 return nil 365 366 } 367 368 if len(b) < addrLen { 369 return nil 370 } 371 372 return b[:addrLen] 373 } 374 375 // ParseAddr parses the address in string s. Returns nil if failed. 376 func ParseAddr(s string) Addr { 377 var addr Addr 378 host, port, err := net.SplitHostPort(s) 379 if err != nil { 380 return nil 381 } 382 if ip := net.ParseIP(host); ip != nil { 383 if ip4 := ip.To4(); ip4 != nil { 384 addr = make([]byte, 1+net.IPv4len+2) 385 addr[0] = AtypIPv4 386 copy(addr[1:], ip4) 387 } else { 388 addr = make([]byte, 1+net.IPv6len+2) 389 addr[0] = AtypIPv6 390 copy(addr[1:], ip) 391 } 392 } else { 393 if len(host) > 255 { 394 return nil 395 } 396 addr = make([]byte, 1+1+len(host)+2) 397 addr[0] = AtypDomainName 398 addr[1] = byte(len(host)) 399 copy(addr[2:], host) 400 } 401 402 portnum, err := strconv.ParseUint(port, 10, 16) 403 if err != nil { 404 return nil 405 } 406 407 addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum) 408 409 return addr 410 } 411 412 // ParseAddrToSocksAddr parse a socks addr from net.addr 413 // This is a fast path of ParseAddr(addr.String()) 414 func ParseAddrToSocksAddr(addr net.Addr) Addr { 415 var hostip net.IP 416 var port int 417 if udpaddr, ok := addr.(*net.UDPAddr); ok { 418 hostip = udpaddr.IP 419 port = udpaddr.Port 420 } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { 421 hostip = tcpaddr.IP 422 port = tcpaddr.Port 423 } 424 425 // fallback parse 426 if hostip == nil { 427 return ParseAddr(addr.String()) 428 } 429 430 var parsed Addr 431 if ip4 := hostip.To4(); ip4.DefaultMask() != nil { 432 parsed = make([]byte, 1+net.IPv4len+2) 433 parsed[0] = AtypIPv4 434 copy(parsed[1:], ip4) 435 binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port)) 436 437 } else { 438 parsed = make([]byte, 1+net.IPv6len+2) 439 parsed[0] = AtypIPv6 440 copy(parsed[1:], hostip) 441 binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port)) 442 } 443 return parsed 444 } 445 446 func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr { 447 addr := addrPort.Addr() 448 if addr.Is4() { 449 ip4 := addr.As4() 450 return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())} 451 } 452 453 buf := make([]byte, 1+net.IPv6len+2) 454 buf[0] = AtypIPv6 455 copy(buf[1:], addr.AsSlice()) 456 buf[1+net.IPv6len] = byte(addrPort.Port() >> 8) 457 buf[1+net.IPv6len+1] = byte(addrPort.Port()) 458 return buf 459 } 460 461 // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` 462 func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { 463 if len(packet) < 5 { 464 err = errors.New("insufficient length of packet") 465 return 466 } 467 468 // packet[0] and packet[1] are reserved 469 if !bytes.Equal(packet[:2], []byte{0, 0}) { 470 err = errors.New("reserved fields should be zero") 471 return 472 } 473 474 if packet[2] != 0 /* fragments */ { 475 err = errors.New("discarding fragmented payload") 476 return 477 } 478 479 addr = SplitAddr(packet[3:]) 480 if addr == nil { 481 err = errors.New("failed to read UDP header") 482 } 483 484 payload = packet[3+len(addr):] 485 return 486 } 487 488 func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { 489 if addr == nil { 490 err = errors.New("address is invalid") 491 return 492 } 493 packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{}) 494 return 495 }