github.com/sagernet/sing-box@v1.9.0-rc.20/transport/dhcp/server.go (about)

     1  package dhcp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"net/url"
     8  	"os"
     9  	"runtime"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/sagernet/sing-box/adapter"
    15  	"github.com/sagernet/sing-box/common/dialer"
    16  	C "github.com/sagernet/sing-box/constant"
    17  	"github.com/sagernet/sing-box/option"
    18  	"github.com/sagernet/sing-dns"
    19  	"github.com/sagernet/sing-tun"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/buf"
    22  	"github.com/sagernet/sing/common/control"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	"github.com/sagernet/sing/common/task"
    25  	"github.com/sagernet/sing/common/x/list"
    26  
    27  	"github.com/insomniacslk/dhcp/dhcpv4"
    28  	mDNS "github.com/miekg/dns"
    29  )
    30  
    31  func init() {
    32  	dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
    33  		return NewTransport(options)
    34  	})
    35  }
    36  
    37  type Transport struct {
    38  	options           dns.TransportOptions
    39  	router            adapter.Router
    40  	interfaceName     string
    41  	autoInterface     bool
    42  	interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
    43  	transports        []dns.Transport
    44  	updateAccess      sync.Mutex
    45  	updatedAt         time.Time
    46  }
    47  
    48  func NewTransport(options dns.TransportOptions) (*Transport, error) {
    49  	linkURL, err := url.Parse(options.Address)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	if linkURL.Host == "" {
    54  		return nil, E.New("missing interface name for DHCP")
    55  	}
    56  	router := adapter.RouterFromContext(options.Context)
    57  	if router == nil {
    58  		return nil, E.New("missing router in context")
    59  	}
    60  	transport := &Transport{
    61  		options:       options,
    62  		router:        router,
    63  		interfaceName: linkURL.Host,
    64  		autoInterface: linkURL.Host == "auto",
    65  	}
    66  	return transport, nil
    67  }
    68  
    69  func (t *Transport) Name() string {
    70  	return t.options.Name
    71  }
    72  
    73  func (t *Transport) Start() error {
    74  	err := t.fetchServers()
    75  	if err != nil {
    76  		return err
    77  	}
    78  	if t.autoInterface {
    79  		t.interfaceCallback = t.router.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
    80  	}
    81  	return nil
    82  }
    83  
    84  func (t *Transport) Reset() {
    85  	for _, transport := range t.transports {
    86  		transport.Reset()
    87  	}
    88  }
    89  
    90  func (t *Transport) Close() error {
    91  	for _, transport := range t.transports {
    92  		transport.Close()
    93  	}
    94  	if t.interfaceCallback != nil {
    95  		t.router.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
    96  	}
    97  	return nil
    98  }
    99  
   100  func (t *Transport) Raw() bool {
   101  	return true
   102  }
   103  
   104  func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
   105  	err := t.fetchServers()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	if len(t.transports) == 0 {
   111  		return nil, E.New("dhcp: empty DNS servers from response")
   112  	}
   113  
   114  	var response *mDNS.Msg
   115  	for _, transport := range t.transports {
   116  		response, err = transport.Exchange(ctx, message)
   117  		if err == nil {
   118  			return response, nil
   119  		}
   120  	}
   121  	return nil, err
   122  }
   123  
   124  func (t *Transport) fetchInterface() (*net.Interface, error) {
   125  	interfaceName := t.interfaceName
   126  	if t.autoInterface {
   127  		if t.router.InterfaceMonitor() == nil {
   128  			return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
   129  		}
   130  		interfaceName = t.router.InterfaceMonitor().DefaultInterfaceName(netip.Addr{})
   131  	}
   132  	if interfaceName == "" {
   133  		return nil, E.New("missing default interface")
   134  	}
   135  	return net.InterfaceByName(interfaceName)
   136  }
   137  
   138  func (t *Transport) fetchServers() error {
   139  	if time.Since(t.updatedAt) < C.DHCPTTL {
   140  		return nil
   141  	}
   142  	t.updateAccess.Lock()
   143  	defer t.updateAccess.Unlock()
   144  	if time.Since(t.updatedAt) < C.DHCPTTL {
   145  		return nil
   146  	}
   147  	return t.updateServers()
   148  }
   149  
   150  func (t *Transport) updateServers() error {
   151  	iface, err := t.fetchInterface()
   152  	if err != nil {
   153  		return E.Cause(err, "dhcp: prepare interface")
   154  	}
   155  
   156  	t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name)
   157  	fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout)
   158  	err = t.fetchServers0(fetchCtx, iface)
   159  	cancel()
   160  	if err != nil {
   161  		return err
   162  	} else if len(t.transports) == 0 {
   163  		return E.New("dhcp: empty DNS servers response")
   164  	} else {
   165  		t.updatedAt = time.Now()
   166  		return nil
   167  	}
   168  }
   169  
   170  func (t *Transport) interfaceUpdated(int) {
   171  	err := t.updateServers()
   172  	if err != nil {
   173  		t.options.Logger.Error("update servers: ", err)
   174  	}
   175  }
   176  
   177  func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) error {
   178  	var listener net.ListenConfig
   179  	listener.Control = control.Append(listener.Control, control.BindToInterface(t.router.InterfaceFinder(), iface.Name, iface.Index))
   180  	listener.Control = control.Append(listener.Control, control.ReuseAddr())
   181  	listenAddr := "0.0.0.0:68"
   182  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
   183  		listenAddr = "255.255.255.255:68"
   184  	}
   185  	packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	defer packetConn.Close()
   190  
   191  	discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer))
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	_, err = packetConn.WriteTo(discovery.ToBytes(), &net.UDPAddr{IP: net.IPv4bcast, Port: 67})
   197  	if err != nil {
   198  		return err
   199  	}
   200  
   201  	var group task.Group
   202  	group.Append0(func(ctx context.Context) error {
   203  		return t.fetchServersResponse(iface, packetConn, discovery.TransactionID)
   204  	})
   205  	group.Cleanup(func() {
   206  		packetConn.Close()
   207  	})
   208  	return group.Run(ctx)
   209  }
   210  
   211  func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error {
   212  	buffer := buf.NewSize(dhcpv4.MaxMessageSize)
   213  	defer buffer.Release()
   214  
   215  	for {
   216  		_, _, err := buffer.ReadPacketFrom(packetConn)
   217  		if err != nil {
   218  			return err
   219  		}
   220  
   221  		dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
   222  		if err != nil {
   223  			t.options.Logger.Trace("dhcp: parse DHCP response: ", err)
   224  			return err
   225  		}
   226  
   227  		if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
   228  			t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
   229  			continue
   230  		}
   231  
   232  		if dhcpPacket.TransactionID != transactionID {
   233  			t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
   234  			continue
   235  		}
   236  
   237  		dns := dhcpPacket.DNS()
   238  		if len(dns) == 0 {
   239  			return nil
   240  		}
   241  
   242  		var addrs []netip.Addr
   243  		for _, ip := range dns {
   244  			addr, _ := netip.AddrFromSlice(ip)
   245  			addrs = append(addrs, addr.Unmap())
   246  		}
   247  		return t.recreateServers(iface, addrs)
   248  	}
   249  }
   250  
   251  func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error {
   252  	if len(serverAddrs) > 0 {
   253  		t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
   254  			return it.String()
   255  		}), ","), "]")
   256  	}
   257  	serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{
   258  		BindInterface:      iface.Name,
   259  		UDPFragmentDefault: true,
   260  	}))
   261  	var transports []dns.Transport
   262  	for _, serverAddr := range serverAddrs {
   263  		newOptions := t.options
   264  		newOptions.Address = serverAddr.String()
   265  		newOptions.Dialer = serverDialer
   266  		serverTransport, err := dns.NewUDPTransport(newOptions)
   267  		if err != nil {
   268  			return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr)
   269  		}
   270  		transports = append(transports, serverTransport)
   271  	}
   272  	for _, transport := range t.transports {
   273  		transport.Close()
   274  	}
   275  	t.transports = transports
   276  	return nil
   277  }
   278  
   279  func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
   280  	return nil, os.ErrInvalid
   281  }