code.flowtr.dev/mirrors/u-root@v1.0.0/pkg/dhcp6client/client_test.go (about)

     1  // Copyright 2017-2018 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dhcp6client
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"net"
    12  	"syscall"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/mdlayher/dhcp6"
    17  )
    18  
    19  type timeoutErr struct{}
    20  
    21  func (timeoutErr) Error() string {
    22  	return "i/o timeout"
    23  }
    24  
    25  func (timeoutErr) Timeout() bool {
    26  	return true
    27  }
    28  
    29  type udpPacket struct {
    30  	source  *net.UDPAddr
    31  	dest    *net.UDPAddr
    32  	payload []byte
    33  }
    34  
    35  // mockUDPConn implements net.PacketConn.
    36  type mockUDPConn struct {
    37  	// This'll just be nil for all the methods we don't implement.
    38  
    39  	// in is the queue of packets ReadFromUDP reads from.
    40  	//
    41  	// ReadFromUDP returns io.EOF when in is closed.
    42  	in chan udpPacket
    43  
    44  	inTimer *time.Timer
    45  
    46  	// out is the queue of packets WriteTo writes to.
    47  	out chan<- udpPacket
    48  
    49  	closed bool
    50  }
    51  
    52  func newMockUDPConn(in chan udpPacket, out chan<- udpPacket) *mockUDPConn {
    53  	return &mockUDPConn{
    54  		in:  in,
    55  		out: out,
    56  	}
    57  }
    58  
    59  // SetReadDeadline implements PacketConn.SetReadDeadline.
    60  func (m *mockUDPConn) SetReadDeadline(t time.Time) error {
    61  	duration := t.Sub(time.Now())
    62  	if duration < 0 {
    63  		return fmt.Errorf("deadline must be in the future")
    64  	}
    65  	m.inTimer = time.NewTimer(duration)
    66  	return nil
    67  }
    68  
    69  func (m *mockUDPConn) LocalAddr() net.Addr {
    70  	panic("unused")
    71  }
    72  
    73  func (m *mockUDPConn) SetWriteDeadline(t time.Time) error {
    74  	panic("unused")
    75  }
    76  
    77  func (m *mockUDPConn) SetDeadline(t time.Time) error {
    78  	panic("unused")
    79  }
    80  
    81  // Close implements PacketConn.Close.
    82  func (m *mockUDPConn) Close() error {
    83  	m.closed = true
    84  	close(m.out)
    85  	return nil
    86  }
    87  
    88  // ReadFrom is a mock for PacketConn.ReadFromUDP.
    89  func (m *mockUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
    90  	// Make sure we don't have data waiting.
    91  	select {
    92  	case p, ok := <-m.in:
    93  		if !ok {
    94  			// Connection was closed.
    95  			return 0, nil, nil
    96  		}
    97  		return copy(b, p.payload), p.source, nil
    98  	default:
    99  	}
   100  
   101  	select {
   102  	case p, ok := <-m.in:
   103  		if !ok {
   104  			return 0, nil, nil
   105  		}
   106  		return copy(b, p.payload), p.source, nil
   107  	case <-m.inTimer.C:
   108  		// This net.OpError will return true for Timeout().
   109  		return 0, nil, &net.OpError{Err: timeoutErr{}}
   110  	}
   111  }
   112  
   113  // WriteTo is a mock for PacketConn.WriteTo.
   114  func (m *mockUDPConn) WriteTo(b []byte, dest net.Addr) (int, error) {
   115  	if m.closed {
   116  		return 0, syscall.EBADF
   117  	}
   118  
   119  	m.out <- udpPacket{
   120  		dest:    dest.(*net.UDPAddr),
   121  		payload: b,
   122  	}
   123  	return len(b), nil
   124  }
   125  
   126  type server struct {
   127  	in  chan udpPacket
   128  	out chan udpPacket
   129  
   130  	received []*dhcp6.Packet
   131  
   132  	// Each received packet can have more than one response (in theory,
   133  	// from different servers sending different Advertise, for example).
   134  	responses [][]*dhcp6.Packet
   135  }
   136  
   137  func (s *server) serve(ctx context.Context) {
   138  	go func() {
   139  		select {
   140  		case udpPkt, ok := <-s.in:
   141  			if !ok {
   142  				break
   143  			}
   144  
   145  			// What did we get?
   146  			var pkt dhcp6.Packet
   147  			if err := (&pkt).UnmarshalBinary(udpPkt.payload); err != nil {
   148  				panic(fmt.Sprintf("invalid dhcp6 packet %q: %v", udpPkt.payload, err))
   149  			}
   150  			s.received = append(s.received, &pkt)
   151  
   152  			if len(s.responses) > 0 {
   153  				resps := s.responses[0]
   154  				// What should we send in response?
   155  				for _, resp := range resps {
   156  					bin, err := resp.MarshalBinary()
   157  					if err != nil {
   158  						panic(fmt.Sprintf("failed to serialize dhcp6 packet %v: %v", resp, err))
   159  					}
   160  					s.out <- udpPacket{
   161  						source:  udpPkt.dest,
   162  						payload: bin,
   163  					}
   164  				}
   165  				s.responses = s.responses[1:]
   166  			}
   167  
   168  		case <-ctx.Done():
   169  			break
   170  		}
   171  
   172  		// We're done sending stuff.
   173  		close(s.out)
   174  	}()
   175  
   176  }
   177  
   178  func ComparePacket(got *dhcp6.Packet, want *dhcp6.Packet) error {
   179  	aa, err := got.MarshalBinary()
   180  	if err != nil {
   181  		panic(err)
   182  	}
   183  	bb, err := want.MarshalBinary()
   184  	if err != nil {
   185  		panic(err)
   186  	}
   187  	if bytes.Compare(aa, bb) != 0 {
   188  		return fmt.Errorf("packet got %v, want %v", got, want)
   189  	}
   190  	return nil
   191  }
   192  
   193  func pktsExpected(got []*dhcp6.Packet, want []*dhcp6.Packet) error {
   194  	if len(got) != len(want) {
   195  		return fmt.Errorf("got %d packets, want %d packets", len(got), len(want))
   196  	}
   197  
   198  	for i := range got {
   199  		if err := ComparePacket(got[i], want[i]); err != nil {
   200  			return err
   201  		}
   202  	}
   203  	return nil
   204  }
   205  
   206  func serveAndClient(ctx context.Context, responses [][]*dhcp6.Packet) (*Client, *mockUDPConn) {
   207  	// These are the client's channels.
   208  	in := make(chan udpPacket, 100)
   209  	out := make(chan udpPacket, 100)
   210  
   211  	mockConn := &mockUDPConn{
   212  		in:  in,
   213  		out: out,
   214  	}
   215  
   216  	mc := &Client{
   217  		conn:    mockConn,
   218  		retry:   1,
   219  		timeout: time.Second,
   220  	}
   221  
   222  	// Of course, for the server they are reversed.
   223  	s := &server{
   224  		in:        out,
   225  		out:       in,
   226  		responses: responses,
   227  	}
   228  	go s.serve(ctx)
   229  
   230  	return mc, mockConn
   231  }
   232  
   233  func TestSimpleSendAndRead(t *testing.T) {
   234  	for _, tt := range []struct {
   235  		desc   string
   236  		send   *dhcp6.Packet
   237  		server []*dhcp6.Packet
   238  
   239  		// If want is nil, we assume server contains what is wanted.
   240  		want    []*dhcp6.Packet
   241  		wantErr error
   242  	}{
   243  		{
   244  			desc: "two response packets",
   245  			send: &dhcp6.Packet{
   246  				MessageType:   dhcp6.MessageTypeSolicit,
   247  				TransactionID: [3]byte{0x33, 0x33, 0x33},
   248  			},
   249  			server: []*dhcp6.Packet{
   250  				{
   251  					MessageType:   dhcp6.MessageTypeAdvertise,
   252  					TransactionID: [3]byte{0x33, 0x33, 0x33},
   253  				},
   254  				{
   255  					MessageType:   dhcp6.MessageTypeAdvertise,
   256  					TransactionID: [3]byte{0x33, 0x33, 0x33},
   257  				},
   258  			},
   259  		},
   260  		{
   261  			desc: "one response packet",
   262  			send: &dhcp6.Packet{
   263  				MessageType:   dhcp6.MessageTypeSolicit,
   264  				TransactionID: [3]byte{0x33, 0x33, 0x33},
   265  			},
   266  			server: []*dhcp6.Packet{
   267  				{
   268  					MessageType:   dhcp6.MessageTypeAdvertise,
   269  					TransactionID: [3]byte{0x33, 0x33, 0x33},
   270  				},
   271  			},
   272  		},
   273  		{
   274  			desc: "one response packet, one invalid XID",
   275  			send: &dhcp6.Packet{
   276  				MessageType:   dhcp6.MessageTypeSolicit,
   277  				TransactionID: [3]byte{0x33, 0x33, 0x33},
   278  			},
   279  			server: []*dhcp6.Packet{
   280  				{
   281  					MessageType:   dhcp6.MessageTypeAdvertise,
   282  					TransactionID: [3]byte{0x33, 0x33, 0x33},
   283  				},
   284  				{
   285  					MessageType:   dhcp6.MessageTypeAdvertise,
   286  					TransactionID: [3]byte{0x77, 0x77, 0x77},
   287  				},
   288  			},
   289  			want: []*dhcp6.Packet{
   290  				{
   291  					MessageType:   dhcp6.MessageTypeAdvertise,
   292  					TransactionID: [3]byte{0x33, 0x33, 0x33},
   293  				},
   294  			},
   295  		},
   296  		{
   297  			desc: "discard wrong XID",
   298  			send: &dhcp6.Packet{
   299  				MessageType:   dhcp6.MessageTypeSolicit,
   300  				TransactionID: [3]byte{0x33, 0x33, 0x33},
   301  			},
   302  			server: []*dhcp6.Packet{
   303  				{
   304  					MessageType:   dhcp6.MessageTypeAdvertise,
   305  					TransactionID: [3]byte{0x00, 0x00, 0x00},
   306  				},
   307  			},
   308  			want:    []*dhcp6.Packet{}, // Explicitly empty.
   309  			wantErr: context.DeadlineExceeded,
   310  		},
   311  		{
   312  			desc: "no response, timeout",
   313  			send: &dhcp6.Packet{
   314  				MessageType:   dhcp6.MessageTypeSolicit,
   315  				TransactionID: [3]byte{0x33, 0x33, 0x33},
   316  			},
   317  			wantErr: context.DeadlineExceeded,
   318  		},
   319  	} {
   320  		// Both server and client only get 2 seconds.
   321  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   322  		defer cancel()
   323  
   324  		mc, _ := serveAndClient(ctx, [][]*dhcp6.Packet{tt.server})
   325  		defer mc.conn.Close()
   326  
   327  		wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, tt.send)
   328  
   329  		var rcvd []*dhcp6.Packet
   330  		for packet := range out {
   331  			rcvd = append(rcvd, packet.Packet)
   332  		}
   333  
   334  		wg.Wait()
   335  		if err, ok := <-errCh; ok && err.Err != tt.wantErr {
   336  			t.Errorf("SimpleSendAndRead(%v): got %v, want %v", tt.send, err.Err, tt.wantErr)
   337  		} else if !ok && tt.wantErr != nil {
   338  			t.Errorf("got no error, want %v", tt.wantErr)
   339  		}
   340  
   341  		want := tt.want
   342  		if want == nil {
   343  			want = tt.server
   344  		}
   345  		if err := pktsExpected(rcvd, want); err != nil {
   346  			t.Errorf("got unexpected packets: %v", err)
   347  		}
   348  	}
   349  }
   350  
   351  func TestSimpleSendAndReadHandleCancel(t *testing.T) {
   352  	pkt := &dhcp6.Packet{
   353  		MessageType:   dhcp6.MessageTypeSolicit,
   354  		TransactionID: [3]byte{0x33, 0x33, 0x33},
   355  	}
   356  
   357  	responses := []*dhcp6.Packet{
   358  		{
   359  			MessageType:   dhcp6.MessageTypeAdvertise,
   360  			TransactionID: [3]byte{0x33, 0x33, 0x33},
   361  		},
   362  		{
   363  			MessageType:   dhcp6.MessageTypeRelayRepl,
   364  			TransactionID: [3]byte{0x33, 0x33, 0x33},
   365  		},
   366  		{
   367  			MessageType:   dhcp6.MessageTypeInformationRequest,
   368  			TransactionID: [3]byte{0x33, 0x33, 0x33},
   369  		},
   370  		{
   371  			MessageType:   dhcp6.MessageTypeReply,
   372  			TransactionID: [3]byte{0x33, 0x33, 0x33},
   373  		},
   374  	}
   375  
   376  	// Both the server and client only get 2 seconds.
   377  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   378  	defer cancel()
   379  
   380  	mc, udpConn := serveAndClient(ctx, [][]*dhcp6.Packet{responses})
   381  	defer mc.conn.Close()
   382  
   383  	wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt)
   384  
   385  	var counter int
   386  	for range out {
   387  		counter++
   388  		if counter == 2 {
   389  			cancel()
   390  		}
   391  	}
   392  
   393  	wg.Wait()
   394  	if err, ok := <-errCh; ok {
   395  		t.Errorf("got %v, want nil error", err)
   396  	}
   397  
   398  	// Make sure that two packets are still in the queue to be read.
   399  	for packet := range udpConn.in {
   400  		bin, err := responses[counter].MarshalBinary()
   401  		if err != nil {
   402  			panic(err)
   403  		}
   404  		if bytes.Compare(packet.payload, bin) != 0 {
   405  			t.Errorf("SimpleSendAndRead read more packets than expected!")
   406  		}
   407  		counter++
   408  	}
   409  }
   410  
   411  func TestSimpleSendAndReadDiscardGarbage(t *testing.T) {
   412  	pkt := &dhcp6.Packet{
   413  		MessageType:   dhcp6.MessageTypeSolicit,
   414  		TransactionID: [3]byte{0x33, 0x33, 0x33},
   415  	}
   416  
   417  	responses := []*dhcp6.Packet{
   418  		{
   419  			MessageType:   dhcp6.MessageTypeAdvertise,
   420  			TransactionID: [3]byte{0x33, 0x33, 0x33},
   421  		},
   422  	}
   423  
   424  	// Both the server and client only get 2 seconds.
   425  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   426  	defer cancel()
   427  
   428  	mc, udpConn := serveAndClient(ctx, [][]*dhcp6.Packet{responses})
   429  	defer mc.conn.Close()
   430  
   431  	udpConn.in <- udpPacket{
   432  		payload: []byte{0x01}, // Too short for valid DHCPv6 packet.
   433  	}
   434  
   435  	wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt)
   436  
   437  	var i int
   438  	for recvd := range out {
   439  		if err := ComparePacket(recvd.Packet, responses[i]); err != nil {
   440  			t.Error(err)
   441  		}
   442  		i++
   443  	}
   444  
   445  	wg.Wait()
   446  	if err, ok := <-errCh; ok {
   447  		t.Errorf("SimpleSendAndRead(%v): got %v %v, want %v", pkt, ok, err, nil)
   448  	}
   449  	if i != len(responses) {
   450  		t.Errorf("should have received %d valid packet, counter is %d", len(responses), i)
   451  	}
   452  }
   453  
   454  func TestSimpleSendAndReadDiscardGarbageTimeout(t *testing.T) {
   455  	pkt := &dhcp6.Packet{
   456  		MessageType:   dhcp6.MessageTypeSolicit,
   457  		TransactionID: [3]byte{0x33, 0x33, 0x33},
   458  	}
   459  
   460  	// Both the server and client only get 2 seconds.
   461  	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   462  	defer cancel()
   463  
   464  	mc, udpConn := serveAndClient(ctx, nil)
   465  	defer mc.conn.Close()
   466  
   467  	udpConn.in <- udpPacket{
   468  		payload: []byte{0x01}, // Too short for valid DHCPv6 packet.
   469  	}
   470  
   471  	wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt)
   472  
   473  	var counter int
   474  	for range out {
   475  		counter++
   476  	}
   477  
   478  	wg.Wait()
   479  	if err, ok := <-errCh; !ok || err == nil || err.Err != context.DeadlineExceeded {
   480  		t.Errorf("SimpleSendAndRead(%v): got %v %v, want %v", pkt, ok, err, context.DeadlineExceeded)
   481  	}
   482  	if counter != 0 {
   483  		t.Errorf("should not have received a valid packet, counter is %d", counter)
   484  	}
   485  }