github.com/database64128/shadowsocks-go@v1.7.0/conn/addr.go (about)

     1  package conn
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/netip"
     8  	"strconv"
     9  	"unsafe"
    10  )
    11  
    12  type addressFamily byte
    13  
    14  const (
    15  	addressFamilyNone addressFamily = iota
    16  	addressFamilyNetip
    17  	addressFamilyDomain
    18  )
    19  
    20  type netipAddrHeader struct {
    21  	hi uint64
    22  	lo uint64
    23  	z  unsafe.Pointer
    24  }
    25  
    26  type stringHeader struct {
    27  	data unsafe.Pointer
    28  	len  int
    29  }
    30  
    31  // Addr is the base address type used throughout the package.
    32  //
    33  // An Addr is a port number combined with either an IP address or a domain name.
    34  //
    35  // For space efficiency, the IP address and the domain string share the same space.
    36  // The [netip.Addr] is stored in its original layout.
    37  // The domain string's data pointer is stored in the ip.z field.
    38  // Its length is stored at the beginning of the structure.
    39  // This is essentially an unsafe "enum".
    40  type Addr struct {
    41  	_    [0]func()
    42  	addr netipAddrHeader
    43  	port uint16
    44  	af   addressFamily
    45  }
    46  
    47  func (a Addr) ip() netip.Addr {
    48  	return *(*netip.Addr)(unsafe.Pointer(&a))
    49  }
    50  
    51  func (a Addr) ipPort() netip.AddrPort {
    52  	return *(*netip.AddrPort)(unsafe.Pointer(&a))
    53  }
    54  
    55  func (a Addr) domain() (domain string) {
    56  	dp := (*stringHeader)(unsafe.Pointer(&domain))
    57  	dp.data = a.addr.z
    58  	dp.len = *(*int)(unsafe.Pointer(&a))
    59  	return
    60  }
    61  
    62  // Equals returns whether two addresses are the same.
    63  func (a Addr) Equals(b Addr) bool {
    64  	if a.af != b.af || a.port != b.port {
    65  		return false
    66  	}
    67  
    68  	switch a.af {
    69  	case addressFamilyNetip:
    70  		return a.addr == b.addr
    71  	case addressFamilyDomain:
    72  		return a.domain() == b.domain()
    73  	default:
    74  		return true
    75  	}
    76  }
    77  
    78  // IsValid returns whether the address is an initialized address (not a zero value).
    79  func (a Addr) IsValid() bool {
    80  	return a.af != addressFamilyNone
    81  }
    82  
    83  // IsIP returns whether the address is an IP address.
    84  func (a Addr) IsIP() bool {
    85  	return a.af == addressFamilyNetip
    86  }
    87  
    88  // IsDomain returns whether the address is a domain name.
    89  func (a Addr) IsDomain() bool {
    90  	return a.af == addressFamilyDomain
    91  }
    92  
    93  // IP returns the IP address.
    94  //
    95  // If the address is a domain name or zero value, this method panics.
    96  func (a Addr) IP() netip.Addr {
    97  	if a.af != addressFamilyNetip {
    98  		panic("IP() called on non-IP address")
    99  	}
   100  	return a.ip()
   101  }
   102  
   103  // Domain returns the domain name.
   104  //
   105  // If the address is an IP address or zero value, this method panics.
   106  func (a Addr) Domain() string {
   107  	if a.af != addressFamilyDomain {
   108  		panic("Domain() called on non-domain address")
   109  	}
   110  	return a.domain()
   111  }
   112  
   113  // Port returns the port number.
   114  func (a Addr) Port() uint16 {
   115  	return a.port
   116  }
   117  
   118  // IPPort returns a netip.AddrPort.
   119  //
   120  // If the address is a domain name or zero value, this method panics.
   121  func (a Addr) IPPort() netip.AddrPort {
   122  	if a.af != addressFamilyNetip {
   123  		panic("IPPort() called on non-IP address")
   124  	}
   125  	return a.ipPort()
   126  }
   127  
   128  // ResolveIP resolves a domain name string into an IP address.
   129  //
   130  // This function always returns the first IP address returned by the resolver,
   131  // because the resolver takes care of sorting the IP addresses by address family
   132  // availability and preference.
   133  //
   134  // String representations of IP addresses are not supported.
   135  func ResolveIP(host string) (netip.Addr, error) {
   136  	ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host)
   137  	if err != nil {
   138  		return netip.Addr{}, err
   139  	}
   140  	return ips[0], nil
   141  }
   142  
   143  // ResolveIP returns the IP address itself or the resolved IP address of the domain name.
   144  //
   145  // If the address is zero value, this method panics.
   146  func (a Addr) ResolveIP() (netip.Addr, error) {
   147  	switch a.af {
   148  	case addressFamilyNetip:
   149  		return a.ip(), nil
   150  	case addressFamilyDomain:
   151  		return ResolveIP(a.domain())
   152  	default:
   153  		panic("ResolveIP() called on zero value")
   154  	}
   155  }
   156  
   157  // ResolveIPPort returns the IP address itself or the resolved IP address of the domain name
   158  // and the port number as a [netip.AddrPort].
   159  //
   160  // If the address is zero value, this method panics.
   161  func (a Addr) ResolveIPPort() (netip.AddrPort, error) {
   162  	switch a.af {
   163  	case addressFamilyNetip:
   164  		return a.ipPort(), nil
   165  	case addressFamilyDomain:
   166  		ip, err := ResolveIP(a.domain())
   167  		if err != nil {
   168  			return netip.AddrPort{}, err
   169  		}
   170  		return netip.AddrPortFrom(ip, a.port), nil
   171  	default:
   172  		panic("ResolveIPPort() called on zero value")
   173  	}
   174  }
   175  
   176  // Host returns the string representation of the IP address or the domain name.
   177  //
   178  // If the address is zero value, this method panics.
   179  func (a Addr) Host() string {
   180  	switch a.af {
   181  	case addressFamilyNetip:
   182  		return a.ip().String()
   183  	case addressFamilyDomain:
   184  		return a.domain()
   185  	default:
   186  		panic("Host() called on zero value")
   187  	}
   188  }
   189  
   190  // String returns the string representation of the address.
   191  //
   192  // If the address is zero value, an empty string is returned.
   193  func (a Addr) String() string {
   194  	switch a.af {
   195  	case addressFamilyNetip:
   196  		return a.ipPort().String()
   197  	case addressFamilyDomain:
   198  		return fmt.Sprintf("%s:%d", a.domain(), a.port)
   199  	default:
   200  		return ""
   201  	}
   202  }
   203  
   204  // AppendTo appends the string representation of the address to the provided buffer.
   205  //
   206  // If the address is zero value, nothing is appended.
   207  func (a Addr) AppendTo(b []byte) []byte {
   208  	switch a.af {
   209  	case addressFamilyNetip:
   210  		return a.ipPort().AppendTo(b)
   211  	case addressFamilyDomain:
   212  		return fmt.Appendf(b, "%s:%d", a.domain(), a.port)
   213  	default:
   214  		return b
   215  	}
   216  }
   217  
   218  // MarshalText implements the encoding.TextMarshaler MarshalText method.
   219  func (a Addr) MarshalText() ([]byte, error) {
   220  	switch a.af {
   221  	case addressFamilyNetip:
   222  		return a.ipPort().MarshalText()
   223  	case addressFamilyDomain:
   224  		return fmt.Appendf(nil, "%s:%d", a.domain(), a.port), nil
   225  	default:
   226  		return nil, nil
   227  	}
   228  }
   229  
   230  // UnmarshalText implements the encoding.TextUnmarshaler UnmarshalText method.
   231  func (a *Addr) UnmarshalText(text []byte) error {
   232  	addr, err := ParseAddr(text)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	*a = addr
   237  	return nil
   238  }
   239  
   240  // AddrFromIPPort returns an Addr from the provided netip.AddrPort.
   241  func AddrFromIPPort(addrPort netip.AddrPort) (addr Addr) {
   242  	*(*netip.AddrPort)(unsafe.Pointer(&addr)) = addrPort
   243  	addr.af = addressFamilyNetip
   244  	return
   245  }
   246  
   247  // AddrFromDomainPort returns an Addr from the provided domain name and port number.
   248  func AddrFromDomainPort(domain string, port uint16) (Addr, error) {
   249  	if len(domain) == 0 || len(domain) > 255 {
   250  		return Addr{}, fmt.Errorf("length of domain %s out of range [1, 255]", domain)
   251  	}
   252  	dp := (*stringHeader)(unsafe.Pointer(&domain))
   253  	addr := Addr{addr: netipAddrHeader{z: dp.data}, port: port, af: addressFamilyDomain}
   254  	*(*int)(unsafe.Pointer(&addr)) = dp.len
   255  	return addr, nil
   256  }
   257  
   258  // MustAddrFromDomainPort calls [AddrFromDomainPort] and panics on error.
   259  func MustAddrFromDomainPort(domain string, port uint16) Addr {
   260  	addr, err := AddrFromDomainPort(domain, port)
   261  	if err != nil {
   262  		panic(err)
   263  	}
   264  	return addr
   265  }
   266  
   267  // AddrFromHostPort returns an Addr from the provided host string and port number.
   268  // The host string may be a string representation of an IP address or a domain name.
   269  func AddrFromHostPort(host string, port uint16) (Addr, error) {
   270  	if host == "" {
   271  		host = "::"
   272  	}
   273  
   274  	if ip, err := netip.ParseAddr(host); err == nil {
   275  		return Addr{addr: *(*netipAddrHeader)(unsafe.Pointer(&ip)), port: port, af: addressFamilyNetip}, nil
   276  	}
   277  
   278  	return AddrFromDomainPort(host, port)
   279  }
   280  
   281  // ParseAddr parses the provided string representation of an address
   282  // and returns the parsed address or an error.
   283  func ParseAddr[T ~[]byte | ~string](s T) (Addr, error) {
   284  	host, portString, err := net.SplitHostPort(*(*string)(unsafe.Pointer(&s)))
   285  	if err != nil {
   286  		return Addr{}, err
   287  	}
   288  
   289  	portNumber, err := strconv.ParseUint(portString, 10, 16)
   290  	if err != nil {
   291  		return Addr{}, fmt.Errorf("failed to parse port string: %w", err)
   292  	}
   293  	port := uint16(portNumber)
   294  
   295  	return AddrFromHostPort(host, port)
   296  }
   297  
   298  type addrPortHeader struct {
   299  	ip   [16]byte
   300  	z    unsafe.Pointer
   301  	port uint16
   302  }
   303  
   304  // AddrPortMappedEqual returns whether the two addresses point to the same endpoint.
   305  // An IPv4 address and an IPv4-mapped IPv6 address pointing to the same endpoint are considered equal.
   306  // For example, 1.1.1.1:53 and [::ffff:1.1.1.1]:53 are considered equal.
   307  func AddrPortMappedEqual(l, r netip.AddrPort) bool {
   308  	lp := (*addrPortHeader)(unsafe.Pointer(&l))
   309  	rp := (*addrPortHeader)(unsafe.Pointer(&r))
   310  	return lp.ip == rp.ip && lp.port == rp.port
   311  }