github.com/chwjbn/xclash@v0.2.0/transport/socks5/socks5.go (about) 1 package socks5 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "strconv" 10 11 "github.com/chwjbn/xclash/component/auth" 12 ) 13 14 // Error represents a SOCKS error 15 type Error byte 16 17 func (err Error) Error() string { 18 return "SOCKS error: " + strconv.Itoa(int(err)) 19 } 20 21 // Command is request commands as defined in RFC 1928 section 4. 22 type Command = uint8 23 24 const Version = 5 25 26 // SOCKS request commands as defined in RFC 1928 section 4. 27 const ( 28 CmdConnect Command = 1 29 CmdBind Command = 2 30 CmdUDPAssociate Command = 3 31 ) 32 33 // SOCKS address types as defined in RFC 1928 section 5. 34 const ( 35 AtypIPv4 = 1 36 AtypDomainName = 3 37 AtypIPv6 = 4 38 ) 39 40 // MaxAddrLen is the maximum size of SOCKS address in bytes. 41 const MaxAddrLen = 1 + 1 + 255 + 2 42 43 // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth 44 const MaxAuthLen = 255 45 46 // Addr represents a SOCKS address as defined in RFC 1928 section 5. 47 type Addr []byte 48 49 func (a Addr) String() string { 50 var host, port string 51 52 switch a[0] { 53 case AtypDomainName: 54 hostLen := uint16(a[1]) 55 host = string(a[2 : 2+hostLen]) 56 port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1])) 57 case AtypIPv4: 58 host = net.IP(a[1 : 1+net.IPv4len]).String() 59 port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1])) 60 case AtypIPv6: 61 host = net.IP(a[1 : 1+net.IPv6len]).String() 62 port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1])) 63 } 64 65 return net.JoinHostPort(host, port) 66 } 67 68 // UDPAddr converts a socks5.Addr to *net.UDPAddr 69 func (a Addr) UDPAddr() *net.UDPAddr { 70 if len(a) == 0 { 71 return nil 72 } 73 switch a[0] { 74 case AtypIPv4: 75 var ip [net.IPv4len]byte 76 copy(ip[0:], a[1:1+net.IPv4len]) 77 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} 78 case AtypIPv6: 79 var ip [net.IPv6len]byte 80 copy(ip[0:], a[1:1+net.IPv6len]) 81 return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} 82 } 83 // Other Atyp 84 return nil 85 } 86 87 // SOCKS errors as defined in RFC 1928 section 6. 88 const ( 89 ErrGeneralFailure = Error(1) 90 ErrConnectionNotAllowed = Error(2) 91 ErrNetworkUnreachable = Error(3) 92 ErrHostUnreachable = Error(4) 93 ErrConnectionRefused = Error(5) 94 ErrTTLExpired = Error(6) 95 ErrCommandNotSupported = Error(7) 96 ErrAddressNotSupported = Error(8) 97 ) 98 99 // Auth errors used to return a specific "Auth failed" error 100 var ErrAuth = errors.New("auth failed") 101 102 type User struct { 103 Username string 104 Password string 105 } 106 107 // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side. 108 func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (authUser *auth.AuthUser, addr Addr, command Command, err error) { 109 // Read RFC 1928 for request and reply structure and sizes. 110 buf := make([]byte, MaxAddrLen) 111 // read VER, NMETHODS, METHODS 112 if _, err = io.ReadFull(rw, buf[:2]); err != nil { 113 return 114 } 115 nmethods := buf[1] 116 if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil { 117 return 118 } 119 120 // write VER METHOD 121 if authenticator != nil { 122 if _, err = rw.Write([]byte{5, 2}); err != nil { 123 return 124 } 125 126 // Get header 127 header := make([]byte, 2) 128 if _, err = io.ReadFull(rw, header); err != nil { 129 return 130 } 131 132 authBuf := make([]byte, MaxAuthLen) 133 // Get username 134 userLen := int(header[1]) 135 if userLen <= 0 { 136 rw.Write([]byte{1, 1}) 137 err = ErrAuth 138 return 139 } 140 if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil { 141 return 142 } 143 user := string(authBuf[:userLen]) 144 145 // Get password 146 if _, err = rw.Read(header[:1]); err != nil { 147 return 148 } 149 passLen := int(header[0]) 150 if passLen <= 0 { 151 rw.Write([]byte{1, 1}) 152 err = ErrAuth 153 return 154 } 155 if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil { 156 return 157 } 158 pass := string(authBuf[:passLen]) 159 160 // Verify 161 if ok := authenticator.Verify(string(user), string(pass)); !ok { 162 rw.Write([]byte{1, 1}) 163 err = ErrAuth 164 return 165 } 166 167 //Record Account 168 authUser = &auth.AuthUser{} 169 authUser.User = user 170 authUser.Pass = pass 171 172 // Response auth state 173 if _, err = rw.Write([]byte{1, 0}); err != nil { 174 return 175 } 176 } else { 177 if _, err = rw.Write([]byte{5, 0}); err != nil { 178 return 179 } 180 } 181 182 // read VER CMD RSV ATYP DST.ADDR DST.PORT 183 if _, err = io.ReadFull(rw, buf[:3]); err != nil { 184 return 185 } 186 187 command = buf[1] 188 addr, err = ReadAddr(rw, buf) 189 if err != nil { 190 return 191 } 192 193 switch command { 194 case CmdConnect, CmdUDPAssociate: 195 // Acquire server listened address info 196 localAddr := ParseAddr(rw.LocalAddr().String()) 197 if localAddr == nil { 198 err = ErrAddressNotSupported 199 } else { 200 // write VER REP RSV ATYP BND.ADDR BND.PORT 201 _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{})) 202 } 203 case CmdBind: 204 fallthrough 205 default: 206 err = ErrCommandNotSupported 207 } 208 209 return 210 } 211 212 // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side. 213 func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) { 214 buf := make([]byte, MaxAddrLen) 215 var err error 216 217 // VER, NMETHODS, METHODS 218 if user != nil { 219 _, err = rw.Write([]byte{5, 1, 2}) 220 } else { 221 _, err = rw.Write([]byte{5, 1, 0}) 222 } 223 if err != nil { 224 return nil, err 225 } 226 227 // VER, METHOD 228 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 229 return nil, err 230 } 231 232 if buf[0] != 5 { 233 return nil, errors.New("SOCKS version error") 234 } 235 236 if buf[1] == 2 { 237 if user == nil { 238 return nil, ErrAuth 239 } 240 241 // password protocol version 242 authMsg := &bytes.Buffer{} 243 authMsg.WriteByte(1) 244 authMsg.WriteByte(uint8(len(user.Username))) 245 authMsg.WriteString(user.Username) 246 authMsg.WriteByte(uint8(len(user.Password))) 247 authMsg.WriteString(user.Password) 248 249 if _, err := rw.Write(authMsg.Bytes()); err != nil { 250 return nil, err 251 } 252 253 if _, err := io.ReadFull(rw, buf[:2]); err != nil { 254 return nil, err 255 } 256 257 if buf[1] != 0 { 258 return nil, errors.New("rejected username/password") 259 } 260 } else if buf[1] != 0 { 261 return nil, errors.New("SOCKS need auth") 262 } 263 264 // VER, CMD, RSV, ADDR 265 if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil { 266 return nil, err 267 } 268 269 // VER, REP, RSV 270 if _, err := io.ReadFull(rw, buf[:3]); err != nil { 271 return nil, err 272 } 273 274 return ReadAddr(rw, buf) 275 } 276 277 func ReadAddr(r io.Reader, b []byte) (Addr, error) { 278 if len(b) < MaxAddrLen { 279 return nil, io.ErrShortBuffer 280 } 281 _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type 282 if err != nil { 283 return nil, err 284 } 285 286 switch b[0] { 287 case AtypDomainName: 288 _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length 289 if err != nil { 290 return nil, err 291 } 292 domainLength := uint16(b[1]) 293 _, err = io.ReadFull(r, b[2:2+domainLength+2]) 294 return b[:1+1+domainLength+2], err 295 case AtypIPv4: 296 _, err = io.ReadFull(r, b[1:1+net.IPv4len+2]) 297 return b[:1+net.IPv4len+2], err 298 case AtypIPv6: 299 _, err = io.ReadFull(r, b[1:1+net.IPv6len+2]) 300 return b[:1+net.IPv6len+2], err 301 } 302 303 return nil, ErrAddressNotSupported 304 } 305 306 // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed. 307 func SplitAddr(b []byte) Addr { 308 addrLen := 1 309 if len(b) < addrLen { 310 return nil 311 } 312 313 switch b[0] { 314 case AtypDomainName: 315 if len(b) < 2 { 316 return nil 317 } 318 addrLen = 1 + 1 + int(b[1]) + 2 319 case AtypIPv4: 320 addrLen = 1 + net.IPv4len + 2 321 case AtypIPv6: 322 addrLen = 1 + net.IPv6len + 2 323 default: 324 return nil 325 326 } 327 328 if len(b) < addrLen { 329 return nil 330 } 331 332 return b[:addrLen] 333 } 334 335 // ParseAddr parses the address in string s. Returns nil if failed. 336 func ParseAddr(s string) Addr { 337 var addr Addr 338 host, port, err := net.SplitHostPort(s) 339 if err != nil { 340 return nil 341 } 342 if ip := net.ParseIP(host); ip != nil { 343 if ip4 := ip.To4(); ip4 != nil { 344 addr = make([]byte, 1+net.IPv4len+2) 345 addr[0] = AtypIPv4 346 copy(addr[1:], ip4) 347 } else { 348 addr = make([]byte, 1+net.IPv6len+2) 349 addr[0] = AtypIPv6 350 copy(addr[1:], ip) 351 } 352 } else { 353 if len(host) > 255 { 354 return nil 355 } 356 addr = make([]byte, 1+1+len(host)+2) 357 addr[0] = AtypDomainName 358 addr[1] = byte(len(host)) 359 copy(addr[2:], host) 360 } 361 362 portnum, err := strconv.ParseUint(port, 10, 16) 363 if err != nil { 364 return nil 365 } 366 367 addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum) 368 369 return addr 370 } 371 372 // ParseAddrToSocksAddr parse a socks addr from net.addr 373 // This is a fast path of ParseAddr(addr.String()) 374 func ParseAddrToSocksAddr(addr net.Addr) Addr { 375 var hostip net.IP 376 var port int 377 if udpaddr, ok := addr.(*net.UDPAddr); ok { 378 hostip = udpaddr.IP 379 port = udpaddr.Port 380 } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { 381 hostip = tcpaddr.IP 382 port = tcpaddr.Port 383 } 384 385 // fallback parse 386 if hostip == nil { 387 return ParseAddr(addr.String()) 388 } 389 390 var parsed Addr 391 if ip4 := hostip.To4(); ip4.DefaultMask() != nil { 392 parsed = make([]byte, 1+net.IPv4len+2) 393 parsed[0] = AtypIPv4 394 copy(parsed[1:], ip4) 395 binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port)) 396 397 } else { 398 parsed = make([]byte, 1+net.IPv6len+2) 399 parsed[0] = AtypIPv6 400 copy(parsed[1:], hostip) 401 binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port)) 402 } 403 return parsed 404 } 405 406 // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` 407 func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { 408 if len(packet) < 5 { 409 err = errors.New("insufficient length of packet") 410 return 411 } 412 413 // packet[0] and packet[1] are reserved 414 if !bytes.Equal(packet[:2], []byte{0, 0}) { 415 err = errors.New("reserved fields should be zero") 416 return 417 } 418 419 if packet[2] != 0 /* fragments */ { 420 err = errors.New("discarding fragmented payload") 421 return 422 } 423 424 addr = SplitAddr(packet[3:]) 425 if addr == nil { 426 err = errors.New("failed to read UDP header") 427 } 428 429 payload = packet[3+len(addr):] 430 return 431 } 432 433 func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { 434 if addr == nil { 435 err = errors.New("address is invalid") 436 return 437 } 438 packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{}) 439 return 440 }