github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/netapi/address.go (about)

     1  package netapi
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"math/rand/v2"
     9  	"net"
    10  	"net/netip"
    11  	"strconv"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/log"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    16  	"golang.org/x/exp/constraints"
    17  )
    18  
    19  func PaseNetwork(s string) statistic.Type { return statistic.Type(statistic.Type_value[s]) }
    20  
    21  func ParseAddress(network statistic.Type, addr string) (ad Address, _ error) {
    22  	hostname, portstr, err := net.SplitHostPort(addr)
    23  	if err != nil {
    24  		log.Error("split host port failed", "err", err, "addr", addr)
    25  		hostname = addr
    26  		portstr = "0"
    27  	}
    28  
    29  	port, err := ParsePortStr(portstr)
    30  	if err != nil {
    31  		return nil, fmt.Errorf("parse port failed: %w", err)
    32  	}
    33  
    34  	return ParseAddressPort(network, hostname, port), nil
    35  }
    36  
    37  func ParseDomainPort(network statistic.Type, addr string, port Port) (ad Address) {
    38  	return &DomainAddr{
    39  		hostname: addr,
    40  		port:     port,
    41  		addr:     newAddr(network),
    42  	}
    43  }
    44  
    45  func ParseAddressPort(network statistic.Type, addr string, port Port) (ad Address) {
    46  	if addr, err := netip.ParseAddr(addr); err == nil {
    47  		return &IPAddrPort{
    48  			addr:     newAddr(network),
    49  			addrPort: netip.AddrPortFrom(addr.Unmap(), port.Port()),
    50  		}
    51  	}
    52  
    53  	return ParseDomainPort(network, addr, port)
    54  }
    55  
    56  func ParseTCPAddress(ad *net.TCPAddr) Address {
    57  	return &IPAddr{
    58  		addr: newAddr(statistic.Type_tcp),
    59  		ip:   ad.IP,
    60  		port: ad.Port,
    61  		zone: ad.Zone,
    62  	}
    63  }
    64  
    65  func ParseUDPAddr(ad *net.UDPAddr) Address {
    66  	return &IPAddr{
    67  		addr: newAddr(statistic.Type_udp),
    68  		ip:   ad.IP,
    69  		port: ad.Port,
    70  		zone: ad.Zone,
    71  	}
    72  }
    73  
    74  func ParseIPAddr(ad *net.IPAddr) Address {
    75  	return &IPAddr{
    76  		addr: newAddr(statistic.Type_ip),
    77  		ip:   ad.IP,
    78  		port: 0,
    79  		zone: ad.Zone,
    80  	}
    81  }
    82  
    83  func ParseUnixAddr(ad *net.UnixAddr) Address {
    84  	return &DomainAddr{
    85  		hostname: ad.Name,
    86  		port:     EmptyPort,
    87  		addr:     newAddr(statistic.Type_unix),
    88  	}
    89  }
    90  
    91  func ParseIPAddrPort(net statistic.Type, ip net.IP, port int) Address {
    92  	return &IPAddr{
    93  		addr: newAddr(net),
    94  		ip:   ip,
    95  		port: port,
    96  	}
    97  }
    98  
    99  func ParseAddrPort(net statistic.Type, addrPort netip.AddrPort) Address {
   100  	return &IPAddrPort{
   101  		addrPort: addrPort,
   102  		addr:     newAddr(net),
   103  	}
   104  }
   105  
   106  func ParseSysAddr(ad net.Addr) (Address, error) {
   107  	switch ad := ad.(type) {
   108  	case Address:
   109  		return ad, nil
   110  	case *net.TCPAddr:
   111  		return ParseTCPAddress(ad), nil
   112  	case *net.UDPAddr:
   113  		return ParseUDPAddr(ad), nil
   114  	case *net.IPAddr:
   115  		return ParseIPAddr(ad), nil
   116  	case *net.UnixAddr:
   117  		return ParseUnixAddr(ad), nil
   118  	}
   119  	return ParseAddress(PaseNetwork(ad.Network()), ad.String())
   120  }
   121  
   122  type addr struct {
   123  	network    statistic.Type
   124  	src        AddressSrc
   125  	preferIPv6 bool
   126  	preferIPv4 bool
   127  	resolver   Resolver
   128  }
   129  
   130  func newAddr(net statistic.Type) *addr {
   131  	return &addr{
   132  		network: net,
   133  	}
   134  }
   135  
   136  func (d *addr) SetSrc(src AddressSrc) {
   137  	d.src = src
   138  }
   139  
   140  func (d *addr) SetResolver(resolver Resolver) {
   141  	if resolver == nil {
   142  		return
   143  	}
   144  
   145  	if d.resolver != nil && d.src == AddressSrcDNS {
   146  		return
   147  	}
   148  
   149  	d.resolver = resolver
   150  }
   151  
   152  func (d *addr) Resolver() Resolver {
   153  	if d.resolver != nil {
   154  		return d.resolver
   155  	}
   156  
   157  	return Bootstrap
   158  }
   159  
   160  func (d *addr) Network() string             { return d.network.String() }
   161  func (d *addr) NetworkType() statistic.Type { return d.network }
   162  func (d *addr) PreferIPv6(b bool)           { d.preferIPv6 = b }
   163  func (d *addr) PreferIPv4(b bool)           { d.preferIPv4 = b }
   164  func (d *addr) overrideHostname(s string, port Port) Address {
   165  	if addr, err := netip.ParseAddr(s); err == nil {
   166  		return &IPAddrPort{
   167  			addr:     d,
   168  			addrPort: netip.AddrPortFrom(addr, port.Port()),
   169  		}
   170  	}
   171  
   172  	return &DomainAddr{
   173  		hostname: s,
   174  		addr:     d,
   175  		port:     port,
   176  	}
   177  }
   178  
   179  var _ Address = (*DomainAddr)(nil)
   180  
   181  type DomainAddr struct {
   182  	*addr
   183  	port     Port
   184  	hostname string
   185  }
   186  
   187  func (d *DomainAddr) String() string   { return net.JoinHostPort(d.Hostname(), d.Port().String()) }
   188  func (d *DomainAddr) Hostname() string { return d.hostname }
   189  func (d *DomainAddr) IPs(ctx context.Context) ([]net.IP, error) {
   190  	return d.lookupIP(ctx)
   191  }
   192  
   193  func (d *DomainAddr) IP(ctx context.Context) (net.IP, error) {
   194  	ip, err := d.lookupIP(ctx)
   195  	if err != nil {
   196  		return nil, fmt.Errorf("resolve address %s failed: %w", d.hostname, err)
   197  	}
   198  
   199  	return ip[rand.IntN(len(ip))], nil
   200  }
   201  
   202  func (d *DomainAddr) AddrPort(ctx context.Context) Result[netip.AddrPort] {
   203  	ip, err := d.IP(ctx)
   204  	if err != nil {
   205  		return NewErrResult[netip.AddrPort](fmt.Errorf("resolve address %s failed: %w", d.hostname, err))
   206  	}
   207  
   208  	addr, _ := netip.AddrFromSlice(ip)
   209  
   210  	return NewResult(netip.AddrPortFrom(addr, d.port.Port()))
   211  }
   212  
   213  func (d *DomainAddr) Port() Port   { return d.port }
   214  func (d *DomainAddr) Type() Type   { return FQDN }
   215  func (d *DomainAddr) IsFqdn() bool { return true }
   216  func (d *DomainAddr) lookupIP(ctx context.Context) ([]net.IP, error) {
   217  	if d.preferIPv6 || d.preferIPv4 {
   218  		ips, err := d.Resolver().LookupIP(ctx, d.hostname, func(li *LookupIPOption) {
   219  			if d.preferIPv6 {
   220  				li.AAAA = true
   221  				li.A = false
   222  			}
   223  			if d.preferIPv4 {
   224  				li.A = true
   225  				li.AAAA = false
   226  			}
   227  		})
   228  		if err == nil {
   229  			return ips, nil
   230  		} else {
   231  			log.Warn("resolve ipv6 failed, fallback to ipv4", slog.String("domain", d.hostname), slog.Any("err", err))
   232  		}
   233  	}
   234  
   235  	ips, err := d.Resolver().LookupIP(ctx, d.hostname)
   236  	if err != nil {
   237  		return nil, fmt.Errorf("resolve address failed: %w", err)
   238  	}
   239  
   240  	return ips, nil
   241  }
   242  
   243  func (d *DomainAddr) UDPAddr(ctx context.Context) Result[*net.UDPAddr] {
   244  	ip, err := d.lookupIP(ctx)
   245  	if err != nil {
   246  		return NewErrResult[*net.UDPAddr](fmt.Errorf("resolve udp address %s failed: %w", d.hostname, err))
   247  	}
   248  
   249  	return NewResult(&net.UDPAddr{IP: ip[rand.IntN(len(ip))], Port: int(d.port.Port())})
   250  }
   251  
   252  func (d *DomainAddr) TCPAddr(ctx context.Context) Result[*net.TCPAddr] {
   253  	ip, err := d.lookupIP(ctx)
   254  	if err != nil {
   255  		return NewErrResult[*net.TCPAddr](fmt.Errorf("resolve tcp address %s failed: %w", d.hostname, err))
   256  	}
   257  
   258  	return NewResult(&net.TCPAddr{IP: ip[rand.IntN(len(ip))], Port: int(d.port.Port())})
   259  }
   260  
   261  func (d *DomainAddr) OverrideHostname(s string) Address {
   262  	return d.addr.overrideHostname(s, d.port)
   263  }
   264  
   265  func (d *DomainAddr) OverridePort(p Port) Address {
   266  	return &DomainAddr{
   267  		hostname: d.Hostname(),
   268  		addr:     d.addr,
   269  		port:     p,
   270  	}
   271  }
   272  
   273  type IPAddr struct {
   274  	ip   net.IP
   275  	port int
   276  	zone string
   277  	*addr
   278  }
   279  
   280  func (d *IPAddr) String() string   { return net.JoinHostPort(d.ip.String(), strconv.Itoa(d.port)) }
   281  func (d *IPAddr) Hostname() string { return d.ip.String() }
   282  func (d *IPAddr) AddrPort(context.Context) Result[netip.AddrPort] {
   283  	addr, _ := netip.AddrFromSlice(d.ip)
   284  	return NewResult(netip.AddrPortFrom(addr, uint16(d.port)))
   285  }
   286  func (d *IPAddr) IPs(context.Context) ([]net.IP, error) {
   287  	return []net.IP{d.ip}, nil
   288  }
   289  func (d *IPAddr) IP(context.Context) (net.IP, error) { return d.ip, nil }
   290  func (d *IPAddr) Port() Port                         { return ParsePort(d.port) }
   291  func (d *IPAddr) Type() Type                         { return IP }
   292  func (d *IPAddr) IsFqdn() bool                       { return false }
   293  func (d *IPAddr) UDPAddr(context.Context) Result[*net.UDPAddr] {
   294  	return NewResult(&net.UDPAddr{IP: d.ip, Port: d.port, Zone: d.zone})
   295  }
   296  func (d *IPAddr) TCPAddr(context.Context) Result[*net.TCPAddr] {
   297  	return NewResult(&net.TCPAddr{IP: d.ip, Port: d.port, Zone: d.zone})
   298  }
   299  func (d *IPAddr) OverrideHostname(s string) Address { return d.overrideHostname(s, d.Port()) }
   300  func (d *IPAddr) OverridePort(p Port) Address {
   301  	return &IPAddr{
   302  		addr: d.addr,
   303  		ip:   d.ip,
   304  		port: int(p.Port()),
   305  	}
   306  }
   307  
   308  var _ Address = (*IPAddrPort)(nil)
   309  
   310  type IPAddrPort struct {
   311  	addrPort netip.AddrPort
   312  	*addr
   313  }
   314  
   315  func (d *IPAddrPort) String() string                                  { return d.addrPort.String() }
   316  func (d *IPAddrPort) Hostname() string                                { return d.addrPort.Addr().String() }
   317  func (d *IPAddrPort) AddrPort(context.Context) Result[netip.AddrPort] { return NewResult(d.addrPort) }
   318  func (d *IPAddrPort) IPs(context.Context) ([]net.IP, error) {
   319  	return []net.IP{d.addrPort.Addr().AsSlice()}, nil
   320  }
   321  func (d *IPAddrPort) IP(context.Context) (net.IP, error) { return d.addrPort.Addr().AsSlice(), nil }
   322  func (d *IPAddrPort) Port() Port                         { return ParsePort(d.addrPort.Port()) }
   323  func (d *IPAddrPort) Type() Type                         { return IP }
   324  func (d *IPAddrPort) IsFqdn() bool                       { return false }
   325  func (d *IPAddrPort) UDPAddr(context.Context) Result[*net.UDPAddr] {
   326  	return NewResult(&net.UDPAddr{IP: d.addrPort.Addr().AsSlice(), Port: int(d.addrPort.Port()), Zone: d.addrPort.Addr().Zone()})
   327  }
   328  func (d *IPAddrPort) TCPAddr(context.Context) Result[*net.TCPAddr] {
   329  	return NewResult(&net.TCPAddr{IP: d.addrPort.Addr().AsSlice(), Port: int(d.addrPort.Port()), Zone: d.addrPort.Addr().Zone()})
   330  }
   331  func (d *IPAddrPort) OverrideHostname(s string) Address { return d.overrideHostname(s, d.Port()) }
   332  func (d *IPAddrPort) OverridePort(p Port) Address {
   333  	return &IPAddrPort{
   334  		addr:     d.addr,
   335  		addrPort: netip.AddrPortFrom(d.addrPort.Addr(), p.Port()),
   336  	}
   337  }
   338  
   339  var EmptyAddr Address = &emptyAddr{}
   340  
   341  type emptyAddr struct{}
   342  
   343  func (d emptyAddr) String() string                        { return "" }
   344  func (d emptyAddr) Hostname() string                      { return "" }
   345  func (d emptyAddr) IPs(context.Context) ([]net.IP, error) { return nil, errors.New("empty") }
   346  func (d emptyAddr) IP(context.Context) (net.IP, error)    { return nil, errors.New("empty") }
   347  func (d emptyAddr) AddrPort(context.Context) Result[netip.AddrPort] {
   348  	return NewErrResult[netip.AddrPort](errors.New("empty"))
   349  }
   350  func (d emptyAddr) Port() Port                  { return EmptyPort }
   351  func (d emptyAddr) Network() string             { return "" }
   352  func (d emptyAddr) NetworkType() statistic.Type { return 0 }
   353  func (d emptyAddr) Type() Type                  { return EMPTY }
   354  func (d emptyAddr) IsFqdn() bool                { return false }
   355  func (d emptyAddr) SetSrc(AddressSrc)           {}
   356  func (d emptyAddr) SetResolver(Resolver)        {}
   357  func (d emptyAddr) PreferIPv6(bool)             {}
   358  func (d emptyAddr) PreferIPv4(bool)             {}
   359  func (d emptyAddr) UDPAddr(context.Context) Result[*net.UDPAddr] {
   360  	return NewErrResult[*net.UDPAddr](errors.New("empty"))
   361  }
   362  func (d emptyAddr) TCPAddr(context.Context) Result[*net.TCPAddr] {
   363  	return NewErrResult[*net.TCPAddr](errors.New("empty"))
   364  }
   365  func (d emptyAddr) IPHost(context.Context) (string, error) { return "", errors.New("empty") }
   366  func (d emptyAddr) WithValue(any, any)                     {}
   367  func (d emptyAddr) Value(any) (any, bool)                  { return nil, false }
   368  func (d emptyAddr) RangeValue(func(any, any) bool)         {}
   369  func (d emptyAddr) OverrideHostname(string) Address        { return d }
   370  func (d emptyAddr) OverridePort(Port) Address              { return d }
   371  
   372  type PortUint16 uint16
   373  
   374  func (p PortUint16) Port() uint16   { return uint16(p) }
   375  func (p PortUint16) String() string { return strconv.FormatUint(uint64(p), 10) }
   376  
   377  var EmptyPort Port = PortUint16(0)
   378  
   379  func ParsePort[T constraints.Integer](p T) Port { return PortUint16(p) }
   380  
   381  func ParsePortStr(p string) (Port, error) {
   382  	pt, err := strconv.ParseUint(p, 10, 16)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  
   387  	return PortUint16(pt), nil
   388  }
   389  
   390  func DialHappyEyeballs(ctx context.Context, addr Address) (net.Conn, error) {
   391  	if !addr.IsFqdn() {
   392  		return dialer.DialContext(ctx, "tcp", addr.String())
   393  	}
   394  
   395  	ips, err := addr.IPs(ctx)
   396  	if err != nil {
   397  		return nil, err
   398  	}
   399  
   400  	tcpAddress := make([]*net.TCPAddr, 0, len(ips))
   401  	for _, i := range rand.Perm(len(ips)) {
   402  		tcpAddress = append(tcpAddress, &net.TCPAddr{IP: ips[i], Port: int(addr.Port().Port())})
   403  	}
   404  
   405  	return dialer.DialHappyEyeballs(ctx, tcpAddress)
   406  }