github.com/icexin/eggos@v0.4.2-0.20220216025428-78b167e4f349/inet/dhcp/client.go (about)

     1  // Copyright 2016 The Netstack Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dhcp
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"fmt"
    12  	"math/rand"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/icexin/eggos/log"
    18  
    19  	"gvisor.dev/gvisor/pkg/tcpip"
    20  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    21  	nheader "gvisor.dev/gvisor/pkg/tcpip/header"
    22  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    23  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    24  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    25  	"gvisor.dev/gvisor/pkg/waiter"
    26  )
    27  
    28  // Client is a DHCP client.
    29  type Client struct {
    30  	stack    *stack.Stack
    31  	nicid    tcpip.NICID
    32  	linkAddr tcpip.LinkAddress
    33  
    34  	mu          sync.Mutex
    35  	addr        tcpip.Address
    36  	cfg         Config
    37  	lease       time.Duration
    38  	cancelRenew func()
    39  }
    40  
    41  // NewClient creates a DHCP client.
    42  //
    43  // TODO(crawshaw): add s.LinkAddr(nicid) to *stack.Stack.
    44  func NewClient(s *stack.Stack, nicid tcpip.NICID, linkAddr tcpip.LinkAddress) *Client {
    45  	return &Client{
    46  		stack:    s,
    47  		nicid:    nicid,
    48  		linkAddr: linkAddr,
    49  	}
    50  }
    51  
    52  // Start starts the DHCP client.
    53  // It will periodically search for an IP address using the Request method.
    54  func (c *Client) Start() {
    55  	go func() {
    56  		for {
    57  			ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
    58  			err := c.Request(ctx, "")
    59  			cancel()
    60  			if err == nil {
    61  				break
    62  			}
    63  		}
    64  	}()
    65  }
    66  
    67  // Address reports the IP address acquired by the DHCP client.
    68  func (c *Client) Address() tcpip.Address {
    69  	c.mu.Lock()
    70  	defer c.mu.Unlock()
    71  	return c.addr
    72  }
    73  
    74  // Config reports the DHCP configuration acquired with the IP address lease.
    75  func (c *Client) Config() Config {
    76  	c.mu.Lock()
    77  	defer c.mu.Unlock()
    78  	return c.cfg
    79  }
    80  
    81  // Shutdown relinquishes any lease and ends any outstanding renewal timers.
    82  func (c *Client) Shutdown() {
    83  	c.mu.Lock()
    84  	defer c.mu.Unlock()
    85  	if c.addr != "" {
    86  		c.stack.RemoveAddress(c.nicid, c.addr)
    87  	}
    88  	if c.cancelRenew != nil {
    89  		c.cancelRenew()
    90  	}
    91  }
    92  
    93  func e(err tcpip.Error) error {
    94  	return errors.New(err.String())
    95  }
    96  
    97  // Request executes a DHCP request session.
    98  //
    99  // On success, it adds a new address to this client's TCPIP stack.
   100  // If the server sets a lease limit a timer is set to automatically
   101  // renew it.
   102  func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) error {
   103  	tcperr := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, nheader.IPv4Any)
   104  	if tcperr != nil {
   105  		return e(tcperr)
   106  	}
   107  	defer c.stack.RemoveAddress(c.nicid, nheader.IPv4Any)
   108  
   109  	clientAddr := tcpip.FullAddress{
   110  		Addr: nheader.IPv4Broadcast,
   111  		Port: clientPort,
   112  		NIC:  c.nicid,
   113  	}
   114  
   115  	serverAddr := &net.UDPAddr{
   116  		IP:   net.IPv4(255, 255, 255, 255),
   117  		Port: serverPort,
   118  	}
   119  
   120  	conn, err := DialUDP(c.stack, &clientAddr, nil, ipv4.ProtocolNumber)
   121  	if err != nil {
   122  		return err
   123  	}
   124  	defer conn.Close()
   125  
   126  	var xid [4]byte
   127  	rand.Read(xid[:])
   128  
   129  	// DHCPDISCOVERY
   130  	options := options{
   131  		{optDHCPMsgType, []byte{byte(dhcpDISCOVER)}},
   132  		{optParamReq, []byte{
   133  			1,  // request subnet mask
   134  			3,  // request router
   135  			15, // domain name
   136  			6,  // domain name server
   137  		}},
   138  	}
   139  	if requestedAddr != "" {
   140  		options = append(options, option{optReqIPAddr, []byte(requestedAddr)})
   141  	}
   142  	h := make(header, headerBaseSize+options.len())
   143  	h.init()
   144  	h.setOp(opRequest)
   145  	copy(h.xidbytes(), xid[:])
   146  	h.setBroadcast()
   147  	copy(h.chaddr(), c.linkAddr)
   148  	h.setOptions(options)
   149  
   150  	_, err = conn.WriteTo(h, serverAddr)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	v := make([]byte, 1024)
   156  	// DHCPOFFER
   157  	for {
   158  		n, err := conn.Read(v)
   159  		if err != nil {
   160  			return err
   161  		}
   162  		h = header(v[:n])
   163  		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
   164  			break
   165  		}
   166  	}
   167  	opts, err := h.options()
   168  	if err != nil {
   169  		return fmt.Errorf("dhcp offer: %v", err)
   170  	}
   171  	log.Infof("[dhcp] offer done")
   172  
   173  	var ack bool
   174  	var cfg Config
   175  
   176  	err = cfg.decode(opts)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	// DHCPREQUEST
   182  	addr := tcpip.Address(h.yiaddr())
   183  	if err := c.stack.AddAddress(c.nicid, ipv4.ProtocolNumber, addr); err != nil {
   184  		if _, ok := err.(*tcpip.ErrDuplicateAddress); ok {
   185  			return e(err)
   186  		}
   187  	}
   188  	defer func() {
   189  		if ack {
   190  			c.mu.Lock()
   191  			c.addr = addr
   192  			c.cfg = cfg
   193  			c.mu.Unlock()
   194  		} else {
   195  			c.stack.RemoveAddress(c.nicid, addr)
   196  		}
   197  	}()
   198  	h.setOp(opRequest)
   199  	for i, b := 0, h.yiaddr(); i < len(b); i++ {
   200  		b[i] = 0
   201  	}
   202  	h.setOptions([]option{
   203  		{optDHCPMsgType, []byte{byte(dhcpREQUEST)}},
   204  		{optReqIPAddr, []byte(addr)},
   205  		{optDHCPServer, []byte(cfg.ServerAddress)},
   206  	})
   207  	log.Infof("[dhcp] offer ip:%s server:%s", addr, cfg.ServerAddress)
   208  	_, err = conn.WriteTo(h, serverAddr)
   209  	if err != nil {
   210  		return err
   211  	}
   212  
   213  	// DHCPACK
   214  	for {
   215  		n, err := conn.Read(v)
   216  		if err != nil {
   217  			return err
   218  		}
   219  		h = header(v[:n])
   220  		if h.isValid() && h.op() == opReply && bytes.Equal(h.xidbytes(), xid[:]) {
   221  			break
   222  		}
   223  	}
   224  	opts, err = h.options()
   225  	if err != nil {
   226  		return fmt.Errorf("dhcp ack: %v", err)
   227  	}
   228  	if err := cfg.decode(opts); err != nil {
   229  		return fmt.Errorf("dhcp ack bad options: %v", err)
   230  	}
   231  	msgtype, err := opts.dhcpMsgType()
   232  	if err != nil {
   233  		return fmt.Errorf("dhcp ack: %v", err)
   234  	}
   235  	ack = msgtype == dhcpACK
   236  	if !ack {
   237  		return fmt.Errorf("dhcp: request not acknowledged")
   238  	}
   239  	log.Infof("[dhcp] lease:%s", cfg.LeaseLength)
   240  	if cfg.LeaseLength != 0 {
   241  		go c.renewAfter(cfg.LeaseLength)
   242  	}
   243  	return nil
   244  }
   245  
   246  func (c *Client) renewAfter(d time.Duration) {
   247  	c.mu.Lock()
   248  	defer c.mu.Unlock()
   249  	if c.cancelRenew != nil {
   250  		c.cancelRenew()
   251  	}
   252  	ctx, cancel := context.WithCancel(context.Background())
   253  	c.cancelRenew = cancel
   254  	go func() {
   255  		timer := time.NewTimer(d)
   256  		defer timer.Stop()
   257  		select {
   258  		case <-ctx.Done():
   259  		case <-timer.C:
   260  			if err := c.Request(ctx, c.addr); err != nil {
   261  				log.Errorf("address renewal failed: %v", err)
   262  				go c.renewAfter(1 * time.Minute)
   263  			}
   264  		}
   265  	}()
   266  }
   267  
   268  func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.UDPConn, error) {
   269  	var wq waiter.Queue
   270  	ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
   271  	if err != nil {
   272  		return nil, errors.New(err.String())
   273  	}
   274  	ep.SocketOptions().SetBroadcast(true)
   275  
   276  	if laddr != nil {
   277  		if err := ep.Bind(*laddr); err != nil {
   278  			ep.Close()
   279  			return nil, &net.OpError{
   280  				Op:   "bind",
   281  				Net:  "udp",
   282  				Addr: fullToUDPAddr(*laddr),
   283  				Err:  errors.New(err.String()),
   284  			}
   285  		}
   286  	}
   287  
   288  	c := gonet.NewUDPConn(s, &wq, ep)
   289  
   290  	if raddr != nil {
   291  		if err := ep.Connect(*raddr); err != nil {
   292  			ep.Close()
   293  			return nil, &net.OpError{
   294  				Op:   "connect",
   295  				Net:  "udp",
   296  				Addr: fullToUDPAddr(*raddr),
   297  				Err:  errors.New(err.String()),
   298  			}
   299  		}
   300  	}
   301  
   302  	return c, nil
   303  }
   304  
   305  func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
   306  	return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
   307  }