github.com/kelleygo/clashcore@v1.0.2/transport/socks5/socks5.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/netip"
    10  	"strconv"
    11  
    12  	"github.com/kelleygo/clashcore/component/auth"
    13  )
    14  
    15  // Error represents a SOCKS error
    16  type Error byte
    17  
    18  func (err Error) Error() string {
    19  	return "SOCKS error: " + strconv.Itoa(int(err))
    20  }
    21  
    22  // Command is request commands as defined in RFC 1928 section 4.
    23  type Command = uint8
    24  
    25  const Version = 5
    26  
    27  // SOCKS request commands as defined in RFC 1928 section 4.
    28  const (
    29  	CmdConnect      Command = 1
    30  	CmdBind         Command = 2
    31  	CmdUDPAssociate Command = 3
    32  )
    33  
    34  // SOCKS address types as defined in RFC 1928 section 5.
    35  const (
    36  	AtypIPv4       = 1
    37  	AtypDomainName = 3
    38  	AtypIPv6       = 4
    39  )
    40  
    41  // MaxAddrLen is the maximum size of SOCKS address in bytes.
    42  const MaxAddrLen = 1 + 1 + 255 + 2
    43  
    44  // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
    45  const MaxAuthLen = 255
    46  
    47  // Addr represents a SOCKS address as defined in RFC 1928 section 5.
    48  type Addr []byte
    49  
    50  func (a Addr) String() string {
    51  	var host, port string
    52  
    53  	switch a[0] {
    54  	case AtypDomainName:
    55  		hostLen := uint16(a[1])
    56  		host = string(a[2 : 2+hostLen])
    57  		port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
    58  	case AtypIPv4:
    59  		host = net.IP(a[1 : 1+net.IPv4len]).String()
    60  		port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
    61  	case AtypIPv6:
    62  		host = net.IP(a[1 : 1+net.IPv6len]).String()
    63  		port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
    64  	}
    65  
    66  	return net.JoinHostPort(host, port)
    67  }
    68  
    69  // UDPAddr converts a socks5.Addr to *net.UDPAddr
    70  func (a Addr) UDPAddr() *net.UDPAddr {
    71  	if len(a) == 0 {
    72  		return nil
    73  	}
    74  	switch a[0] {
    75  	case AtypIPv4:
    76  		var ip [net.IPv4len]byte
    77  		copy(ip[0:], a[1:1+net.IPv4len])
    78  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
    79  	case AtypIPv6:
    80  		var ip [net.IPv6len]byte
    81  		copy(ip[0:], a[1:1+net.IPv6len])
    82  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
    83  	}
    84  	// Other Atyp
    85  	return nil
    86  }
    87  
    88  // SOCKS errors as defined in RFC 1928 section 6.
    89  const (
    90  	ErrGeneralFailure       = Error(1)
    91  	ErrConnectionNotAllowed = Error(2)
    92  	ErrNetworkUnreachable   = Error(3)
    93  	ErrHostUnreachable      = Error(4)
    94  	ErrConnectionRefused    = Error(5)
    95  	ErrTTLExpired           = Error(6)
    96  	ErrCommandNotSupported  = Error(7)
    97  	ErrAddressNotSupported  = Error(8)
    98  )
    99  
   100  // Auth errors used to return a specific "Auth failed" error
   101  var ErrAuth = errors.New("auth failed")
   102  
   103  type User struct {
   104  	Username string
   105  	Password string
   106  }
   107  
   108  // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
   109  func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) {
   110  	// Read RFC 1928 for request and reply structure and sizes.
   111  	buf := make([]byte, MaxAddrLen)
   112  	// read VER, NMETHODS, METHODS
   113  	if _, err = io.ReadFull(rw, buf[:2]); err != nil {
   114  		return
   115  	}
   116  	nmethods := buf[1]
   117  	if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
   118  		return
   119  	}
   120  
   121  	// write VER METHOD
   122  	if authenticator != nil {
   123  		if _, err = rw.Write([]byte{5, 2}); err != nil {
   124  			return
   125  		}
   126  
   127  		// Get header
   128  		header := make([]byte, 2)
   129  		if _, err = io.ReadFull(rw, header); err != nil {
   130  			return
   131  		}
   132  
   133  		authBuf := make([]byte, MaxAuthLen)
   134  		// Get username
   135  		userLen := int(header[1])
   136  		if userLen <= 0 {
   137  			rw.Write([]byte{1, 1})
   138  			err = ErrAuth
   139  			return
   140  		}
   141  		if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
   142  			return
   143  		}
   144  		user := string(authBuf[:userLen])
   145  
   146  		// Get password
   147  		if _, err = rw.Read(header[:1]); err != nil {
   148  			return
   149  		}
   150  		passLen := int(header[0])
   151  		if passLen <= 0 {
   152  			rw.Write([]byte{1, 1})
   153  			err = ErrAuth
   154  			return
   155  		}
   156  		if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
   157  			return
   158  		}
   159  		pass := string(authBuf[:passLen])
   160  
   161  		// Verify
   162  		if ok := authenticator.Verify(string(user), string(pass)); !ok {
   163  			rw.Write([]byte{1, 1})
   164  			err = ErrAuth
   165  			return
   166  		}
   167  
   168  		// Response auth state
   169  		if _, err = rw.Write([]byte{1, 0}); err != nil {
   170  			return
   171  		}
   172  	} else {
   173  		if _, err = rw.Write([]byte{5, 0}); err != nil {
   174  			return
   175  		}
   176  	}
   177  
   178  	// read VER CMD RSV ATYP DST.ADDR DST.PORT
   179  	if _, err = io.ReadFull(rw, buf[:3]); err != nil {
   180  		return
   181  	}
   182  
   183  	command = buf[1]
   184  	addr, err = ReadAddr(rw, buf)
   185  	if err != nil {
   186  		return
   187  	}
   188  
   189  	switch command {
   190  	case CmdConnect, CmdUDPAssociate:
   191  		// Acquire server listened address info
   192  		localAddr := ParseAddr(rw.LocalAddr().String())
   193  		if localAddr == nil {
   194  			err = ErrAddressNotSupported
   195  		} else {
   196  			// write VER REP RSV ATYP BND.ADDR BND.PORT
   197  			_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
   198  		}
   199  	case CmdBind:
   200  		fallthrough
   201  	default:
   202  		err = ErrCommandNotSupported
   203  	}
   204  
   205  	return
   206  }
   207  
   208  // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
   209  func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
   210  	buf := make([]byte, MaxAddrLen)
   211  	var err error
   212  
   213  	// VER, NMETHODS, METHODS
   214  	if user != nil {
   215  		_, err = rw.Write([]byte{5, 1, 2})
   216  	} else {
   217  		_, err = rw.Write([]byte{5, 1, 0})
   218  	}
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	// VER, METHOD
   224  	if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	if buf[0] != 5 {
   229  		return nil, errors.New("SOCKS version error")
   230  	}
   231  
   232  	if buf[1] == 2 {
   233  		if user == nil {
   234  			return nil, ErrAuth
   235  		}
   236  
   237  		// password protocol version
   238  		authMsg := &bytes.Buffer{}
   239  		authMsg.WriteByte(1)
   240  		authMsg.WriteByte(uint8(len(user.Username)))
   241  		authMsg.WriteString(user.Username)
   242  		authMsg.WriteByte(uint8(len(user.Password)))
   243  		authMsg.WriteString(user.Password)
   244  
   245  		if _, err := rw.Write(authMsg.Bytes()); err != nil {
   246  			return nil, err
   247  		}
   248  
   249  		if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   250  			return nil, err
   251  		}
   252  
   253  		if buf[1] != 0 {
   254  			return nil, errors.New("rejected username/password")
   255  		}
   256  	} else if buf[1] != 0 {
   257  		return nil, errors.New("SOCKS need auth")
   258  	}
   259  
   260  	// VER, CMD, RSV, ADDR
   261  	if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
   262  		return nil, err
   263  	}
   264  
   265  	// VER, REP, RSV
   266  	if _, err := io.ReadFull(rw, buf[:3]); err != nil {
   267  		return nil, err
   268  	}
   269  
   270  	return ReadAddr(rw, buf)
   271  }
   272  
   273  func ReadAddr(r io.Reader, b []byte) (Addr, error) {
   274  	if len(b) < MaxAddrLen {
   275  		return nil, io.ErrShortBuffer
   276  	}
   277  	_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  
   282  	switch b[0] {
   283  	case AtypDomainName:
   284  		_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  		domainLength := uint16(b[1])
   289  		_, err = io.ReadFull(r, b[2:2+domainLength+2])
   290  		return b[:1+1+domainLength+2], err
   291  	case AtypIPv4:
   292  		_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
   293  		return b[:1+net.IPv4len+2], err
   294  	case AtypIPv6:
   295  		_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
   296  		return b[:1+net.IPv6len+2], err
   297  	}
   298  
   299  	return nil, ErrAddressNotSupported
   300  }
   301  
   302  func ReadAddr0(r io.Reader) (Addr, error) {
   303  	aType, err := ReadByte(r) // read 1st byte for address type
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  
   308  	switch aType {
   309  	case AtypDomainName:
   310  		var domainLength byte
   311  		domainLength, err = ReadByte(r) // read 2nd byte for domain length
   312  		if err != nil {
   313  			return nil, err
   314  		}
   315  		b := make([]byte, 1+1+uint16(domainLength)+2)
   316  		_, err = io.ReadFull(r, b[2:])
   317  		b[0] = aType
   318  		b[1] = domainLength
   319  		return b, err
   320  	case AtypIPv4:
   321  		var b [1 + net.IPv4len + 2]byte
   322  		_, err = io.ReadFull(r, b[1:])
   323  		b[0] = aType
   324  		return b[:], err
   325  	case AtypIPv6:
   326  		var b [1 + net.IPv6len + 2]byte
   327  		_, err = io.ReadFull(r, b[1:])
   328  		b[0] = aType
   329  		return b[:], err
   330  	}
   331  
   332  	return nil, ErrAddressNotSupported
   333  }
   334  
   335  func ReadByte(reader io.Reader) (byte, error) {
   336  	if br, isBr := reader.(io.ByteReader); isBr {
   337  		return br.ReadByte()
   338  	}
   339  	var b [1]byte
   340  	if _, err := io.ReadFull(reader, b[:]); err != nil {
   341  		return 0, err
   342  	}
   343  	return b[0], nil
   344  }
   345  
   346  // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
   347  func SplitAddr(b []byte) Addr {
   348  	addrLen := 1
   349  	if len(b) < addrLen {
   350  		return nil
   351  	}
   352  
   353  	switch b[0] {
   354  	case AtypDomainName:
   355  		if len(b) < 2 {
   356  			return nil
   357  		}
   358  		addrLen = 1 + 1 + int(b[1]) + 2
   359  	case AtypIPv4:
   360  		addrLen = 1 + net.IPv4len + 2
   361  	case AtypIPv6:
   362  		addrLen = 1 + net.IPv6len + 2
   363  	default:
   364  		return nil
   365  
   366  	}
   367  
   368  	if len(b) < addrLen {
   369  		return nil
   370  	}
   371  
   372  	return b[:addrLen]
   373  }
   374  
   375  // ParseAddr parses the address in string s. Returns nil if failed.
   376  func ParseAddr(s string) Addr {
   377  	var addr Addr
   378  	host, port, err := net.SplitHostPort(s)
   379  	if err != nil {
   380  		return nil
   381  	}
   382  	if ip := net.ParseIP(host); ip != nil {
   383  		if ip4 := ip.To4(); ip4 != nil {
   384  			addr = make([]byte, 1+net.IPv4len+2)
   385  			addr[0] = AtypIPv4
   386  			copy(addr[1:], ip4)
   387  		} else {
   388  			addr = make([]byte, 1+net.IPv6len+2)
   389  			addr[0] = AtypIPv6
   390  			copy(addr[1:], ip)
   391  		}
   392  	} else {
   393  		if len(host) > 255 {
   394  			return nil
   395  		}
   396  		addr = make([]byte, 1+1+len(host)+2)
   397  		addr[0] = AtypDomainName
   398  		addr[1] = byte(len(host))
   399  		copy(addr[2:], host)
   400  	}
   401  
   402  	portnum, err := strconv.ParseUint(port, 10, 16)
   403  	if err != nil {
   404  		return nil
   405  	}
   406  
   407  	addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
   408  
   409  	return addr
   410  }
   411  
   412  // ParseAddrToSocksAddr parse a socks addr from net.addr
   413  // This is a fast path of ParseAddr(addr.String())
   414  func ParseAddrToSocksAddr(addr net.Addr) Addr {
   415  	var hostip net.IP
   416  	var port int
   417  	if udpaddr, ok := addr.(*net.UDPAddr); ok {
   418  		hostip = udpaddr.IP
   419  		port = udpaddr.Port
   420  	} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
   421  		hostip = tcpaddr.IP
   422  		port = tcpaddr.Port
   423  	}
   424  
   425  	// fallback parse
   426  	if hostip == nil {
   427  		return ParseAddr(addr.String())
   428  	}
   429  
   430  	var parsed Addr
   431  	if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
   432  		parsed = make([]byte, 1+net.IPv4len+2)
   433  		parsed[0] = AtypIPv4
   434  		copy(parsed[1:], ip4)
   435  		binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
   436  
   437  	} else {
   438  		parsed = make([]byte, 1+net.IPv6len+2)
   439  		parsed[0] = AtypIPv6
   440  		copy(parsed[1:], hostip)
   441  		binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
   442  	}
   443  	return parsed
   444  }
   445  
   446  func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
   447  	addr := addrPort.Addr()
   448  	if addr.Is4() {
   449  		ip4 := addr.As4()
   450  		return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())}
   451  	}
   452  
   453  	buf := make([]byte, 1+net.IPv6len+2)
   454  	buf[0] = AtypIPv6
   455  	copy(buf[1:], addr.AsSlice())
   456  	buf[1+net.IPv6len] = byte(addrPort.Port() >> 8)
   457  	buf[1+net.IPv6len+1] = byte(addrPort.Port())
   458  	return buf
   459  }
   460  
   461  // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
   462  func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
   463  	if len(packet) < 5 {
   464  		err = errors.New("insufficient length of packet")
   465  		return
   466  	}
   467  
   468  	// packet[0] and packet[1] are reserved
   469  	if !bytes.Equal(packet[:2], []byte{0, 0}) {
   470  		err = errors.New("reserved fields should be zero")
   471  		return
   472  	}
   473  
   474  	if packet[2] != 0 /* fragments */ {
   475  		err = errors.New("discarding fragmented payload")
   476  		return
   477  	}
   478  
   479  	addr = SplitAddr(packet[3:])
   480  	if addr == nil {
   481  		err = errors.New("failed to read UDP header")
   482  	}
   483  
   484  	payload = packet[3+len(addr):]
   485  	return
   486  }
   487  
   488  func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
   489  	if addr == nil {
   490  		err = errors.New("address is invalid")
   491  		return
   492  	}
   493  	packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
   494  	return
   495  }