github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/dhcp/server.go (about)

     1  package dhcp
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"net/url"
     8  	"os"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/inazumav/sing-box/adapter"
    14  	"github.com/inazumav/sing-box/common/dialer"
    15  	C "github.com/inazumav/sing-box/constant"
    16  	"github.com/inazumav/sing-box/option"
    17  	"github.com/sagernet/sing-dns"
    18  	"github.com/sagernet/sing-tun"
    19  	"github.com/sagernet/sing/common"
    20  	"github.com/sagernet/sing/common/buf"
    21  	"github.com/sagernet/sing/common/control"
    22  	E "github.com/sagernet/sing/common/exceptions"
    23  	"github.com/sagernet/sing/common/logger"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/common/task"
    27  	"github.com/sagernet/sing/common/x/list"
    28  
    29  	"github.com/insomniacslk/dhcp/dhcpv4"
    30  	mDNS "github.com/miekg/dns"
    31  )
    32  
    33  func init() {
    34  	dns.RegisterTransport([]string{"dhcp"}, NewTransport)
    35  }
    36  
    37  type Transport struct {
    38  	name              string
    39  	ctx               context.Context
    40  	router            adapter.Router
    41  	logger            logger.Logger
    42  	interfaceName     string
    43  	autoInterface     bool
    44  	interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
    45  	transports        []dns.Transport
    46  	updateAccess      sync.Mutex
    47  	updatedAt         time.Time
    48  }
    49  
    50  func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) {
    51  	linkURL, err := url.Parse(link)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	if linkURL.Host == "" {
    56  		return nil, E.New("missing interface name for DHCP")
    57  	}
    58  	router := adapter.RouterFromContext(ctx)
    59  	if router == nil {
    60  		return nil, E.New("missing router in context")
    61  	}
    62  	transport := &Transport{
    63  		name:          name,
    64  		ctx:           ctx,
    65  		router:        router,
    66  		logger:        logger,
    67  		interfaceName: linkURL.Host,
    68  		autoInterface: linkURL.Host == "auto",
    69  	}
    70  	return transport, nil
    71  }
    72  
    73  func (t *Transport) Name() string {
    74  	return t.name
    75  }
    76  
    77  func (t *Transport) Start() error {
    78  	err := t.fetchServers()
    79  	if err != nil {
    80  		return err
    81  	}
    82  	if t.autoInterface {
    83  		t.interfaceCallback = t.router.InterfaceMonitor().RegisterCallback(t.interfaceUpdated)
    84  	}
    85  	return nil
    86  }
    87  
    88  func (t *Transport) Reset() {
    89  }
    90  
    91  func (t *Transport) Close() error {
    92  	if t.interfaceCallback != nil {
    93  		t.router.InterfaceMonitor().UnregisterCallback(t.interfaceCallback)
    94  	}
    95  	return nil
    96  }
    97  
    98  func (t *Transport) Raw() bool {
    99  	return true
   100  }
   101  
   102  func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
   103  	err := t.fetchServers()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	if len(t.transports) == 0 {
   109  		return nil, E.New("dhcp: empty DNS servers from response")
   110  	}
   111  
   112  	var response *mDNS.Msg
   113  	for _, transport := range t.transports {
   114  		response, err = transport.Exchange(ctx, message)
   115  		if err == nil {
   116  			return response, nil
   117  		}
   118  	}
   119  	return nil, err
   120  }
   121  
   122  func (t *Transport) fetchInterface() (*net.Interface, error) {
   123  	interfaceName := t.interfaceName
   124  	if t.autoInterface {
   125  		if t.router.InterfaceMonitor() == nil {
   126  			return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface")
   127  		}
   128  		interfaceName = t.router.InterfaceMonitor().DefaultInterfaceName(netip.Addr{})
   129  	}
   130  	if interfaceName == "" {
   131  		return nil, E.New("missing default interface")
   132  	}
   133  	return net.InterfaceByName(interfaceName)
   134  }
   135  
   136  func (t *Transport) fetchServers() error {
   137  	if time.Since(t.updatedAt) < C.DHCPTTL {
   138  		return nil
   139  	}
   140  	t.updateAccess.Lock()
   141  	defer t.updateAccess.Unlock()
   142  	if time.Since(t.updatedAt) < C.DHCPTTL {
   143  		return nil
   144  	}
   145  	return t.updateServers()
   146  }
   147  
   148  func (t *Transport) updateServers() error {
   149  	iface, err := t.fetchInterface()
   150  	if err != nil {
   151  		return E.Cause(err, "dhcp: prepare interface")
   152  	}
   153  
   154  	t.logger.Info("dhcp: query DNS servers on ", iface.Name)
   155  	fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout)
   156  	err = t.fetchServers0(fetchCtx, iface)
   157  	cancel()
   158  	if err != nil {
   159  		return err
   160  	} else if len(t.transports) == 0 {
   161  		return E.New("dhcp: empty DNS servers response")
   162  	} else {
   163  		t.updatedAt = time.Now()
   164  		return nil
   165  	}
   166  }
   167  
   168  func (t *Transport) interfaceUpdated(int) {
   169  	err := t.updateServers()
   170  	if err != nil {
   171  		t.logger.Error("update servers: ", err)
   172  	}
   173  }
   174  
   175  func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) error {
   176  	var listener net.ListenConfig
   177  	listener.Control = control.Append(listener.Control, control.BindToInterfaceFunc(t.router.InterfaceFinder(), func(network string, address string) (interfaceName string, interfaceIndex int) {
   178  		return iface.Name, iface.Index
   179  	}))
   180  	listener.Control = control.Append(listener.Control, control.ReuseAddr())
   181  	packetConn, err := listener.ListenPacket(t.ctx, "udp4", "0.0.0.0:68")
   182  	if err != nil {
   183  		return err
   184  	}
   185  	defer packetConn.Close()
   186  
   187  	discovery, err := dhcpv4.NewDiscovery(iface.HardwareAddr, dhcpv4.WithBroadcast(true), dhcpv4.WithRequestedOptions(dhcpv4.OptionDomainNameServer))
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	_, err = packetConn.WriteTo(discovery.ToBytes(), &net.UDPAddr{IP: net.IPv4bcast, Port: 67})
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	var group task.Group
   198  	group.Append0(func(ctx context.Context) error {
   199  		return t.fetchServersResponse(iface, packetConn, discovery.TransactionID)
   200  	})
   201  	group.Cleanup(func() {
   202  		packetConn.Close()
   203  	})
   204  	return group.Run(ctx)
   205  }
   206  
   207  func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.PacketConn, transactionID dhcpv4.TransactionID) error {
   208  	buffer := buf.NewSize(dhcpv4.MaxMessageSize)
   209  	defer buffer.Release()
   210  
   211  	for {
   212  		_, _, err := buffer.ReadPacketFrom(packetConn)
   213  		if err != nil {
   214  			return err
   215  		}
   216  
   217  		dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
   218  		if err != nil {
   219  			t.logger.Trace("dhcp: parse DHCP response: ", err)
   220  			return err
   221  		}
   222  
   223  		if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
   224  			t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
   225  			continue
   226  		}
   227  
   228  		if dhcpPacket.TransactionID != transactionID {
   229  			t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
   230  			continue
   231  		}
   232  
   233  		dns := dhcpPacket.DNS()
   234  		if len(dns) == 0 {
   235  			return nil
   236  		}
   237  
   238  		var addrs []netip.Addr
   239  		for _, ip := range dns {
   240  			addr, _ := netip.AddrFromSlice(ip)
   241  			addrs = append(addrs, addr.Unmap())
   242  		}
   243  		return t.recreateServers(iface, addrs)
   244  	}
   245  }
   246  
   247  func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error {
   248  	if len(serverAddrs) > 0 {
   249  		t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
   250  			return it.String()
   251  		}), ","), "]")
   252  	}
   253  
   254  	serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{
   255  		BindInterface:      iface.Name,
   256  		UDPFragmentDefault: true,
   257  	}))
   258  	var transports []dns.Transport
   259  	for _, serverAddr := range serverAddrs {
   260  		serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53})
   261  		if err != nil {
   262  			return err
   263  		}
   264  		transports = append(transports, serverTransport)
   265  	}
   266  	t.transports = transports
   267  	return nil
   268  }
   269  
   270  func (t *Transport) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
   271  	return nil, os.ErrInvalid
   272  }