github.com/sagernet/sing-box@v1.2.7/transport/dhcp/server.go (about)

     1  package dhcp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"net/url"
     8  	"os"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/sagernet/sing-box/adapter"
    14  	"github.com/sagernet/sing-box/common/dialer"
    15  	C "github.com/sagernet/sing-box/constant"
    16  	"github.com/sagernet/sing-box/option"
    17  	"github.com/sagernet/sing-dns"
    18  	"github.com/sagernet/sing-tun"
    19  	"github.com/sagernet/sing/common"
    20  	"github.com/sagernet/sing/common/buf"
    21  	"github.com/sagernet/sing/common/control"
    22  	E "github.com/sagernet/sing/common/exceptions"
    23  	"github.com/sagernet/sing/common/logger"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/common/task"
    27  	"github.com/sagernet/sing/common/x/list"
    28  
    29  	"github.com/insomniacslk/dhcp/dhcpv4"
    30  	mDNS "github.com/miekg/dns"
    31  )
    32  
    33  func init() {
    34  	dns.RegisterTransport([]string{"dhcp"}, NewTransport)
    35  }
    36  
    37  type Transport struct {
    38  	name              string
    39  	ctx               context.Context
    40  	router            adapter.Router
    41  	logger            logger.Logger
    42  	interfaceName     string
    43  	autoInterface     bool
    44  	interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
    45  	transports        []dns.Transport
    46  	updateAccess      sync.Mutex
    47  	updatedAt         time.Time
    48  }
    49  
    50  func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
    51  	linkURL, err := url.Parse(link)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	if linkURL.Host == "" {
    56  		return nil, E.New("missing interface name for DHCP")
    57  	}
    58  	router := adapter.RouterFromContext(ctx)
    59  	if router == nil {
    60  		return nil, E.New("missing router in context")
    61  	}
    62  	transport := &Transport{
    63  		name:          name,
    64  		ctx:           ctx,
    65  		router:        router,
    66  		logger:        logger,
    67  		interfaceName: linkURL.Host,
    68  		autoInterface: linkURL.Host == "auto",
    69  	}
    70  	return transport, nil
    71  }
    72  
    73  func (t *Transport) Name() string {
    74  	return t.name
    75  }
    76  
    77  func (t *Transport) Start() error {
    78  	err := t.fetchServers()
    79  	if err != nil {
    80  		return err
    81  	}
    82  	if t.autoInterface {
    83  		t.interfaceCallback = t.router.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
    84  	}
    85  	return nil
    86  }
    87  
    88  func (t *Transport) Close() error {
    89  	if t.interfaceCallback != nil {
    90  		t.router.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
    91  	}
    92  	return nil
    93  }
    94  
    95  func (t *Transport) Raw() bool {
    96  	return true
    97  }
    98  
    99  func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
   100  	err := t.fetchServers()
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	if len(t.transports) == 0 {
   106  		return nil, E.New("dhcp: empty DNS servers from response")
   107  	}
   108  
   109  	var response *mDNS.Msg
   110  	for _, transport := range t.transports {
   111  		response, err = transport.Exchange(ctx, message)
   112  		if err == nil {
   113  			return response, nil
   114  		}
   115  	}
   116  	return nil, err
   117  }
   118  
   119  func (t *Transport) fetchInterface() (*net.Interface, error) {
   120  	interfaceName := t.interfaceName
   121  	if t.autoInterface {
   122  		if t.router.InterfaceMonitor() == nil {
   123  			return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
   124  		}
   125  		interfaceName = t.router.InterfaceMonitor().DefaultInterfaceName(netip.Addr{})
   126  	}
   127  	if interfaceName == "" {
   128  		return nil, E.New("missing default interface")
   129  	}
   130  	return net.InterfaceByName(interfaceName)
   131  }
   132  
   133  func (t *Transport) fetchServers() error {
   134  	if time.Since(t.updatedAt) < C.DHCPTTL {
   135  		return nil
   136  	}
   137  	t.updateAccess.Lock()
   138  	defer t.updateAccess.Unlock()
   139  	if time.Since(t.updatedAt) < C.DHCPTTL {
   140  		return nil
   141  	}
   142  	return t.updateServers()
   143  }
   144  
   145  func (t *Transport) updateServers() error {
   146  	iface, err := t.fetchInterface()
   147  	if err != nil {
   148  		return E.Cause(err, "dhcp: prepare interface")
   149  	}
   150  
   151  	t.logger.Info("dhcp: query DNS servers on ", iface.Name)
   152  	fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout)
   153  	err = t.fetchServers0(fetchCtx, iface)
   154  	cancel()
   155  	if err != nil {
   156  		return err
   157  	} else if len(t.transports) == 0 {
   158  		return E.New("dhcp: empty DNS servers response")
   159  	} else {
   160  		t.updatedAt = time.Now()
   161  		return nil
   162  	}
   163  }
   164  
   165  func (t *Transport) interfaceUpdated(int) error {
   166  	return t.updateServers()
   167  }
   168  
   169  func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) error {
   170  	var listener net.ListenConfig
   171  	listener.Control = control.Append(listener.Control, control.BindToInterfaceFunc(t.router.InterfaceFinder(), func(network string, address string) (interfaceName string, interfaceIndex int) {
   172  		return iface.Name, iface.Index
   173  	}))
   174  	listener.Control = control.Append(listener.Control, control.ReuseAddr())
   175  	packetConn, err := listener.ListenPacket(t.ctx, "udp4", "0.0.0.0:68")
   176  	if err != nil {
   177  		return err
   178  	}
   179  	defer packetConn.Close()
   180  
   181  	discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer))
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	_, err = packetConn.WriteTo(discovery.ToBytes(), &net.UDPAddr{IP: net.IPv4bcast, Port: 67})
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	var group task.Group
   192  	group.Append0(func(ctx context.Context) error {
   193  		return t.fetchServersResponse(iface, packetConn, discovery.TransactionID)
   194  	})
   195  	group.Cleanup(func() {
   196  		packetConn.Close()
   197  	})
   198  	return group.Run(ctx)
   199  }
   200  
   201  func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error {
   202  	_buffer := buf.StackNewSize(dhcpv4.MaxMessageSize)
   203  	defer common.KeepAlive(_buffer)
   204  	buffer := common.Dup(_buffer)
   205  	defer buffer.Release()
   206  
   207  	for {
   208  		_, _, err := buffer.ReadPacketFrom(packetConn)
   209  		if err != nil {
   210  			return err
   211  		}
   212  
   213  		dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
   214  		if err != nil {
   215  			t.logger.Trace("dhcp: parse DHCP response: ", err)
   216  			return err
   217  		}
   218  
   219  		if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
   220  			t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
   221  			continue
   222  		}
   223  
   224  		if dhcpPacket.TransactionID != transactionID {
   225  			t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
   226  			continue
   227  		}
   228  
   229  		dns := dhcpPacket.DNS()
   230  		if len(dns) == 0 {
   231  			return nil
   232  		}
   233  
   234  		var addrs []netip.Addr
   235  		for _, ip := range dns {
   236  			addr, _ := netip.AddrFromSlice(ip)
   237  			addrs = append(addrs, addr.Unmap())
   238  		}
   239  		return t.recreateServers(iface, addrs)
   240  	}
   241  }
   242  
   243  func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error {
   244  	if len(serverAddrs) > 0 {
   245  		t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
   246  			return it.String()
   247  		}), ","), "]")
   248  	}
   249  
   250  	serverDialer := dialer.NewDefault(t.router, option.DialerOptions{
   251  		BindInterface:      iface.Name,
   252  		UDPFragmentDefault: true,
   253  	})
   254  	var transports []dns.Transport
   255  	for _, serverAddr := range serverAddrs {
   256  		serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53})
   257  		if err != nil {
   258  			return err
   259  		}
   260  		transports = append(transports, serverTransport)
   261  	}
   262  	t.transports = transports
   263  	return nil
   264  }
   265  
   266  func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
   267  	return nil, os.ErrInvalid
   268  }