github.com/yaling888/clash@v1.53.0/transport/socks5/socks5.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/netip"
    10  	"strconv"
    11  
    12  	"github.com/yaling888/clash/common/pool"
    13  	"github.com/yaling888/clash/component/auth"
    14  )
    15  
    16  // Error represents a SOCKS error
    17  type Error byte
    18  
    19  func (err Error) Error() string {
    20  	return "SOCKS error: " + strconv.Itoa(int(err))
    21  }
    22  
    23  // Command is request commands as defined in RFC 1928 section 4.
    24  type Command = uint8
    25  
    26  const Version = 5
    27  
    28  // SOCKS request commands as defined in RFC 1928 section 4.
    29  const (
    30  	CmdConnect      Command = 1
    31  	CmdBind         Command = 2
    32  	CmdUDPAssociate Command = 3
    33  )
    34  
    35  // SOCKS address types as defined in RFC 1928 section 5.
    36  const (
    37  	AtypIPv4       = 1
    38  	AtypDomainName = 3
    39  	AtypIPv6       = 4
    40  )
    41  
    42  // MaxAddrLen is the maximum size of SOCKS address in bytes.
    43  const MaxAddrLen = 1 + 1 + 255 + 2
    44  
    45  // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
    46  const MaxAuthLen = 255
    47  
    48  // Addr represents a SOCKS address as defined in RFC 1928 section 5.
    49  type Addr []byte
    50  
    51  func (a Addr) String() string {
    52  	var host, port string
    53  
    54  	switch a[0] {
    55  	case AtypDomainName:
    56  		hostLen := uint16(a[1])
    57  		host = string(a[2 : 2+hostLen])
    58  		port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
    59  	case AtypIPv4:
    60  		host = net.IP(a[1 : 1+net.IPv4len]).String()
    61  		port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
    62  	case AtypIPv6:
    63  		host = net.IP(a[1 : 1+net.IPv6len]).String()
    64  		port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
    65  	}
    66  
    67  	return net.JoinHostPort(host, port)
    68  }
    69  
    70  // UDPAddr converts a socks5.Addr to *net.UDPAddr
    71  func (a Addr) UDPAddr() *net.UDPAddr {
    72  	if len(a) == 0 {
    73  		return nil
    74  	}
    75  	switch a[0] {
    76  	case AtypIPv4:
    77  		var ip [net.IPv4len]byte
    78  		copy(ip[0:], a[1:1+net.IPv4len])
    79  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
    80  	case AtypIPv6:
    81  		var ip [net.IPv6len]byte
    82  		copy(ip[0:], a[1:1+net.IPv6len])
    83  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
    84  	}
    85  	// Other Atyp
    86  	return nil
    87  }
    88  
    89  // SOCKS errors as defined in RFC 1928 section 6.
    90  const (
    91  	ErrCommandNotSupported = Error(7)
    92  	ErrAddressNotSupported = Error(8)
    93  )
    94  
    95  // ErrAuth errors used to return a specific "Auth failed" error
    96  var ErrAuth = errors.New("auth failed")
    97  
    98  type User struct {
    99  	Username string
   100  	Password string
   101  }
   102  
   103  // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
   104  func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, err error) {
   105  	// Read RFC 1928 for request and reply structure and sizes.
   106  	buf := make([]byte, MaxAddrLen)
   107  	// read VER, NMETHODS, METHODS
   108  	if _, err = io.ReadFull(rw, buf[:2]); err != nil {
   109  		return
   110  	}
   111  	nmethods := buf[1]
   112  	if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
   113  		return
   114  	}
   115  
   116  	// write VER METHOD
   117  	if authenticator != nil {
   118  		if _, err = rw.Write([]byte{5, 2}); err != nil {
   119  			return
   120  		}
   121  
   122  		// Get header
   123  		header := make([]byte, 2)
   124  		if _, err = io.ReadFull(rw, header); err != nil {
   125  			return
   126  		}
   127  
   128  		authBuf := make([]byte, MaxAuthLen)
   129  		// Get username
   130  		userLen := int(header[1])
   131  		if userLen <= 0 {
   132  			_, _ = rw.Write([]byte{1, 1})
   133  			err = ErrAuth
   134  			return
   135  		}
   136  		if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
   137  			return
   138  		}
   139  		user := make([]byte, userLen)
   140  		copy(user, authBuf[:userLen])
   141  
   142  		// Get password
   143  		if _, err = rw.Read(header[:1]); err != nil {
   144  			return
   145  		}
   146  		passLen := int(header[0])
   147  		if passLen <= 0 {
   148  			_, _ = rw.Write([]byte{1, 1})
   149  			err = ErrAuth
   150  			return
   151  		}
   152  		if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
   153  			return
   154  		}
   155  		pass := authBuf[:passLen]
   156  
   157  		// Verify
   158  		if ok := authenticator.Verify(user, pass); !ok {
   159  			_, _ = rw.Write([]byte{1, 1})
   160  			err = ErrAuth
   161  			return
   162  		}
   163  
   164  		// Response auth state
   165  		if _, err = rw.Write([]byte{1, 0}); err != nil {
   166  			return
   167  		}
   168  	} else {
   169  		if _, err = rw.Write([]byte{5, 0}); err != nil {
   170  			return
   171  		}
   172  	}
   173  
   174  	// read VER CMD RSV ATYP DST.ADDR DST.PORT
   175  	if _, err = io.ReadFull(rw, buf[:3]); err != nil {
   176  		return
   177  	}
   178  
   179  	command = buf[1]
   180  	addr, err = ReadAddr(rw, buf)
   181  	if err != nil {
   182  		return
   183  	}
   184  
   185  	switch command {
   186  	case CmdConnect, CmdUDPAssociate:
   187  		// Acquire server listened address info
   188  		localAddr := ParseAddr(rw.LocalAddr().String())
   189  		if localAddr == nil {
   190  			err = ErrAddressNotSupported
   191  		} else {
   192  			// write VER REP RSV ATYP BND.ADDR BND.PORT
   193  			_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
   194  		}
   195  	case CmdBind:
   196  		fallthrough
   197  	default:
   198  		err = ErrCommandNotSupported
   199  	}
   200  
   201  	return
   202  }
   203  
   204  // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
   205  func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
   206  	buf := make([]byte, MaxAddrLen)
   207  	var err error
   208  
   209  	// VER, NMETHODS, METHODS
   210  	if user != nil {
   211  		_, err = rw.Write([]byte{5, 1, 2})
   212  	} else {
   213  		_, err = rw.Write([]byte{5, 1, 0})
   214  	}
   215  	if err != nil {
   216  		return nil, err
   217  	}
   218  
   219  	// VER, METHOD
   220  	if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	if buf[0] != 5 {
   225  		return nil, errors.New("SOCKS version error")
   226  	}
   227  
   228  	if buf[1] == 2 {
   229  		if user == nil {
   230  			return nil, ErrAuth
   231  		}
   232  
   233  		// password protocol version
   234  		authMsg := pool.BufferWriter{}
   235  		authMsg.PutUint8(1)
   236  		authMsg.PutUint8(uint8(len(user.Username)))
   237  		authMsg.PutString(user.Username)
   238  		authMsg.PutUint8(uint8(len(user.Password)))
   239  		authMsg.PutString(user.Password)
   240  
   241  		if _, err := rw.Write(authMsg.Bytes()); err != nil {
   242  			return nil, err
   243  		}
   244  
   245  		if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   246  			return nil, err
   247  		}
   248  
   249  		if buf[1] != 0 {
   250  			return nil, errors.New("rejected username/password")
   251  		}
   252  	} else if buf[1] != 0 {
   253  		return nil, errors.New("SOCKS need auth")
   254  	}
   255  
   256  	// VER, CMD, RSV, ADDR
   257  	if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	// VER, REP, RSV
   262  	if _, err := io.ReadFull(rw, buf[:3]); err != nil {
   263  		return nil, err
   264  	}
   265  
   266  	return ReadAddr(rw, buf)
   267  }
   268  
   269  func ReadAddr(r io.Reader, b []byte) (Addr, error) {
   270  	if len(b) < MaxAddrLen {
   271  		return nil, io.ErrShortBuffer
   272  	}
   273  	_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	switch b[0] {
   279  	case AtypDomainName:
   280  		_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
   281  		if err != nil {
   282  			return nil, err
   283  		}
   284  		domainLength := uint16(b[1])
   285  		_, err = io.ReadFull(r, b[2:2+domainLength+2])
   286  		return b[:1+1+domainLength+2], err
   287  	case AtypIPv4:
   288  		_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
   289  		return b[:1+net.IPv4len+2], err
   290  	case AtypIPv6:
   291  		_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
   292  		return b[:1+net.IPv6len+2], err
   293  	}
   294  
   295  	return nil, ErrAddressNotSupported
   296  }
   297  
   298  // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
   299  func SplitAddr(b []byte) Addr {
   300  	addrLen := 1
   301  	if len(b) < addrLen {
   302  		return nil
   303  	}
   304  
   305  	switch b[0] {
   306  	case AtypDomainName:
   307  		if len(b) < 2 {
   308  			return nil
   309  		}
   310  		addrLen = 1 + 1 + int(b[1]) + 2
   311  	case AtypIPv4:
   312  		addrLen = 1 + net.IPv4len + 2
   313  	case AtypIPv6:
   314  		addrLen = 1 + net.IPv6len + 2
   315  	default:
   316  		return nil
   317  	}
   318  
   319  	if len(b) < addrLen {
   320  		return nil
   321  	}
   322  
   323  	return b[:addrLen]
   324  }
   325  
   326  // ParseAddr parses the address in string s. Returns nil if failed.
   327  func ParseAddr(s string) Addr {
   328  	buf := pool.BufferWriter{}
   329  	host, port, err := net.SplitHostPort(s)
   330  	if err != nil {
   331  		return nil
   332  	}
   333  	if ip, err := netip.ParseAddr(host); err == nil {
   334  		if ip.Is4() {
   335  			buf.PutUint8(AtypIPv4)
   336  		} else {
   337  			buf.PutUint8(AtypIPv6)
   338  		}
   339  		buf.PutSlice(ip.AsSlice())
   340  	} else {
   341  		if len(host) > 255 {
   342  			return nil
   343  		}
   344  		buf.PutUint8(AtypDomainName)
   345  		buf.PutUint8(byte(len(host)))
   346  		buf.PutString(host)
   347  	}
   348  
   349  	portNum, err := strconv.ParseUint(port, 10, 16)
   350  	if err != nil {
   351  		return nil
   352  	}
   353  
   354  	buf.PutUint16be(uint16(portNum))
   355  	return buf.Bytes()
   356  }
   357  
   358  // ParseAddrToSocksAddr parse a socks addr from net.addr
   359  // This is a fast path of ParseAddr(addr.String())
   360  func ParseAddrToSocksAddr(addr net.Addr) Addr {
   361  	var hostip net.IP
   362  	var port int
   363  	if udpaddr, ok := addr.(*net.UDPAddr); ok {
   364  		hostip = udpaddr.IP
   365  		port = udpaddr.Port
   366  	} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
   367  		hostip = tcpaddr.IP
   368  		port = tcpaddr.Port
   369  	}
   370  
   371  	// fallback parse
   372  	if hostip == nil {
   373  		return ParseAddr(addr.String())
   374  	}
   375  
   376  	var parsed pool.BufferWriter
   377  	if ip4 := hostip.To4(); ip4 != nil {
   378  		parsed = make([]byte, 0, 1+net.IPv4len+2)
   379  		parsed.PutUint8(AtypIPv4)
   380  		parsed.PutSlice(ip4)
   381  	} else {
   382  		parsed = make([]byte, 0, 1+net.IPv6len+2)
   383  		parsed.PutUint8(AtypIPv6)
   384  		parsed.PutSlice(hostip)
   385  	}
   386  
   387  	parsed.PutUint16be(uint16(port))
   388  	return parsed.Bytes()
   389  }
   390  
   391  func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
   392  	addr := addrPort.Addr().Unmap()
   393  	if !addr.IsValid() {
   394  		return nil
   395  	}
   396  
   397  	buf := pool.BufferWriter{}
   398  	if addr.Is4() {
   399  		buf.PutUint8(AtypIPv4)
   400  	} else {
   401  		buf.PutUint8(AtypIPv6)
   402  	}
   403  
   404  	buf.PutSlice(addr.AsSlice())
   405  	buf.PutUint16be(addrPort.Port())
   406  	return buf.Bytes()
   407  }
   408  
   409  // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
   410  func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
   411  	r := pool.BufferReader(packet)
   412  
   413  	if r.Len() < 5 {
   414  		err = errors.New("insufficient length of packet")
   415  		return
   416  	}
   417  
   418  	// packet[0] and packet[1] are reserved
   419  	reserved, r := r.SplitAt(2)
   420  	if !bytes.Equal(reserved, []byte{0, 0}) {
   421  		err = errors.New("reserved fields should be zero")
   422  		return
   423  	}
   424  
   425  	if r.ReadUint8() != 0 /* fragments */ {
   426  		err = errors.New("discarding fragmented payload")
   427  		return
   428  	}
   429  
   430  	addr = SplitAddr(r)
   431  	if addr == nil {
   432  		err = errors.New("failed to read UDP header")
   433  	}
   434  
   435  	_, payload = r.SplitAt(len(addr))
   436  	return
   437  }
   438  
   439  func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
   440  	if addr == nil {
   441  		err = errors.New("address is invalid")
   442  		return
   443  	}
   444  	w := pool.BufferWriter{}
   445  	w.PutSlice([]byte{0, 0, 0})
   446  	w.PutSlice(addr)
   447  	w.PutSlice(payload)
   448  	packet = w.Bytes()
   449  	return
   450  }