github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/socks5/addr.go (about)

     1  package socks5
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/netip"
     9  	"unsafe"
    10  
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  	"github.com/database64128/shadowsocks-go/slicehelper"
    13  )
    14  
    15  // SOCKS address types as defined in RFC 1928 section 5.
    16  const (
    17  	AtypIPv4       = 1
    18  	AtypDomainName = 3
    19  	AtypIPv6       = 4
    20  )
    21  
    22  const (
    23  	// IPv4AddrLen is the size of an IPv4 SOCKS address in bytes.
    24  	IPv4AddrLen = 1 + 4 + 2
    25  
    26  	// IPv6AddrLen is the size of an IPv6 SOCKS address in bytes.
    27  	IPv6AddrLen = 1 + 16 + 2
    28  
    29  	// MaxAddrLen is the maximum size of a SOCKS address in bytes.
    30  	MaxAddrLen = 1 + 1 + 255 + 2
    31  )
    32  
    33  var (
    34  	// IPv4UnspecifiedAddr represents 0.0.0.0:0.
    35  	IPv4UnspecifiedAddr = [IPv4AddrLen]byte{AtypIPv4}
    36  
    37  	// IPv6UnspecifiedAddr represents [::]:0.
    38  	IPv6UnspecifiedAddr = [IPv6AddrLen]byte{AtypIPv6}
    39  )
    40  
    41  // AppendAddrFromAddrPort appends the netip.AddrPort to the buffer in the SOCKS address format.
    42  //
    43  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
    44  func AppendAddrFromAddrPort(b []byte, addrPort netip.AddrPort) []byte {
    45  	var ret, out []byte
    46  	ip := addrPort.Addr()
    47  	switch {
    48  	case ip.Is4() || ip.Is4In6():
    49  		ret, out = slicehelper.Extend(b, 1+4+2)
    50  		out[0] = AtypIPv4
    51  		*(*[4]byte)(out[1:]) = ip.As4()
    52  	default:
    53  		ret, out = slicehelper.Extend(b, 1+16+2)
    54  		out[0] = AtypIPv6
    55  		*(*[16]byte)(out[1:]) = ip.As16()
    56  	}
    57  	binary.BigEndian.PutUint16(out[len(out)-2:], addrPort.Port())
    58  	return ret
    59  }
    60  
    61  // WriteAddrFromAddrPort writes the netip.AddrPort to the buffer in the SOCKS address format
    62  // and returns the number of bytes written.
    63  //
    64  // If the address is an IPv4-mapped IPv6 address, it is converted to an IPv4 address.
    65  //
    66  // This function does not check whether b has sufficient space for the address.
    67  // The caller may call [LengthOfAddrFromAddrPort] to get the required length.
    68  func WriteAddrFromAddrPort(b []byte, addrPort netip.AddrPort) (n int) {
    69  	ip := addrPort.Addr()
    70  	switch {
    71  	case ip.Is4() || ip.Is4In6():
    72  		b[0] = AtypIPv4
    73  		*(*[4]byte)(b[1:]) = ip.As4()
    74  		n = 1 + 4 + 2
    75  	default:
    76  		b[0] = AtypIPv6
    77  		*(*[16]byte)(b[1:]) = ip.As16()
    78  		n = 1 + 16 + 2
    79  	}
    80  	binary.BigEndian.PutUint16(b[n-2:], addrPort.Port())
    81  	return
    82  }
    83  
    84  // LengthOfAddrFromAddrPort returns the length of a SOCKS address converted from the netip.AddrPort.
    85  func LengthOfAddrFromAddrPort(addrPort netip.AddrPort) int {
    86  	if ip := addrPort.Addr(); ip.Is4() || ip.Is4In6() {
    87  		return 1 + 4 + 2
    88  	}
    89  	return 1 + 16 + 2
    90  }
    91  
    92  // AppendAddrFromConnAddr appends the address to the buffer in the SOCKS address format.
    93  //
    94  // - Zero value address is treated as 0.0.0.0:0.
    95  // - IPv4-mapped IPv6 address is converted to the equivalent IPv4 address.
    96  func AppendAddrFromConnAddr(b []byte, addr conn.Addr) []byte {
    97  	if !addr.IsValid() {
    98  		return AppendAddrFromAddrPort(b, netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
    99  	}
   100  
   101  	if addr.IsIP() {
   102  		return AppendAddrFromAddrPort(b, addr.IPPort())
   103  	}
   104  
   105  	domain := addr.Domain()
   106  	ret, out := slicehelper.Extend(b, 1+1+len(domain)+2)
   107  	out[0] = AtypDomainName
   108  	out[1] = byte(len(domain))
   109  	copy(out[2:], domain)
   110  
   111  	port := addr.Port()
   112  	binary.BigEndian.PutUint16(out[1+1+len(domain):], port)
   113  
   114  	return ret
   115  }
   116  
   117  // WriteAddrFromConnAddr writes the address to the buffer in the SOCKS address format
   118  // and returns the number of bytes written.
   119  //
   120  // - Zero value address is treated as 0.0.0.0:0.
   121  // - IPv4-mapped IPv6 address is converted to the equivalent IPv4 address.
   122  //
   123  // This function does not check whether b has sufficient space for the address.
   124  // The caller may call [LengthOfAddrFromConnAddr] to get the required length.
   125  func WriteAddrFromConnAddr(b []byte, addr conn.Addr) int {
   126  	if !addr.IsValid() {
   127  		return WriteAddrFromAddrPort(b, netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
   128  	}
   129  
   130  	if addr.IsIP() {
   131  		return WriteAddrFromAddrPort(b, addr.IPPort())
   132  	}
   133  
   134  	domain := addr.Domain()
   135  	b[0] = AtypDomainName
   136  	b[1] = byte(len(domain))
   137  	copy(b[2:], domain)
   138  
   139  	port := addr.Port()
   140  	binary.BigEndian.PutUint16(b[1+1+len(domain):], port)
   141  
   142  	return 1 + 1 + len(domain) + 2
   143  }
   144  
   145  // LengthOfAddrFromConnAddr returns the length of a SOCKS address converted from the conn.Addr.
   146  //
   147  // - Zero value address is treated as 0.0.0.0:0.
   148  // - IPv4-mapped IPv6 address is treated as the equivalent IPv4 address.
   149  func LengthOfAddrFromConnAddr(addr conn.Addr) int {
   150  	if !addr.IsValid() {
   151  		return 1 + 4 + 2
   152  	}
   153  	if addr.IsIP() {
   154  		return LengthOfAddrFromAddrPort(addr.IPPort())
   155  	}
   156  	domain := addr.Domain()
   157  	return 1 + 1 + len(domain) + 2
   158  }
   159  
   160  // AppendFromReader reads just enough bytes from r to get a valid Addr
   161  // and appends it to the buffer.
   162  func AppendFromReader(b []byte, r io.Reader) ([]byte, error) {
   163  	ret, out := slicehelper.Extend(b, 2)
   164  
   165  	// Read ATYP and an extra byte.
   166  	_, err := io.ReadFull(r, out)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	var addrLen int
   172  
   173  	switch out[0] {
   174  	case AtypDomainName:
   175  		addrLen = 1 + 1 + int(out[1]) + 2
   176  	case AtypIPv4:
   177  		addrLen = 1 + 4 + 2
   178  	case AtypIPv6:
   179  		addrLen = 1 + 16 + 2
   180  	default:
   181  		return nil, fmt.Errorf("unknown atyp %d", out[0])
   182  	}
   183  
   184  	ret, out = slicehelper.Extend(ret, addrLen-2)
   185  	_, err = io.ReadFull(r, out)
   186  	return ret, err
   187  }
   188  
   189  // AddrFromReader allocates and reads a SOCKS address from an io.Reader.
   190  //
   191  // To avoid allocations, call AppendFromReader directly.
   192  func AddrFromReader(r io.Reader) ([]byte, error) {
   193  	b := make([]byte, 0, MaxAddrLen)
   194  	return AppendFromReader(b, r)
   195  }
   196  
   197  // ConnAddrFromReader reads a SOCKS address from r and returns the converted conn.Addr.
   198  func ConnAddrFromReader(r io.Reader) (conn.Addr, error) {
   199  	b := make([]byte, 2)
   200  
   201  	// Read ATYP and an extra byte.
   202  	_, err := io.ReadFull(r, b)
   203  	if err != nil {
   204  		return conn.Addr{}, err
   205  	}
   206  
   207  	switch b[0] {
   208  	case AtypDomainName:
   209  		b1 := make([]byte, int(b[1])+2)
   210  		_, err = io.ReadFull(r, b1)
   211  		if err != nil {
   212  			return conn.Addr{}, err
   213  		}
   214  		domain := unsafe.String(unsafe.SliceData(b1), b[1])
   215  		port := binary.BigEndian.Uint16(b1[b[1]:])
   216  		return conn.AddrFromDomainPort(domain, port)
   217  
   218  	case AtypIPv4:
   219  		b1 := make([]byte, 4+2)
   220  		b1[0] = b[1]
   221  		_, err = io.ReadFull(r, b1[1:])
   222  		if err != nil {
   223  			return conn.Addr{}, err
   224  		}
   225  		ip := netip.AddrFrom4(*(*[4]byte)(b1))
   226  		port := binary.BigEndian.Uint16(b1[4:])
   227  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), nil
   228  
   229  	case AtypIPv6:
   230  		b1 := make([]byte, 16+2)
   231  		b1[0] = b[1]
   232  		_, err = io.ReadFull(r, b1[1:])
   233  		if err != nil {
   234  			return conn.Addr{}, err
   235  		}
   236  		ip := netip.AddrFrom16(*(*[16]byte)(b1))
   237  		port := binary.BigEndian.Uint16(b1[16:])
   238  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), nil
   239  
   240  	default:
   241  		return conn.Addr{}, fmt.Errorf("invalid ATYP: %d", b[0])
   242  	}
   243  }
   244  
   245  var errDomain = errors.New("addr is a domain")
   246  
   247  // AddrPortFromSlice slices a SOCKS address from the beginning of b and returns the converted netip.AddrPort
   248  // and the length of the SOCKS address.
   249  func AddrPortFromSlice(b []byte) (netip.AddrPort, int, error) {
   250  	if len(b) < 1+4+2 {
   251  		return netip.AddrPort{}, 0, fmt.Errorf("addr length too short: %d", len(b))
   252  	}
   253  
   254  	switch b[0] {
   255  	case AtypIPv4:
   256  		ip := netip.AddrFrom4(*(*[4]byte)(b[1:]))
   257  		port := binary.BigEndian.Uint16(b[1+4:])
   258  		return netip.AddrPortFrom(ip, port), 1 + 4 + 2, nil
   259  
   260  	case AtypIPv6:
   261  		if len(b) < 1+16+2 {
   262  			return netip.AddrPort{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   263  		}
   264  		ip := netip.AddrFrom16(*(*[16]byte)(b[1:]))
   265  		port := binary.BigEndian.Uint16(b[1+16:])
   266  		return netip.AddrPortFrom(ip, port), 1 + 16 + 2, nil
   267  
   268  	case AtypDomainName:
   269  		return netip.AddrPort{}, 0, errDomain
   270  
   271  	default:
   272  		return netip.AddrPort{}, 0, fmt.Errorf("invalid ATYP: %d", b[0])
   273  	}
   274  }
   275  
   276  // ConnAddrFromSlice slices a SOCKS address from the beginning of b and returns the converted conn.Addr
   277  // and the length of the SOCKS address.
   278  func ConnAddrFromSlice(b []byte) (conn.Addr, int, error) {
   279  	if len(b) < 2 {
   280  		return conn.Addr{}, 0, fmt.Errorf("addr length too short: %d", len(b))
   281  	}
   282  
   283  	switch b[0] {
   284  	case AtypDomainName:
   285  		if len(b) < 1+1+int(b[1])+2 {
   286  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   287  		}
   288  		domain := string(b[2 : 2+int(b[1])])
   289  		port := binary.BigEndian.Uint16(b[2+int(b[1]):])
   290  		addr, err := conn.AddrFromDomainPort(domain, port)
   291  		return addr, 2 + int(b[1]) + 2, err
   292  
   293  	case AtypIPv4:
   294  		if len(b) < 1+4+2 {
   295  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   296  		}
   297  		ip := netip.AddrFrom4(*(*[4]byte)(b[1:]))
   298  		port := binary.BigEndian.Uint16(b[1+4:])
   299  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 4 + 2, nil
   300  
   301  	case AtypIPv6:
   302  		if len(b) < 1+16+2 {
   303  			return conn.Addr{}, 0, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   304  		}
   305  		ip := netip.AddrFrom16(*(*[16]byte)(b[1:]))
   306  		port := binary.BigEndian.Uint16(b[1+16:])
   307  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 16 + 2, nil
   308  
   309  	default:
   310  		return conn.Addr{}, 0, fmt.Errorf("invalid ATYP: %d", b[0])
   311  	}
   312  }
   313  
   314  // ConnAddrFromSliceWithDomainCache is like [ConnAddrFromSlice] but uses a domain cache to minimize string allocations.
   315  // The returned string is the updated domain cache.
   316  func ConnAddrFromSliceWithDomainCache(b []byte, cachedDomain string) (conn.Addr, int, string, error) {
   317  	if len(b) < 2 {
   318  		return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length too short: %d", len(b))
   319  	}
   320  
   321  	switch b[0] {
   322  	case AtypDomainName:
   323  		if len(b) < 1+1+int(b[1])+2 {
   324  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   325  		}
   326  		domain := b[2 : 2+int(b[1])]
   327  		if cachedDomain != string(domain) { // Hopefully the compiler will optimize the string allocation away.
   328  			cachedDomain = string(domain)
   329  		}
   330  		port := binary.BigEndian.Uint16(b[2+int(b[1]):])
   331  		addr, err := conn.AddrFromDomainPort(cachedDomain, port)
   332  		return addr, 2 + int(b[1]) + 2, cachedDomain, err
   333  
   334  	case AtypIPv4:
   335  		if len(b) < 1+4+2 {
   336  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   337  		}
   338  		ip := netip.AddrFrom4(*(*[4]byte)(b[1 : 1+4]))
   339  		port := binary.BigEndian.Uint16(b[1+4:])
   340  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 4 + 2, cachedDomain, nil
   341  
   342  	case AtypIPv6:
   343  		if len(b) < 1+16+2 {
   344  			return conn.Addr{}, 0, cachedDomain, fmt.Errorf("addr length %d is too short for ATYP %d", len(b), b[0])
   345  		}
   346  		ip := netip.AddrFrom16(*(*[16]byte)(b[1 : 1+16]))
   347  		port := binary.BigEndian.Uint16(b[1+16:])
   348  		return conn.AddrFromIPPort(netip.AddrPortFrom(ip, port)), 1 + 16 + 2, cachedDomain, nil
   349  
   350  	default:
   351  		return conn.Addr{}, 0, cachedDomain, fmt.Errorf("invalid ATYP: %d", b[0])
   352  	}
   353  }