github.com/chwjbn/xclash@v0.2.0/transport/socks5/socks5.go (about)

     1  package socks5
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"strconv"
    10  
    11  	"github.com/chwjbn/xclash/component/auth"
    12  )
    13  
    14  // Error represents a SOCKS error
    15  type Error byte
    16  
    17  func (err Error) Error() string {
    18  	return "SOCKS error: " + strconv.Itoa(int(err))
    19  }
    20  
    21  // Command is request commands as defined in RFC 1928 section 4.
    22  type Command = uint8
    23  
    24  const Version = 5
    25  
    26  // SOCKS request commands as defined in RFC 1928 section 4.
    27  const (
    28  	CmdConnect      Command = 1
    29  	CmdBind         Command = 2
    30  	CmdUDPAssociate Command = 3
    31  )
    32  
    33  // SOCKS address types as defined in RFC 1928 section 5.
    34  const (
    35  	AtypIPv4       = 1
    36  	AtypDomainName = 3
    37  	AtypIPv6       = 4
    38  )
    39  
    40  // MaxAddrLen is the maximum size of SOCKS address in bytes.
    41  const MaxAddrLen = 1 + 1 + 255 + 2
    42  
    43  // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
    44  const MaxAuthLen = 255
    45  
    46  // Addr represents a SOCKS address as defined in RFC 1928 section 5.
    47  type Addr []byte
    48  
    49  func (a Addr) String() string {
    50  	var host, port string
    51  
    52  	switch a[0] {
    53  	case AtypDomainName:
    54  		hostLen := uint16(a[1])
    55  		host = string(a[2 : 2+hostLen])
    56  		port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
    57  	case AtypIPv4:
    58  		host = net.IP(a[1 : 1+net.IPv4len]).String()
    59  		port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
    60  	case AtypIPv6:
    61  		host = net.IP(a[1 : 1+net.IPv6len]).String()
    62  		port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
    63  	}
    64  
    65  	return net.JoinHostPort(host, port)
    66  }
    67  
    68  // UDPAddr converts a socks5.Addr to *net.UDPAddr
    69  func (a Addr) UDPAddr() *net.UDPAddr {
    70  	if len(a) == 0 {
    71  		return nil
    72  	}
    73  	switch a[0] {
    74  	case AtypIPv4:
    75  		var ip [net.IPv4len]byte
    76  		copy(ip[0:], a[1:1+net.IPv4len])
    77  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
    78  	case AtypIPv6:
    79  		var ip [net.IPv6len]byte
    80  		copy(ip[0:], a[1:1+net.IPv6len])
    81  		return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
    82  	}
    83  	// Other Atyp
    84  	return nil
    85  }
    86  
    87  // SOCKS errors as defined in RFC 1928 section 6.
    88  const (
    89  	ErrGeneralFailure       = Error(1)
    90  	ErrConnectionNotAllowed = Error(2)
    91  	ErrNetworkUnreachable   = Error(3)
    92  	ErrHostUnreachable      = Error(4)
    93  	ErrConnectionRefused    = Error(5)
    94  	ErrTTLExpired           = Error(6)
    95  	ErrCommandNotSupported  = Error(7)
    96  	ErrAddressNotSupported  = Error(8)
    97  )
    98  
    99  // Auth errors used to return a specific "Auth failed" error
   100  var ErrAuth = errors.New("auth failed")
   101  
   102  type User struct {
   103  	Username string
   104  	Password string
   105  }
   106  
   107  // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
   108  func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (authUser *auth.AuthUser, addr Addr, command Command, err error) {
   109  	// Read RFC 1928 for request and reply structure and sizes.
   110  	buf := make([]byte, MaxAddrLen)
   111  	// read VER, NMETHODS, METHODS
   112  	if _, err = io.ReadFull(rw, buf[:2]); err != nil {
   113  		return
   114  	}
   115  	nmethods := buf[1]
   116  	if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
   117  		return
   118  	}
   119  
   120  	// write VER METHOD
   121  	if authenticator != nil {
   122  		if _, err = rw.Write([]byte{5, 2}); err != nil {
   123  			return
   124  		}
   125  
   126  		// Get header
   127  		header := make([]byte, 2)
   128  		if _, err = io.ReadFull(rw, header); err != nil {
   129  			return
   130  		}
   131  
   132  		authBuf := make([]byte, MaxAuthLen)
   133  		// Get username
   134  		userLen := int(header[1])
   135  		if userLen <= 0 {
   136  			rw.Write([]byte{1, 1})
   137  			err = ErrAuth
   138  			return
   139  		}
   140  		if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
   141  			return
   142  		}
   143  		user := string(authBuf[:userLen])
   144  
   145  		// Get password
   146  		if _, err = rw.Read(header[:1]); err != nil {
   147  			return
   148  		}
   149  		passLen := int(header[0])
   150  		if passLen <= 0 {
   151  			rw.Write([]byte{1, 1})
   152  			err = ErrAuth
   153  			return
   154  		}
   155  		if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
   156  			return
   157  		}
   158  		pass := string(authBuf[:passLen])
   159  
   160  		// Verify
   161  		if ok := authenticator.Verify(string(user), string(pass)); !ok {
   162  			rw.Write([]byte{1, 1})
   163  			err = ErrAuth
   164  			return
   165  		}
   166  
   167  		//Record Account
   168  		authUser = &auth.AuthUser{}
   169  		authUser.User = user
   170  		authUser.Pass = pass
   171  
   172  		// Response auth state
   173  		if _, err = rw.Write([]byte{1, 0}); err != nil {
   174  			return
   175  		}
   176  	} else {
   177  		if _, err = rw.Write([]byte{5, 0}); err != nil {
   178  			return
   179  		}
   180  	}
   181  
   182  	// read VER CMD RSV ATYP DST.ADDR DST.PORT
   183  	if _, err = io.ReadFull(rw, buf[:3]); err != nil {
   184  		return
   185  	}
   186  
   187  	command = buf[1]
   188  	addr, err = ReadAddr(rw, buf)
   189  	if err != nil {
   190  		return
   191  	}
   192  
   193  	switch command {
   194  	case CmdConnect, CmdUDPAssociate:
   195  		// Acquire server listened address info
   196  		localAddr := ParseAddr(rw.LocalAddr().String())
   197  		if localAddr == nil {
   198  			err = ErrAddressNotSupported
   199  		} else {
   200  			// write VER REP RSV ATYP BND.ADDR BND.PORT
   201  			_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
   202  		}
   203  	case CmdBind:
   204  		fallthrough
   205  	default:
   206  		err = ErrCommandNotSupported
   207  	}
   208  
   209  	return
   210  }
   211  
   212  // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
   213  func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
   214  	buf := make([]byte, MaxAddrLen)
   215  	var err error
   216  
   217  	// VER, NMETHODS, METHODS
   218  	if user != nil {
   219  		_, err = rw.Write([]byte{5, 1, 2})
   220  	} else {
   221  		_, err = rw.Write([]byte{5, 1, 0})
   222  	}
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	// VER, METHOD
   228  	if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   229  		return nil, err
   230  	}
   231  
   232  	if buf[0] != 5 {
   233  		return nil, errors.New("SOCKS version error")
   234  	}
   235  
   236  	if buf[1] == 2 {
   237  		if user == nil {
   238  			return nil, ErrAuth
   239  		}
   240  
   241  		// password protocol version
   242  		authMsg := &bytes.Buffer{}
   243  		authMsg.WriteByte(1)
   244  		authMsg.WriteByte(uint8(len(user.Username)))
   245  		authMsg.WriteString(user.Username)
   246  		authMsg.WriteByte(uint8(len(user.Password)))
   247  		authMsg.WriteString(user.Password)
   248  
   249  		if _, err := rw.Write(authMsg.Bytes()); err != nil {
   250  			return nil, err
   251  		}
   252  
   253  		if _, err := io.ReadFull(rw, buf[:2]); err != nil {
   254  			return nil, err
   255  		}
   256  
   257  		if buf[1] != 0 {
   258  			return nil, errors.New("rejected username/password")
   259  		}
   260  	} else if buf[1] != 0 {
   261  		return nil, errors.New("SOCKS need auth")
   262  	}
   263  
   264  	// VER, CMD, RSV, ADDR
   265  	if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
   266  		return nil, err
   267  	}
   268  
   269  	// VER, REP, RSV
   270  	if _, err := io.ReadFull(rw, buf[:3]); err != nil {
   271  		return nil, err
   272  	}
   273  
   274  	return ReadAddr(rw, buf)
   275  }
   276  
   277  func ReadAddr(r io.Reader, b []byte) (Addr, error) {
   278  	if len(b) < MaxAddrLen {
   279  		return nil, io.ErrShortBuffer
   280  	}
   281  	_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
   282  	if err != nil {
   283  		return nil, err
   284  	}
   285  
   286  	switch b[0] {
   287  	case AtypDomainName:
   288  		_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
   289  		if err != nil {
   290  			return nil, err
   291  		}
   292  		domainLength := uint16(b[1])
   293  		_, err = io.ReadFull(r, b[2:2+domainLength+2])
   294  		return b[:1+1+domainLength+2], err
   295  	case AtypIPv4:
   296  		_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
   297  		return b[:1+net.IPv4len+2], err
   298  	case AtypIPv6:
   299  		_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
   300  		return b[:1+net.IPv6len+2], err
   301  	}
   302  
   303  	return nil, ErrAddressNotSupported
   304  }
   305  
   306  // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
   307  func SplitAddr(b []byte) Addr {
   308  	addrLen := 1
   309  	if len(b) < addrLen {
   310  		return nil
   311  	}
   312  
   313  	switch b[0] {
   314  	case AtypDomainName:
   315  		if len(b) < 2 {
   316  			return nil
   317  		}
   318  		addrLen = 1 + 1 + int(b[1]) + 2
   319  	case AtypIPv4:
   320  		addrLen = 1 + net.IPv4len + 2
   321  	case AtypIPv6:
   322  		addrLen = 1 + net.IPv6len + 2
   323  	default:
   324  		return nil
   325  
   326  	}
   327  
   328  	if len(b) < addrLen {
   329  		return nil
   330  	}
   331  
   332  	return b[:addrLen]
   333  }
   334  
   335  // ParseAddr parses the address in string s. Returns nil if failed.
   336  func ParseAddr(s string) Addr {
   337  	var addr Addr
   338  	host, port, err := net.SplitHostPort(s)
   339  	if err != nil {
   340  		return nil
   341  	}
   342  	if ip := net.ParseIP(host); ip != nil {
   343  		if ip4 := ip.To4(); ip4 != nil {
   344  			addr = make([]byte, 1+net.IPv4len+2)
   345  			addr[0] = AtypIPv4
   346  			copy(addr[1:], ip4)
   347  		} else {
   348  			addr = make([]byte, 1+net.IPv6len+2)
   349  			addr[0] = AtypIPv6
   350  			copy(addr[1:], ip)
   351  		}
   352  	} else {
   353  		if len(host) > 255 {
   354  			return nil
   355  		}
   356  		addr = make([]byte, 1+1+len(host)+2)
   357  		addr[0] = AtypDomainName
   358  		addr[1] = byte(len(host))
   359  		copy(addr[2:], host)
   360  	}
   361  
   362  	portnum, err := strconv.ParseUint(port, 10, 16)
   363  	if err != nil {
   364  		return nil
   365  	}
   366  
   367  	addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
   368  
   369  	return addr
   370  }
   371  
   372  // ParseAddrToSocksAddr parse a socks addr from net.addr
   373  // This is a fast path of ParseAddr(addr.String())
   374  func ParseAddrToSocksAddr(addr net.Addr) Addr {
   375  	var hostip net.IP
   376  	var port int
   377  	if udpaddr, ok := addr.(*net.UDPAddr); ok {
   378  		hostip = udpaddr.IP
   379  		port = udpaddr.Port
   380  	} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
   381  		hostip = tcpaddr.IP
   382  		port = tcpaddr.Port
   383  	}
   384  
   385  	// fallback parse
   386  	if hostip == nil {
   387  		return ParseAddr(addr.String())
   388  	}
   389  
   390  	var parsed Addr
   391  	if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
   392  		parsed = make([]byte, 1+net.IPv4len+2)
   393  		parsed[0] = AtypIPv4
   394  		copy(parsed[1:], ip4)
   395  		binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
   396  
   397  	} else {
   398  		parsed = make([]byte, 1+net.IPv6len+2)
   399  		parsed[0] = AtypIPv6
   400  		copy(parsed[1:], hostip)
   401  		binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
   402  	}
   403  	return parsed
   404  }
   405  
   406  // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
   407  func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
   408  	if len(packet) < 5 {
   409  		err = errors.New("insufficient length of packet")
   410  		return
   411  	}
   412  
   413  	// packet[0] and packet[1] are reserved
   414  	if !bytes.Equal(packet[:2], []byte{0, 0}) {
   415  		err = errors.New("reserved fields should be zero")
   416  		return
   417  	}
   418  
   419  	if packet[2] != 0 /* fragments */ {
   420  		err = errors.New("discarding fragmented payload")
   421  		return
   422  	}
   423  
   424  	addr = SplitAddr(packet[3:])
   425  	if addr == nil {
   426  		err = errors.New("failed to read UDP header")
   427  	}
   428  
   429  	payload = packet[3+len(addr):]
   430  	return
   431  }
   432  
   433  func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
   434  	if addr == nil {
   435  		err = errors.New("address is invalid")
   436  		return
   437  	}
   438  	packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
   439  	return
   440  }