github.com/yaling888/clash@v1.53.0/transport/socks5/socks5.go (about) 1 package socks5 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "net/netip" 10 "strconv" 11 12 "github.com/yaling888/clash/common/pool" 13 "github.com/yaling888/clash/component/auth" 14 ) 15 16 // Error represents a SOCKS error 17 type Error byte 18 19 func (err Error) Error() string { 20 return "SOCKS error: " + strconv.Itoa(int(err)) 21 } 22 23 // Command is request commands as defined in RFC 1928 section 4. 24 type Command = uint8 25 26 const Version = 5 27 28 // SOCKS request commands as defined in RFC 1928 section 4. 29 const ( 30 CmdConnect Command = 1 31 CmdBind Command = 2 32 CmdUDPAssociate Command = 3 33 ) 34 35 // SOCKS address types as defined in RFC 1928 section 5. 36 const ( 37 AtypIPv4 = 1 38 AtypDomainName = 3 39 AtypIPv6 = 4 40 ) 41 42 // MaxAddrLen is the maximum size of SOCKS address in bytes. 43 const MaxAddrLen = 1 + 1 + 255 + 2 44 45 // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth 46 const MaxAuthLen = 255 47 48 // Addr represents a SOCKS address as defined in RFC 1928 section 5. 49 type Addr []byte 50 51 func (a Addr) String() string { 52 var host, port string 53 54 switch a[0] { 55 case AtypDomainName: 56 hostLen := uint16(a[1]) 57 host = string(a[2 : 2+hostLen]) 58 port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) 59 case AtypIPv4: 60 host = net.IP(a[1 : 1+net.IPv4len]).String() 61 port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) 62 case AtypIPv6: 63 host = net.IP(a[1 : 1+net.IPv6len]).String() 64 port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) 65 } 66 67 return net.JoinHostPort(host, port) 68 } 69 70 // UDPAddr converts a socks5.Addr to *net.UDPAddr 71 func (a Addr) UDPAddr() *net.UDPAddr { 72 if len(a) == 0 { 73 return nil 74 } 75 switch a[0] { 76 case AtypIPv4: 77 var ip [net.IPv4len]byte 78 copy(ip[0:], a[1:1+net.IPv4len]) 79 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} 80 case AtypIPv6: 81 var ip [net.IPv6len]byte 82 copy(ip[0:], a[1:1+net.IPv6len]) 83 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} 84 } 85 // Other Atyp 86 return nil 87 } 88 89 // SOCKS errors as defined in RFC 1928 section 6. 90 const ( 91 ErrCommandNotSupported = Error(7) 92 ErrAddressNotSupported = Error(8) 93 ) 94 95 // ErrAuth errors used to return a specific "Auth failed" error 96 var ErrAuth = errors.New("auth failed") 97 98 type User struct { 99 Username string 100 Password string 101 } 102 103 // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. 104 func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) { 105 // Read RFC 1928 for request and reply structure and sizes. 106 buf := make([]byte, MaxAddrLen) 107 // read VER, NMETHODS, METHODS 108 if _, err = io.ReadFull(rw, buf[:2]); err != nil { 109 return 110 } 111 nmethods := buf[1] 112 if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { 113 return 114 } 115 116 // write VER METHOD 117 if authenticator != nil { 118 if _, err = rw.Write([]byte{5, 2}); err != nil { 119 return 120 } 121 122 // Get header 123 header := make([]byte, 2) 124 if _, err = io.ReadFull(rw, header); err != nil { 125 return 126 } 127 128 authBuf := make([]byte, MaxAuthLen) 129 // Get username 130 userLen := int(header[1]) 131 if userLen <= 0 { 132 _, _ = rw.Write([]byte{1, 1}) 133 err = ErrAuth 134 return 135 } 136 if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { 137 return 138 } 139 user := make([]byte, userLen) 140 copy(user, authBuf[:userLen]) 141 142 // Get password 143 if _, err = rw.Read(header[:1]); err != nil { 144 return 145 } 146 passLen := int(header[0]) 147 if passLen <= 0 { 148 _, _ = rw.Write([]byte{1, 1}) 149 err = ErrAuth 150 return 151 } 152 if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil { 153 return 154 } 155 pass := authBuf[:passLen] 156 157 // Verify 158 if ok := authenticator.Verify(user, pass); !ok { 159 _, _ = rw.Write([]byte{1, 1}) 160 err = ErrAuth 161 return 162 } 163 164 // Response auth state 165 if _, err = rw.Write([]byte{1, 0}); err != nil { 166 return 167 } 168 } else { 169 if _, err = rw.Write([]byte{5, 0}); err != nil { 170 return 171 } 172 } 173 174 // read VER CMD RSV ATYP DST.ADDR DST.PORT 175 if _, err = io.ReadFull(rw, buf[:3]); err != nil { 176 return 177 } 178 179 command = buf[1] 180 addr, err = ReadAddr(rw, buf) 181 if err != nil { 182 return 183 } 184 185 switch command { 186 case CmdConnect, CmdUDPAssociate: 187 // Acquire server listened address info 188 localAddr := ParseAddr(rw.LocalAddr().String()) 189 if localAddr == nil { 190 err = ErrAddressNotSupported 191 } else { 192 // write VER REP RSV ATYP BND.ADDR BND.PORT 193 _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) 194 } 195 case CmdBind: 196 fallthrough 197 default: 198 err = ErrCommandNotSupported 199 } 200 201 return 202 } 203 204 // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. 205 func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { 206 buf := make([]byte, MaxAddrLen) 207 var err error 208 209 // VER, NMETHODS, METHODS 210 if user != nil { 211 _, err = rw.Write([]byte{5, 1, 2}) 212 } else { 213 _, err = rw.Write([]byte{5, 1, 0}) 214 } 215 if err != nil { 216 return nil, err 217 } 218 219 // VER, METHOD 220 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 221 return nil, err 222 } 223 224 if buf[0] != 5 { 225 return nil, errors.New("SOCKS version error") 226 } 227 228 if buf[1] == 2 { 229 if user == nil { 230 return nil, ErrAuth 231 } 232 233 // password protocol version 234 authMsg := pool.BufferWriter{} 235 authMsg.PutUint8(1) 236 authMsg.PutUint8(uint8(len(user.Username))) 237 authMsg.PutString(user.Username) 238 authMsg.PutUint8(uint8(len(user.Password))) 239 authMsg.PutString(user.Password) 240 241 if _, err := rw.Write(authMsg.Bytes()); err != nil { 242 return nil, err 243 } 244 245 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 246 return nil, err 247 } 248 249 if buf[1] != 0 { 250 return nil, errors.New("rejected username/password") 251 } 252 } else if buf[1] != 0 { 253 return nil, errors.New("SOCKS need auth") 254 } 255 256 // VER, CMD, RSV, ADDR 257 if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { 258 return nil, err 259 } 260 261 // VER, REP, RSV 262 if _, err := io.ReadFull(rw, buf[:3]); err != nil { 263 return nil, err 264 } 265 266 return ReadAddr(rw, buf) 267 } 268 269 func ReadAddr(r io.Reader, b []byte) (Addr, error) { 270 if len(b) < MaxAddrLen { 271 return nil, io.ErrShortBuffer 272 } 273 _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type 274 if err != nil { 275 return nil, err 276 } 277 278 switch b[0] { 279 case AtypDomainName: 280 _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length 281 if err != nil { 282 return nil, err 283 } 284 domainLength := uint16(b[1]) 285 _, err = io.ReadFull(r, b[2:2+domainLength+2]) 286 return b[:1+1+domainLength+2], err 287 case AtypIPv4: 288 _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) 289 return b[:1+net.IPv4len+2], err 290 case AtypIPv6: 291 _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) 292 return b[:1+net.IPv6len+2], err 293 } 294 295 return nil, ErrAddressNotSupported 296 } 297 298 // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. 299 func SplitAddr(b []byte) Addr { 300 addrLen := 1 301 if len(b) < addrLen { 302 return nil 303 } 304 305 switch b[0] { 306 case AtypDomainName: 307 if len(b) < 2 { 308 return nil 309 } 310 addrLen = 1 + 1 + int(b[1]) + 2 311 case AtypIPv4: 312 addrLen = 1 + net.IPv4len + 2 313 case AtypIPv6: 314 addrLen = 1 + net.IPv6len + 2 315 default: 316 return nil 317 } 318 319 if len(b) < addrLen { 320 return nil 321 } 322 323 return b[:addrLen] 324 } 325 326 // ParseAddr parses the address in string s. Returns nil if failed. 327 func ParseAddr(s string) Addr { 328 buf := pool.BufferWriter{} 329 host, port, err := net.SplitHostPort(s) 330 if err != nil { 331 return nil 332 } 333 if ip, err := netip.ParseAddr(host); err == nil { 334 if ip.Is4() { 335 buf.PutUint8(AtypIPv4) 336 } else { 337 buf.PutUint8(AtypIPv6) 338 } 339 buf.PutSlice(ip.AsSlice()) 340 } else { 341 if len(host) > 255 { 342 return nil 343 } 344 buf.PutUint8(AtypDomainName) 345 buf.PutUint8(byte(len(host))) 346 buf.PutString(host) 347 } 348 349 portNum, err := strconv.ParseUint(port, 10, 16) 350 if err != nil { 351 return nil 352 } 353 354 buf.PutUint16be(uint16(portNum)) 355 return buf.Bytes() 356 } 357 358 // ParseAddrToSocksAddr parse a socks addr from net.addr 359 // This is a fast path of ParseAddr(addr.String()) 360 func ParseAddrToSocksAddr(addr net.Addr) Addr { 361 var hostip net.IP 362 var port int 363 if udpaddr, ok := addr.(*net.UDPAddr); ok { 364 hostip = udpaddr.IP 365 port = udpaddr.Port 366 } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { 367 hostip = tcpaddr.IP 368 port = tcpaddr.Port 369 } 370 371 // fallback parse 372 if hostip == nil { 373 return ParseAddr(addr.String()) 374 } 375 376 var parsed pool.BufferWriter 377 if ip4 := hostip.To4(); ip4 != nil { 378 parsed = make([]byte, 0, 1+net.IPv4len+2) 379 parsed.PutUint8(AtypIPv4) 380 parsed.PutSlice(ip4) 381 } else { 382 parsed = make([]byte, 0, 1+net.IPv6len+2) 383 parsed.PutUint8(AtypIPv6) 384 parsed.PutSlice(hostip) 385 } 386 387 parsed.PutUint16be(uint16(port)) 388 return parsed.Bytes() 389 } 390 391 func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr { 392 addr := addrPort.Addr().Unmap() 393 if !addr.IsValid() { 394 return nil 395 } 396 397 buf := pool.BufferWriter{} 398 if addr.Is4() { 399 buf.PutUint8(AtypIPv4) 400 } else { 401 buf.PutUint8(AtypIPv6) 402 } 403 404 buf.PutSlice(addr.AsSlice()) 405 buf.PutUint16be(addrPort.Port()) 406 return buf.Bytes() 407 } 408 409 // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` 410 func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { 411 r := pool.BufferReader(packet) 412 413 if r.Len() < 5 { 414 err = errors.New("insufficient length of packet") 415 return 416 } 417 418 // packet[0] and packet[1] are reserved 419 reserved, r := r.SplitAt(2) 420 if !bytes.Equal(reserved, []byte{0, 0}) { 421 err = errors.New("reserved fields should be zero") 422 return 423 } 424 425 if r.ReadUint8() != 0 /* fragments */ { 426 err = errors.New("discarding fragmented payload") 427 return 428 } 429 430 addr = SplitAddr(r) 431 if addr == nil { 432 err = errors.New("failed to read UDP header") 433 } 434 435 _, payload = r.SplitAt(len(addr)) 436 return 437 } 438 439 func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { 440 if addr == nil { 441 err = errors.New("address is invalid") 442 return 443 } 444 w := pool.BufferWriter{} 445 w.PutSlice([]byte{0, 0, 0}) 446 w.PutSlice(addr) 447 w.PutSlice(payload) 448 packet = w.Bytes() 449 return 450 }