github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/conn/addr.go (about)

     1  package conn
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"strconv"
     9  	"unsafe"
    10  )
    11  
    12  type addressFamily byte
    13  
    14  const (
    15  	addressFamilyNone addressFamily = iota
    16  	addressFamilyNetip
    17  	addressFamilyDomain
    18  )
    19  
    20  type netipAddrHeader struct {
    21  	hi uint64
    22  	lo uint64
    23  	z  *byte
    24  }
    25  
    26  // Addr is the base address type used throughout the package.
    27  //
    28  // An Addr is a port number combined with either an IP address or a domain name.
    29  //
    30  // For space efficiency, the IP address and the domain string share the same space.
    31  // The [netip.Addr] is stored in its original layout.
    32  // The domain string's data pointer is stored in the ip.z field.
    33  // Its length is stored at the beginning of the structure.
    34  // This is essentially an unsafe "enum".
    35  type Addr struct {
    36  	_    [0]func()
    37  	addr netipAddrHeader
    38  	port uint16
    39  	af   addressFamily
    40  }
    41  
    42  func (a Addr) ip() netip.Addr {
    43  	return *(*netip.Addr)(unsafe.Pointer(&a))
    44  }
    45  
    46  func (a Addr) ipPort() netip.AddrPort {
    47  	return *(*netip.AddrPort)(unsafe.Pointer(&a))
    48  }
    49  
    50  func (a Addr) domain() string {
    51  	return unsafe.String(a.addr.z, a.addr.hi)
    52  }
    53  
    54  // Equals returns whether two addresses are the same.
    55  func (a Addr) Equals(b Addr) bool {
    56  	if a.af != b.af || a.port != b.port {
    57  		return false
    58  	}
    59  
    60  	switch a.af {
    61  	case addressFamilyNetip:
    62  		return a.addr == b.addr
    63  	case addressFamilyDomain:
    64  		return a.domain() == b.domain()
    65  	default:
    66  		return true
    67  	}
    68  }
    69  
    70  // IsValid returns whether the address is an initialized address (not a zero value).
    71  func (a Addr) IsValid() bool {
    72  	return a.af != addressFamilyNone
    73  }
    74  
    75  // IsIP returns whether the address is an IP address.
    76  func (a Addr) IsIP() bool {
    77  	return a.af == addressFamilyNetip
    78  }
    79  
    80  // IsDomain returns whether the address is a domain name.
    81  func (a Addr) IsDomain() bool {
    82  	return a.af == addressFamilyDomain
    83  }
    84  
    85  // IP returns the IP address.
    86  //
    87  // If the address is a domain name or zero value, this method panics.
    88  func (a Addr) IP() netip.Addr {
    89  	if a.af != addressFamilyNetip {
    90  		panic("IP() called on non-IP address")
    91  	}
    92  	return a.ip()
    93  }
    94  
    95  // Domain returns the domain name.
    96  //
    97  // If the address is an IP address or zero value, this method panics.
    98  func (a Addr) Domain() string {
    99  	if a.af != addressFamilyDomain {
   100  		panic("Domain() called on non-domain address")
   101  	}
   102  	return a.domain()
   103  }
   104  
   105  // Port returns the port number.
   106  func (a Addr) Port() uint16 {
   107  	return a.port
   108  }
   109  
   110  // IPPort returns a netip.AddrPort.
   111  //
   112  // If the address is a domain name or zero value, this method panics.
   113  func (a Addr) IPPort() netip.AddrPort {
   114  	if a.af != addressFamilyNetip {
   115  		panic("IPPort() called on non-IP address")
   116  	}
   117  	return a.ipPort()
   118  }
   119  
   120  // ResolveIP resolves a domain name string into an IP address.
   121  //
   122  // The network must be one of "ip", "ip4" or "ip6".
   123  // String representations of IP addresses are not supported.
   124  //
   125  // This function always returns the first IP address returned by the resolver,
   126  // because the resolver takes care of sorting the IP addresses by address family
   127  // availability and preference.
   128  func ResolveIP(ctx context.Context, network, host string) (netip.Addr, error) {
   129  	ips, err := net.DefaultResolver.LookupNetIP(ctx, network, host)
   130  	if err != nil {
   131  		return netip.Addr{}, err
   132  	}
   133  	return ips[0], nil
   134  }
   135  
   136  // ResolveIP returns the IP address itself or the resolved IP address of the domain name.
   137  //
   138  // The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6".
   139  //
   140  // If the address is zero value, this method panics.
   141  func (a Addr) ResolveIP(ctx context.Context, network string) (netip.Addr, error) {
   142  	switch a.af {
   143  	case addressFamilyNetip:
   144  		return a.ip(), nil
   145  	case addressFamilyDomain:
   146  		return ResolveIP(ctx, network, a.domain())
   147  	default:
   148  		panic("ResolveIP() called on zero value")
   149  	}
   150  }
   151  
   152  // ResolveIPPort returns the IP address itself or the resolved IP address of the domain name
   153  // and the port number as a [netip.AddrPort].
   154  //
   155  // The network is only used for domain name resolution and must be one of "ip", "ip4" or "ip6".
   156  //
   157  // If the address is zero value, this method panics.
   158  func (a Addr) ResolveIPPort(ctx context.Context, network string) (netip.AddrPort, error) {
   159  	switch a.af {
   160  	case addressFamilyNetip:
   161  		return a.ipPort(), nil
   162  	case addressFamilyDomain:
   163  		ip, err := ResolveIP(ctx, network, a.domain())
   164  		if err != nil {
   165  			return netip.AddrPort{}, err
   166  		}
   167  		return netip.AddrPortFrom(ip, a.port), nil
   168  	default:
   169  		panic("ResolveIPPort() called on zero value")
   170  	}
   171  }
   172  
   173  // Host returns the string representation of the IP address or the domain name.
   174  //
   175  // If the address is zero value, this method panics.
   176  func (a Addr) Host() string {
   177  	switch a.af {
   178  	case addressFamilyNetip:
   179  		return a.ip().String()
   180  	case addressFamilyDomain:
   181  		return a.domain()
   182  	default:
   183  		panic("Host() called on zero value")
   184  	}
   185  }
   186  
   187  // String returns the string representation of the address.
   188  //
   189  // If the address is zero value, an empty string is returned.
   190  func (a Addr) String() string {
   191  	switch a.af {
   192  	case addressFamilyNetip:
   193  		return a.ipPort().String()
   194  	case addressFamilyDomain:
   195  		return fmt.Sprintf("%s:%d", a.domain(), a.port)
   196  	default:
   197  		return ""
   198  	}
   199  }
   200  
   201  // AppendTo appends the string representation of the address to the provided buffer.
   202  //
   203  // If the address is zero value, nothing is appended.
   204  func (a Addr) AppendTo(b []byte) []byte {
   205  	switch a.af {
   206  	case addressFamilyNetip:
   207  		return a.ipPort().AppendTo(b)
   208  	case addressFamilyDomain:
   209  		return fmt.Appendf(b, "%s:%d", a.domain(), a.port)
   210  	default:
   211  		return b
   212  	}
   213  }
   214  
   215  // MarshalText implements the encoding.TextMarshaler MarshalText method.
   216  func (a Addr) MarshalText() ([]byte, error) {
   217  	switch a.af {
   218  	case addressFamilyNetip:
   219  		return a.ipPort().MarshalText()
   220  	case addressFamilyDomain:
   221  		return fmt.Appendf(nil, "%s:%d", a.domain(), a.port), nil
   222  	default:
   223  		return nil, nil
   224  	}
   225  }
   226  
   227  // UnmarshalText implements the encoding.TextUnmarshaler UnmarshalText method.
   228  func (a *Addr) UnmarshalText(text []byte) error {
   229  	addr, err := ParseAddr(text)
   230  	if err != nil {
   231  		return err
   232  	}
   233  	*a = addr
   234  	return nil
   235  }
   236  
   237  // AddrFromIPPort returns an Addr from the provided netip.AddrPort.
   238  func AddrFromIPPort(addrPort netip.AddrPort) (addr Addr) {
   239  	*(*netip.AddrPort)(unsafe.Pointer(&addr)) = addrPort
   240  	addr.af = addressFamilyNetip
   241  	return
   242  }
   243  
   244  // AddrFromDomainPort returns an Addr from the provided domain name and port number.
   245  func AddrFromDomainPort(domain string, port uint16) (Addr, error) {
   246  	if len(domain) == 0 || len(domain) > 255 {
   247  		return Addr{}, fmt.Errorf("length of domain %s out of range [1, 255]", domain)
   248  	}
   249  	return Addr{
   250  		addr: netipAddrHeader{
   251  			hi: uint64(len(domain)),
   252  			z:  unsafe.StringData(domain),
   253  		},
   254  		port: port,
   255  		af:   addressFamilyDomain,
   256  	}, nil
   257  }
   258  
   259  // MustAddrFromDomainPort calls [AddrFromDomainPort] and panics on error.
   260  func MustAddrFromDomainPort(domain string, port uint16) Addr {
   261  	addr, err := AddrFromDomainPort(domain, port)
   262  	if err != nil {
   263  		panic(err)
   264  	}
   265  	return addr
   266  }
   267  
   268  // AddrFromHostPort returns an Addr from the provided host string and port number.
   269  // The host string may be a string representation of an IP address or a domain name.
   270  func AddrFromHostPort(host string, port uint16) (Addr, error) {
   271  	if host == "" {
   272  		host = "::"
   273  	}
   274  
   275  	if ip, err := netip.ParseAddr(host); err == nil {
   276  		return Addr{addr: *(*netipAddrHeader)(unsafe.Pointer(&ip)), port: port, af: addressFamilyNetip}, nil
   277  	}
   278  
   279  	return AddrFromDomainPort(host, port)
   280  }
   281  
   282  // ParseAddr parses the provided string representation of an address
   283  // and returns the parsed address or an error.
   284  func ParseAddr[T ~[]byte | ~string](s T) (Addr, error) {
   285  	host, portString, err := net.SplitHostPort(*(*string)(unsafe.Pointer(&s)))
   286  	if err != nil {
   287  		return Addr{}, err
   288  	}
   289  
   290  	portNumber, err := strconv.ParseUint(portString, 10, 16)
   291  	if err != nil {
   292  		return Addr{}, fmt.Errorf("failed to parse port string: %w", err)
   293  	}
   294  	port := uint16(portNumber)
   295  
   296  	return AddrFromHostPort(host, port)
   297  }
   298  
   299  type addrPortHeader struct {
   300  	ip   [16]byte
   301  	z    unsafe.Pointer
   302  	port uint16
   303  }
   304  
   305  // AddrPortMappedEqual returns whether the two addresses point to the same endpoint.
   306  // An IPv4 address and an IPv4-mapped IPv6 address pointing to the same endpoint are considered equal.
   307  // For example, 1.1.1.1:53 and [::ffff:1.1.1.1]:53 are considered equal.
   308  func AddrPortMappedEqual(l, r netip.AddrPort) bool {
   309  	lp := (*addrPortHeader)(unsafe.Pointer(&l))
   310  	rp := (*addrPortHeader)(unsafe.Pointer(&r))
   311  	return lp.ip == rp.ip && lp.port == rp.port
   312  }