github.com/igoogolx/clash@v1.19.8/transport/socks5/socks5.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/netip"
    10  	"strconv"
    11  
    12  	"github.com/igoogolx/clash/component/auth"
    13  
    14  	"github.com/Dreamacro/protobytes"
    15  )
    16  
    17  // Error represents a SOCKS error
    18  type Error byte
    19  
    20  func (err Error) Error() string {
    21  	return "SOCKS error: " + strconv.Itoa(int(err))
    22  }
    23  
    24  // Command is request commands as defined in RFC 1928 section 4.
    25  type Command = uint8
    26  
    27  const Version = 5
    28  
    29  // SOCKS request commands as defined in RFC 1928 section 4.
    30  const (
    31  	CmdConnect      Command = 1
    32  	CmdBind         Command = 2
    33  	CmdUDPAssociate Command = 3
    34  )
    35  
    36  // SOCKS address types as defined in RFC 1928 section 5.
    37  const (
    38  	AtypIPv4       = 1
    39  	AtypDomainName = 3
    40  	AtypIPv6       = 4
    41  )
    42  
    43  // MaxAddrLen is the maximum size of SOCKS address in bytes.
    44  const MaxAddrLen = 1 + 1 + 255 + 2
    45  
    46  // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
    47  const MaxAuthLen = 255
    48  
    49  // Addr represents a SOCKS address as defined in RFC 1928 section 5.
    50  type Addr []byte
    51  
    52  func (a Addr) String() string {
    53  	var host, port string
    54  
    55  	switch a[0] {
    56  	case AtypDomainName:
    57  		hostLen := uint16(a[1])
    58  		host = string(a[2 : 2+hostLen])
    59  		port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
    60  	case AtypIPv4:
    61  		host = net.IP(a[1 : 1+net.IPv4len]).String()
    62  		port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
    63  	case AtypIPv6:
    64  		host = net.IP(a[1 : 1+net.IPv6len]).String()
    65  		port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
    66  	}
    67  
    68  	return net.JoinHostPort(host, port)
    69  }
    70  
    71  // UDPAddr converts a socks5.Addr to *net.UDPAddr
    72  func (a Addr) UDPAddr() *net.UDPAddr {
    73  	if len(a) == 0 {
    74  		return nil
    75  	}
    76  	switch a[0] {
    77  	case AtypIPv4:
    78  		var ip [net.IPv4len]byte
    79  		copy(ip[0:], a[1:1+net.IPv4len])
    80  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
    81  	case AtypIPv6:
    82  		var ip [net.IPv6len]byte
    83  		copy(ip[0:], a[1:1+net.IPv6len])
    84  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
    85  	}
    86  	// Other Atyp
    87  	return nil
    88  }
    89  
    90  // SOCKS errors as defined in RFC 1928 section 6.
    91  const (
    92  	ErrGeneralFailure       = Error(1)
    93  	ErrConnectionNotAllowed = Error(2)
    94  	ErrNetworkUnreachable   = Error(3)
    95  	ErrHostUnreachable      = Error(4)
    96  	ErrConnectionRefused    = Error(5)
    97  	ErrTTLExpired           = Error(6)
    98  	ErrCommandNotSupported  = Error(7)
    99  	ErrAddressNotSupported  = Error(8)
   100  )
   101  
   102  // Auth errors used to return a specific "Auth failed" error
   103  var ErrAuth = errors.New("auth failed")
   104  
   105  type User struct {
   106  	Username string
   107  	Password string
   108  }
   109  
   110  // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
   111  func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) {
   112  	// Read RFC 1928 for request and reply structure and sizes.
   113  	buf := make([]byte, MaxAddrLen)
   114  	// read VER, NMETHODS, METHODS
   115  	if _, err = io.ReadFull(rw, buf[:2]); err != nil {
   116  		return
   117  	}
   118  	nmethods := buf[1]
   119  	if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
   120  		return
   121  	}
   122  
   123  	// write VER METHOD
   124  	if authenticator != nil {
   125  		if _, err = rw.Write([]byte{5, 2}); err != nil {
   126  			return
   127  		}
   128  
   129  		// Get header
   130  		header := make([]byte, 2)
   131  		if _, err = io.ReadFull(rw, header); err != nil {
   132  			return
   133  		}
   134  
   135  		authBuf := make([]byte, MaxAuthLen)
   136  		// Get username
   137  		userLen := int(header[1])
   138  		if userLen <= 0 {
   139  			rw.Write([]byte{1, 1})
   140  			err = ErrAuth
   141  			return
   142  		}
   143  		if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
   144  			return
   145  		}
   146  		user := string(authBuf[:userLen])
   147  
   148  		// Get password
   149  		if _, err = rw.Read(header[:1]); err != nil {
   150  			return
   151  		}
   152  		passLen := int(header[0])
   153  		if passLen <= 0 {
   154  			rw.Write([]byte{1, 1})
   155  			err = ErrAuth
   156  			return
   157  		}
   158  		if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
   159  			return
   160  		}
   161  		pass := string(authBuf[:passLen])
   162  
   163  		// Verify
   164  		if ok := authenticator.Verify(user, pass); !ok {
   165  			rw.Write([]byte{1, 1})
   166  			err = ErrAuth
   167  			return
   168  		}
   169  
   170  		// Response auth state
   171  		if _, err = rw.Write([]byte{1, 0}); err != nil {
   172  			return
   173  		}
   174  	} else {
   175  		if _, err = rw.Write([]byte{5, 0}); err != nil {
   176  			return
   177  		}
   178  	}
   179  
   180  	// read VER CMD RSV ATYP DST.ADDR DST.PORT
   181  	if _, err = io.ReadFull(rw, buf[:3]); err != nil {
   182  		return
   183  	}
   184  
   185  	command = buf[1]
   186  	addr, err = ReadAddr(rw, buf)
   187  	if err != nil {
   188  		return
   189  	}
   190  
   191  	switch command {
   192  	case CmdConnect, CmdUDPAssociate:
   193  		// Acquire server listened address info
   194  		localAddr := ParseAddr(rw.LocalAddr().String())
   195  		if localAddr == nil {
   196  			err = ErrAddressNotSupported
   197  		} else {
   198  			// write VER REP RSV ATYP BND.ADDR BND.PORT
   199  			_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
   200  		}
   201  	case CmdBind:
   202  		fallthrough
   203  	default:
   204  		err = ErrCommandNotSupported
   205  	}
   206  
   207  	return
   208  }
   209  
   210  // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
   211  func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
   212  	buf := make([]byte, MaxAddrLen)
   213  	var err error
   214  
   215  	// VER, NMETHODS, METHODS
   216  	if user != nil {
   217  		_, err = rw.Write([]byte{5, 1, 2})
   218  	} else {
   219  		_, err = rw.Write([]byte{5, 1, 0})
   220  	}
   221  	if err != nil {
   222  		return nil, err
   223  	}
   224  
   225  	// VER, METHOD
   226  	if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   227  		return nil, err
   228  	}
   229  
   230  	if buf[0] != 5 {
   231  		return nil, errors.New("SOCKS version error")
   232  	}
   233  
   234  	if buf[1] == 2 {
   235  		if user == nil {
   236  			return nil, ErrAuth
   237  		}
   238  
   239  		// password protocol version
   240  		authMsg := protobytes.BytesWriter{}
   241  		authMsg.PutUint8(1)
   242  		authMsg.PutUint8(uint8(len(user.Username)))
   243  		authMsg.PutString(user.Username)
   244  		authMsg.PutUint8(uint8(len(user.Password)))
   245  		authMsg.PutString(user.Password)
   246  
   247  		if _, err := rw.Write(authMsg.Bytes()); err != nil {
   248  			return nil, err
   249  		}
   250  
   251  		if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   252  			return nil, err
   253  		}
   254  
   255  		if buf[1] != 0 {
   256  			return nil, errors.New("rejected username/password")
   257  		}
   258  	} else if buf[1] != 0 {
   259  		return nil, errors.New("SOCKS need auth")
   260  	}
   261  
   262  	// VER, CMD, RSV, ADDR
   263  	if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	// VER, REP, RSV
   268  	if _, err := io.ReadFull(rw, buf[:3]); err != nil {
   269  		return nil, err
   270  	}
   271  
   272  	return ReadAddr(rw, buf)
   273  }
   274  
   275  func ReadAddr(r io.Reader, b []byte) (Addr, error) {
   276  	if len(b) < MaxAddrLen {
   277  		return nil, io.ErrShortBuffer
   278  	}
   279  	_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  
   284  	switch b[0] {
   285  	case AtypDomainName:
   286  		_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
   287  		if err != nil {
   288  			return nil, err
   289  		}
   290  		domainLength := uint16(b[1])
   291  		_, err = io.ReadFull(r, b[2:2+domainLength+2])
   292  		return b[:1+1+domainLength+2], err
   293  	case AtypIPv4:
   294  		_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
   295  		return b[:1+net.IPv4len+2], err
   296  	case AtypIPv6:
   297  		_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
   298  		return b[:1+net.IPv6len+2], err
   299  	}
   300  
   301  	return nil, ErrAddressNotSupported
   302  }
   303  
   304  // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
   305  func SplitAddr(b []byte) Addr {
   306  	addrLen := 1
   307  	if len(b) < addrLen {
   308  		return nil
   309  	}
   310  
   311  	switch b[0] {
   312  	case AtypDomainName:
   313  		if len(b) < 2 {
   314  			return nil
   315  		}
   316  		addrLen = 1 + 1 + int(b[1]) + 2
   317  	case AtypIPv4:
   318  		addrLen = 1 + net.IPv4len + 2
   319  	case AtypIPv6:
   320  		addrLen = 1 + net.IPv6len + 2
   321  	default:
   322  		return nil
   323  
   324  	}
   325  
   326  	if len(b) < addrLen {
   327  		return nil
   328  	}
   329  
   330  	return b[:addrLen]
   331  }
   332  
   333  // ParseAddr parses the address in string s. Returns nil if failed.
   334  func ParseAddr(s string) Addr {
   335  	buf := protobytes.BytesWriter{}
   336  	host, port, err := net.SplitHostPort(s)
   337  	if err != nil {
   338  		return nil
   339  	}
   340  	if ip := net.ParseIP(host); ip != nil {
   341  		if ip4 := ip.To4(); ip4 != nil {
   342  			buf.PutUint8(AtypIPv4)
   343  			buf.PutSlice(ip4)
   344  		} else {
   345  			buf.PutUint8(AtypIPv6)
   346  			buf.PutSlice(ip)
   347  		}
   348  	} else {
   349  		if len(host) > 255 {
   350  			return nil
   351  		}
   352  		buf.PutUint8(AtypDomainName)
   353  		buf.PutUint8(byte(len(host)))
   354  		buf.PutString(host)
   355  	}
   356  
   357  	portnum, err := strconv.ParseUint(port, 10, 16)
   358  	if err != nil {
   359  		return nil
   360  	}
   361  
   362  	buf.PutUint16be(uint16(portnum))
   363  	return Addr(buf.Bytes())
   364  }
   365  
   366  // ParseAddrToSocksAddr parse a socks addr from net.addr
   367  // This is a fast path of ParseAddr(addr.String())
   368  func ParseAddrToSocksAddr(addr net.Addr) Addr {
   369  	var hostip net.IP
   370  	var port int
   371  	if udpaddr, ok := addr.(*net.UDPAddr); ok {
   372  		hostip = udpaddr.IP
   373  		port = udpaddr.Port
   374  	} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
   375  		hostip = tcpaddr.IP
   376  		port = tcpaddr.Port
   377  	}
   378  
   379  	// fallback parse
   380  	if hostip == nil {
   381  		return ParseAddr(addr.String())
   382  	}
   383  
   384  	var parsed protobytes.BytesWriter
   385  	if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
   386  		parsed = make([]byte, 0, 1+net.IPv4len+2)
   387  		parsed.PutUint8(AtypIPv4)
   388  		parsed.PutSlice(ip4)
   389  		parsed.PutUint16be(uint16(port))
   390  	} else {
   391  		parsed = make([]byte, 0, 1+net.IPv6len+2)
   392  		parsed.PutUint8(AtypIPv6)
   393  		parsed.PutSlice(hostip)
   394  		parsed.PutUint16be(uint16(port))
   395  	}
   396  	return Addr(parsed)
   397  }
   398  
   399  func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
   400  	addr := addrPort.Addr()
   401  	if addr.Is4() {
   402  		ip4 := addr.As4()
   403  		return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())}
   404  	}
   405  
   406  	buf := make([]byte, 1+net.IPv6len+2)
   407  	buf[0] = AtypIPv6
   408  	copy(buf[1:], addr.AsSlice())
   409  	buf[1+net.IPv6len] = byte(addrPort.Port() >> 8)
   410  	buf[1+net.IPv6len+1] = byte(addrPort.Port())
   411  	return buf
   412  }
   413  
   414  // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
   415  func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
   416  	r := protobytes.BytesReader(packet)
   417  
   418  	if r.Len() < 5 {
   419  		err = errors.New("insufficient length of packet")
   420  		return
   421  	}
   422  
   423  	// packet[0] and packet[1] are reserved
   424  	reserved, r := r.SplitAt(2)
   425  	if !bytes.Equal(reserved, []byte{0, 0}) {
   426  		err = errors.New("reserved fields should be zero")
   427  		return
   428  	}
   429  
   430  	if r.ReadUint8() != 0 /* fragments */ {
   431  		err = errors.New("discarding fragmented payload")
   432  		return
   433  	}
   434  
   435  	addr = SplitAddr(r)
   436  	if addr == nil {
   437  		err = errors.New("failed to read UDP header")
   438  	}
   439  
   440  	_, payload = r.SplitAt(len(addr))
   441  	return
   442  }
   443  
   444  func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
   445  	if addr == nil {
   446  		err = errors.New("address is invalid")
   447  		return
   448  	}
   449  	w := protobytes.BytesWriter{}
   450  	w.PutSlice([]byte{0, 0, 0})
   451  	w.PutSlice(addr)
   452  	w.PutSlice(payload)
   453  	packet = w.Bytes()
   454  	return
   455  }