github.com/database64128/shadowsocks-go@v1.7.0/socks5/addr.go (about)

     1  package socks5
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/netip"
     9  	"unsafe"
    10  
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  	"github.com/database64128/shadowsocks-go/slices"
    13  )
    14  
    15  // SOCKS address types as defined in RFC 1928 section 5.
    16  const (
    17  	AtypIPv4       = 1
    18  	AtypDomainName = 3
    19  	AtypIPv6       = 4
    20  )
    21  
    22  const (
    23  	// IPv4AddrLen is the size of an IPv4 SOCKS address in bytes.
    24  	IPv4AddrLen = 1 + 4 + 2
    25  
    26  	// IPv6AddrLen is the size of an IPv6 SOCKS address in bytes.
    27  	IPv6AddrLen = 1 + 16 + 2
    28  
    29  	// MaxAddrLen is the maximum size of a SOCKS address in bytes.
    30  	MaxAddrLen = 1 + 1 + 255 + 2
    31  )
    32  
    33  var (
    34  	// IPv4UnspecifiedAddr represents 0.0.0.0:0.
    35  	IPv4UnspecifiedAddr = [IPv4AddrLen]byte{AtypIPv4}
    36  
    37  	// IPv6UnspecifiedAddr represents [::]:0.
    38  	IPv6UnspecifiedAddr = [IPv6AddrLen]byte{AtypIPv6}
    39  )
    40  
    41  // AppendAddrFromAddrPort appends the netip.AddrPort to the buffer in the SOCKS address format.
    42  //
    43  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
    44  func AppendAddrFromAddrPort(b []byte, addrPort netip.AddrPort) []byte {
    45  	var ret, out []byte
    46  	ip := addrPort.Addr()
    47  	switch {
    48  	case ip.Is4() || ip.Is4In6():
    49  		ret, out = slices.Extend(b, 1+4+2)
    50  		out[0] = AtypIPv4
    51  		*(*[4]byte)(out[1:]) = ip.As4()
    52  	default:
    53  		ret, out = slices.Extend(b, 1+16+2)
    54  		out[0] = AtypIPv6
    55  		*(*[16]byte)(out[1:]) = ip.As16()
    56  	}
    57  	binary.BigEndian.PutUint16(out[len(out)-2:], addrPort.Port())
    58  	return ret
    59  }
    60  
    61  // WriteAddrFromAddrPort writes the netip.AddrPort to the buffer in the SOCKS address format
    62  // and returns the number of bytes written.
    63  //
    64  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
    65  //
    66  // This function does not check whether b has sufficient space for the address.
    67  // The caller may call [LengthOfAddrFromAddrPort] to get the required length.
    68  func WriteAddrFromAddrPort(b []byte, addrPort netip.AddrPort) (n int) {
    69  	ip := addrPort.Addr()
    70  	switch {
    71  	case ip.Is4() || ip.Is4In6():
    72  		b[0] = AtypIPv4
    73  		*(*[4]byte)(b[1:]) = ip.As4()
    74  		n = 1 + 4 + 2
    75  	default:
    76  		b[0] = AtypIPv6
    77  		*(*[16]byte)(b[1:]) = ip.As16()
    78  		n = 1 + 16 + 2
    79  	}
    80  	binary.BigEndian.PutUint16(b[n-2:], addrPort.Port())
    81  	return
    82  }
    83  
    84  // LengthOfAddrFromAddrPort returns the length of a SOCKS address converted from the netip.AddrPort.
    85  func LengthOfAddrFromAddrPort(addrPort netip.AddrPort) int {
    86  	if ip := addrPort.Addr(); ip.Is4() || ip.Is4In6() {
    87  		return 1 + 4 + 2
    88  	}
    89  	return 1 + 16 + 2
    90  }
    91  
    92  // AppendAddrFromConnAddr appends the address to the buffer in the SOCKS address format.
    93  //
    94  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
    95  func AppendAddrFromConnAddr(b []byte, addr conn.Addr) []byte {
    96  	if addr.IsIP() {
    97  		return AppendAddrFromAddrPort(b, addr.IPPort())
    98  	}
    99  
   100  	domain := addr.Domain()
   101  	ret, out := slices.Extend(b, 1+1+len(domain)+2)
   102  	out[0] = AtypDomainName
   103  	out[1] = byte(len(domain))
   104  	copy(out[2:], domain)
   105  
   106  	port := addr.Port()
   107  	binary.BigEndian.PutUint16(out[1+1+len(domain):], port)
   108  
   109  	return ret
   110  }
   111  
   112  // WriteAddrFromConnAddr writes the address to the buffer in the SOCKS address format
   113  // and returns the number of bytes written.
   114  //
   115  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
   116  //
   117  // This function does not check whether b has sufficient space for the address.
   118  // The caller may call [LengthOfAddrFromConnAddr] to get the required length.
   119  func WriteAddrFromConnAddr(b []byte, addr conn.Addr) int {
   120  	if addr.IsIP() {
   121  		return WriteAddrFromAddrPort(b, addr.IPPort())
   122  	}
   123  
   124  	domain := addr.Domain()
   125  	b[0] = AtypDomainName
   126  	b[1] = byte(len(domain))
   127  	copy(b[2:], domain)
   128  
   129  	port := addr.Port()
   130  	binary.BigEndian.PutUint16(b[1+1+len(domain):], port)
   131  
   132  	return 1 + 1 + len(domain) + 2
   133  }
   134  
   135  // LengthOfAddrFromConnAddr returns the length of a SOCKS address converted from the conn.Addr.
   136  func LengthOfAddrFromConnAddr(addr conn.Addr) int {
   137  	if addr.IsIP() {
   138  		return LengthOfAddrFromAddrPort(addr.IPPort())
   139  	}
   140  	domain := addr.Domain()
   141  	return 1 + 1 + len(domain) + 2
   142  }
   143  
   144  // AppendFromReader reads just enough bytes from r to get a valid Addr
   145  // and appends it to the buffer.
   146  func AppendFromReader(b []byte, r io.Reader) ([]byte, error) {
   147  	ret, out := slices.Extend(b, 2)
   148  
   149  	// Read ATYP and an extra byte.
   150  	_, err := io.ReadFull(r, out)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  
   155  	var addrLen int
   156  
   157  	switch out[0] {
   158  	case AtypDomainName:
   159  		addrLen = 1 + 1 + int(out[1]) + 2
   160  	case AtypIPv4:
   161  		addrLen = 1 + 4 + 2
   162  	case AtypIPv6:
   163  		addrLen = 1 + 16 + 2
   164  	default:
   165  		return nil, fmt.Errorf("unknown atyp %d", out[0])
   166  	}
   167  
   168  	ret, out = slices.Extend(ret[:len(b)+2], addrLen-2)
   169  	_, err = io.ReadFull(r, out)
   170  	return ret, err
   171  }
   172  
   173  // AddrFromReader allocates and reads a SOCKS address from an io.Reader.
   174  //
   175  // To avoid allocations, call AppendFromReader directly.
   176  func AddrFromReader(r io.Reader) ([]byte, error) {
   177  	b := make([]byte, 0, MaxAddrLen)
   178  	return AppendFromReader(b, r)
   179  }
   180  
   181  // ConnAddrFromReader reads a SOCKS address from r and returns the converted conn.Addr.
   182  func ConnAddrFromReader(r io.Reader) (conn.Addr, error) {
   183  	b := make([]byte, 2)
   184  
   185  	// Read ATYP and an extra byte.
   186  	_, err := io.ReadFull(r, b)
   187  	if err != nil {
   188  		return conn.Addr{}, err
   189  	}
   190  
   191  	switch b[0] {
   192  	case AtypDomainName:
   193  		b1 := make([]byte, int(b[1])+2)
   194  		_, err = io.ReadFull(r, b1)
   195  		if err != nil {
   196  			return conn.Addr{}, err
   197  		}
   198  		b2 := b1[:b[1]:b[1]]
   199  		domain := *(*string)(unsafe.Pointer(&b2))
   200  		port := binary.BigEndian.Uint16(b1[b[1]:])
   201  		return conn.AddrFromDomainPort(domain, port)
   202  
   203  	case AtypIPv4:
   204  		b1 := make([]byte, 4+2)
   205  		b1[0] = b[1]
   206  		_, err = io.ReadFull(r, b1[1:])
   207  		if err != nil {
   208  			return conn.Addr{}, err
   209  		}
   210  		ip := netip.AddrFrom4(*(*[4]byte)(b1))
   211  		port := binary.BigEndian.Uint16(b1[4:])
   212  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), nil
   213  
   214  	case AtypIPv6:
   215  		b1 := make([]byte, 16+2)
   216  		b1[0] = b[1]
   217  		_, err = io.ReadFull(r, b1[1:])
   218  		if err != nil {
   219  			return conn.Addr{}, err
   220  		}
   221  		ip := netip.AddrFrom16(*(*[16]byte)(b1))
   222  		port := binary.BigEndian.Uint16(b1[16:])
   223  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), nil
   224  
   225  	default:
   226  		return conn.Addr{}, fmt.Errorf("invalid ATYP: %d", b[0])
   227  	}
   228  }
   229  
   230  var errDomain = errors.New("addr is a domain")
   231  
   232  // AddrPortFromSlice slices a SOCKS address from the beginning of b and returns the converted netip.AddrPort
   233  // and the length of the SOCKS address.
   234  func AddrPortFromSlice(b []byte) (netip.AddrPort, int, error) {
   235  	if len(b) < 1+4+2 {
   236  		return netip.AddrPort{}, 0, fmt.Errorf("addr length too short: %d", len(b))
   237  	}
   238  
   239  	switch b[0] {
   240  	case AtypIPv4:
   241  		ip := netip.AddrFrom4(*(*[4]byte)(b[1:]))
   242  		port := binary.BigEndian.Uint16(b[1+4:])
   243  		return netip.AddrPortFrom(ip, port), 1 + 4 + 2, nil
   244  
   245  	case AtypIPv6:
   246  		if len(b) < 1+16+2 {
   247  			return netip.AddrPort{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   248  		}
   249  		ip := netip.AddrFrom16(*(*[16]byte)(b[1:]))
   250  		port := binary.BigEndian.Uint16(b[1+16:])
   251  		return netip.AddrPortFrom(ip, port), 1 + 16 + 2, nil
   252  
   253  	case AtypDomainName:
   254  		return netip.AddrPort{}, 0, errDomain
   255  
   256  	default:
   257  		return netip.AddrPort{}, 0, fmt.Errorf("invalid ATYP: %d", b[0])
   258  	}
   259  }
   260  
   261  // ConnAddrFromSlice slices a SOCKS address from the beginning of b and returns the converted conn.Addr
   262  // and the length of the SOCKS address.
   263  func ConnAddrFromSlice(b []byte) (conn.Addr, int, error) {
   264  	if len(b) < 2 {
   265  		return conn.Addr{}, 0, fmt.Errorf("addr length too short: %d", len(b))
   266  	}
   267  
   268  	switch b[0] {
   269  	case AtypDomainName:
   270  		if len(b) < 1+1+int(b[1])+2 {
   271  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   272  		}
   273  		domain := string(b[2 : 2+int(b[1])])
   274  		port := binary.BigEndian.Uint16(b[2+int(b[1]):])
   275  		addr, err := conn.AddrFromDomainPort(domain, port)
   276  		return addr, 2 + int(b[1]) + 2, err
   277  
   278  	case AtypIPv4:
   279  		if len(b) < 1+4+2 {
   280  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   281  		}
   282  		ip := netip.AddrFrom4(*(*[4]byte)(b[1:]))
   283  		port := binary.BigEndian.Uint16(b[1+4:])
   284  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 4 + 2, nil
   285  
   286  	case AtypIPv6:
   287  		if len(b) < 1+16+2 {
   288  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   289  		}
   290  		ip := netip.AddrFrom16(*(*[16]byte)(b[1:]))
   291  		port := binary.BigEndian.Uint16(b[1+16:])
   292  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 16 + 2, nil
   293  
   294  	default:
   295  		return conn.Addr{}, 0, fmt.Errorf("invalid ATYP: %d", b[0])
   296  	}
   297  }
   298  
   299  // ConnAddrFromSliceWithDomainCache is like [ConnAddrFromSlice] but uses a domain cache to minimize string allocations.
   300  // The returned string is the updated domain cache.
   301  func ConnAddrFromSliceWithDomainCache(b []byte, cachedDomain string) (conn.Addr, int, string, error) {
   302  	if len(b) < 2 {
   303  		return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length too short: %d", len(b))
   304  	}
   305  
   306  	switch b[0] {
   307  	case AtypDomainName:
   308  		if len(b) < 1+1+int(b[1])+2 {
   309  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   310  		}
   311  		domain := b[2 : 2+int(b[1])]
   312  		if cachedDomain != string(domain) { // Hopefully the compiler will optimize the string allocation away.
   313  			cachedDomain = string(domain)
   314  		}
   315  		port := binary.BigEndian.Uint16(b[2+int(b[1]):])
   316  		addr, err := conn.AddrFromDomainPort(cachedDomain, port)
   317  		return addr, 2 + int(b[1]) + 2, cachedDomain, err
   318  
   319  	case AtypIPv4:
   320  		if len(b) < 1+4+2 {
   321  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   322  		}
   323  		ip := netip.AddrFrom4(*(*[4]byte)(b[1 : 1+4]))
   324  		port := binary.BigEndian.Uint16(b[1+4:])
   325  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 4 + 2, cachedDomain, nil
   326  
   327  	case AtypIPv6:
   328  		if len(b) < 1+16+2 {
   329  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   330  		}
   331  		ip := netip.AddrFrom16(*(*[16]byte)(b[1 : 1+16]))
   332  		port := binary.BigEndian.Uint16(b[1+16:])
   333  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 16 + 2, cachedDomain, nil
   334  
   335  	default:
   336  		return conn.Addr{}, 0, cachedDomain, fmt.Errorf("invalid ATYP: %d", b[0])
   337  	}
   338  }