code.flowtr.dev/mirrors/u-root@v1.0.0/pkg/dhcp6client/client_test.go (about) 1 // Copyright 2017-2018 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package dhcp6client 6 7 import ( 8 "bytes" 9 "context" 10 "fmt" 11 "net" 12 "syscall" 13 "testing" 14 "time" 15 16 "github.com/mdlayher/dhcp6" 17 ) 18 19 type timeoutErr struct{} 20 21 func (timeoutErr) Error() string { 22 return "i/o timeout" 23 } 24 25 func (timeoutErr) Timeout() bool { 26 return true 27 } 28 29 type udpPacket struct { 30 source *net.UDPAddr 31 dest *net.UDPAddr 32 payload []byte 33 } 34 35 // mockUDPConn implements net.PacketConn. 36 type mockUDPConn struct { 37 // This'll just be nil for all the methods we don't implement. 38 39 // in is the queue of packets ReadFromUDP reads from. 40 // 41 // ReadFromUDP returns io.EOF when in is closed. 42 in chan udpPacket 43 44 inTimer *time.Timer 45 46 // out is the queue of packets WriteTo writes to. 47 out chan<- udpPacket 48 49 closed bool 50 } 51 52 func newMockUDPConn(in chan udpPacket, out chan<- udpPacket) *mockUDPConn { 53 return &mockUDPConn{ 54 in: in, 55 out: out, 56 } 57 } 58 59 // SetReadDeadline implements PacketConn.SetReadDeadline. 60 func (m *mockUDPConn) SetReadDeadline(t time.Time) error { 61 duration := t.Sub(time.Now()) 62 if duration < 0 { 63 return fmt.Errorf("deadline must be in the future") 64 } 65 m.inTimer = time.NewTimer(duration) 66 return nil 67 } 68 69 func (m *mockUDPConn) LocalAddr() net.Addr { 70 panic("unused") 71 } 72 73 func (m *mockUDPConn) SetWriteDeadline(t time.Time) error { 74 panic("unused") 75 } 76 77 func (m *mockUDPConn) SetDeadline(t time.Time) error { 78 panic("unused") 79 } 80 81 // Close implements PacketConn.Close. 82 func (m *mockUDPConn) Close() error { 83 m.closed = true 84 close(m.out) 85 return nil 86 } 87 88 // ReadFrom is a mock for PacketConn.ReadFromUDP. 89 func (m *mockUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { 90 // Make sure we don't have data waiting. 91 select { 92 case p, ok := <-m.in: 93 if !ok { 94 // Connection was closed. 95 return 0, nil, nil 96 } 97 return copy(b, p.payload), p.source, nil 98 default: 99 } 100 101 select { 102 case p, ok := <-m.in: 103 if !ok { 104 return 0, nil, nil 105 } 106 return copy(b, p.payload), p.source, nil 107 case <-m.inTimer.C: 108 // This net.OpError will return true for Timeout(). 109 return 0, nil, &net.OpError{Err: timeoutErr{}} 110 } 111 } 112 113 // WriteTo is a mock for PacketConn.WriteTo. 114 func (m *mockUDPConn) WriteTo(b []byte, dest net.Addr) (int, error) { 115 if m.closed { 116 return 0, syscall.EBADF 117 } 118 119 m.out <- udpPacket{ 120 dest: dest.(*net.UDPAddr), 121 payload: b, 122 } 123 return len(b), nil 124 } 125 126 type server struct { 127 in chan udpPacket 128 out chan udpPacket 129 130 received []*dhcp6.Packet 131 132 // Each received packet can have more than one response (in theory, 133 // from different servers sending different Advertise, for example). 134 responses [][]*dhcp6.Packet 135 } 136 137 func (s *server) serve(ctx context.Context) { 138 go func() { 139 select { 140 case udpPkt, ok := <-s.in: 141 if !ok { 142 break 143 } 144 145 // What did we get? 146 var pkt dhcp6.Packet 147 if err := (&pkt).UnmarshalBinary(udpPkt.payload); err != nil { 148 panic(fmt.Sprintf("invalid dhcp6 packet %q: %v", udpPkt.payload, err)) 149 } 150 s.received = append(s.received, &pkt) 151 152 if len(s.responses) > 0 { 153 resps := s.responses[0] 154 // What should we send in response? 155 for _, resp := range resps { 156 bin, err := resp.MarshalBinary() 157 if err != nil { 158 panic(fmt.Sprintf("failed to serialize dhcp6 packet %v: %v", resp, err)) 159 } 160 s.out <- udpPacket{ 161 source: udpPkt.dest, 162 payload: bin, 163 } 164 } 165 s.responses = s.responses[1:] 166 } 167 168 case <-ctx.Done(): 169 break 170 } 171 172 // We're done sending stuff. 173 close(s.out) 174 }() 175 176 } 177 178 func ComparePacket(got *dhcp6.Packet, want *dhcp6.Packet) error { 179 aa, err := got.MarshalBinary() 180 if err != nil { 181 panic(err) 182 } 183 bb, err := want.MarshalBinary() 184 if err != nil { 185 panic(err) 186 } 187 if bytes.Compare(aa, bb) != 0 { 188 return fmt.Errorf("packet got %v, want %v", got, want) 189 } 190 return nil 191 } 192 193 func pktsExpected(got []*dhcp6.Packet, want []*dhcp6.Packet) error { 194 if len(got) != len(want) { 195 return fmt.Errorf("got %d packets, want %d packets", len(got), len(want)) 196 } 197 198 for i := range got { 199 if err := ComparePacket(got[i], want[i]); err != nil { 200 return err 201 } 202 } 203 return nil 204 } 205 206 func serveAndClient(ctx context.Context, responses [][]*dhcp6.Packet) (*Client, *mockUDPConn) { 207 // These are the client's channels. 208 in := make(chan udpPacket, 100) 209 out := make(chan udpPacket, 100) 210 211 mockConn := &mockUDPConn{ 212 in: in, 213 out: out, 214 } 215 216 mc := &Client{ 217 conn: mockConn, 218 retry: 1, 219 timeout: time.Second, 220 } 221 222 // Of course, for the server they are reversed. 223 s := &server{ 224 in: out, 225 out: in, 226 responses: responses, 227 } 228 go s.serve(ctx) 229 230 return mc, mockConn 231 } 232 233 func TestSimpleSendAndRead(t *testing.T) { 234 for _, tt := range []struct { 235 desc string 236 send *dhcp6.Packet 237 server []*dhcp6.Packet 238 239 // If want is nil, we assume server contains what is wanted. 240 want []*dhcp6.Packet 241 wantErr error 242 }{ 243 { 244 desc: "two response packets", 245 send: &dhcp6.Packet{ 246 MessageType: dhcp6.MessageTypeSolicit, 247 TransactionID: [3]byte{0x33, 0x33, 0x33}, 248 }, 249 server: []*dhcp6.Packet{ 250 { 251 MessageType: dhcp6.MessageTypeAdvertise, 252 TransactionID: [3]byte{0x33, 0x33, 0x33}, 253 }, 254 { 255 MessageType: dhcp6.MessageTypeAdvertise, 256 TransactionID: [3]byte{0x33, 0x33, 0x33}, 257 }, 258 }, 259 }, 260 { 261 desc: "one response packet", 262 send: &dhcp6.Packet{ 263 MessageType: dhcp6.MessageTypeSolicit, 264 TransactionID: [3]byte{0x33, 0x33, 0x33}, 265 }, 266 server: []*dhcp6.Packet{ 267 { 268 MessageType: dhcp6.MessageTypeAdvertise, 269 TransactionID: [3]byte{0x33, 0x33, 0x33}, 270 }, 271 }, 272 }, 273 { 274 desc: "one response packet, one invalid XID", 275 send: &dhcp6.Packet{ 276 MessageType: dhcp6.MessageTypeSolicit, 277 TransactionID: [3]byte{0x33, 0x33, 0x33}, 278 }, 279 server: []*dhcp6.Packet{ 280 { 281 MessageType: dhcp6.MessageTypeAdvertise, 282 TransactionID: [3]byte{0x33, 0x33, 0x33}, 283 }, 284 { 285 MessageType: dhcp6.MessageTypeAdvertise, 286 TransactionID: [3]byte{0x77, 0x77, 0x77}, 287 }, 288 }, 289 want: []*dhcp6.Packet{ 290 { 291 MessageType: dhcp6.MessageTypeAdvertise, 292 TransactionID: [3]byte{0x33, 0x33, 0x33}, 293 }, 294 }, 295 }, 296 { 297 desc: "discard wrong XID", 298 send: &dhcp6.Packet{ 299 MessageType: dhcp6.MessageTypeSolicit, 300 TransactionID: [3]byte{0x33, 0x33, 0x33}, 301 }, 302 server: []*dhcp6.Packet{ 303 { 304 MessageType: dhcp6.MessageTypeAdvertise, 305 TransactionID: [3]byte{0x00, 0x00, 0x00}, 306 }, 307 }, 308 want: []*dhcp6.Packet{}, // Explicitly empty. 309 wantErr: context.DeadlineExceeded, 310 }, 311 { 312 desc: "no response, timeout", 313 send: &dhcp6.Packet{ 314 MessageType: dhcp6.MessageTypeSolicit, 315 TransactionID: [3]byte{0x33, 0x33, 0x33}, 316 }, 317 wantErr: context.DeadlineExceeded, 318 }, 319 } { 320 // Both server and client only get 2 seconds. 321 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 322 defer cancel() 323 324 mc, _ := serveAndClient(ctx, [][]*dhcp6.Packet{tt.server}) 325 defer mc.conn.Close() 326 327 wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, tt.send) 328 329 var rcvd []*dhcp6.Packet 330 for packet := range out { 331 rcvd = append(rcvd, packet.Packet) 332 } 333 334 wg.Wait() 335 if err, ok := <-errCh; ok && err.Err != tt.wantErr { 336 t.Errorf("SimpleSendAndRead(%v): got %v, want %v", tt.send, err.Err, tt.wantErr) 337 } else if !ok && tt.wantErr != nil { 338 t.Errorf("got no error, want %v", tt.wantErr) 339 } 340 341 want := tt.want 342 if want == nil { 343 want = tt.server 344 } 345 if err := pktsExpected(rcvd, want); err != nil { 346 t.Errorf("got unexpected packets: %v", err) 347 } 348 } 349 } 350 351 func TestSimpleSendAndReadHandleCancel(t *testing.T) { 352 pkt := &dhcp6.Packet{ 353 MessageType: dhcp6.MessageTypeSolicit, 354 TransactionID: [3]byte{0x33, 0x33, 0x33}, 355 } 356 357 responses := []*dhcp6.Packet{ 358 { 359 MessageType: dhcp6.MessageTypeAdvertise, 360 TransactionID: [3]byte{0x33, 0x33, 0x33}, 361 }, 362 { 363 MessageType: dhcp6.MessageTypeRelayRepl, 364 TransactionID: [3]byte{0x33, 0x33, 0x33}, 365 }, 366 { 367 MessageType: dhcp6.MessageTypeInformationRequest, 368 TransactionID: [3]byte{0x33, 0x33, 0x33}, 369 }, 370 { 371 MessageType: dhcp6.MessageTypeReply, 372 TransactionID: [3]byte{0x33, 0x33, 0x33}, 373 }, 374 } 375 376 // Both the server and client only get 2 seconds. 377 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 378 defer cancel() 379 380 mc, udpConn := serveAndClient(ctx, [][]*dhcp6.Packet{responses}) 381 defer mc.conn.Close() 382 383 wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt) 384 385 var counter int 386 for range out { 387 counter++ 388 if counter == 2 { 389 cancel() 390 } 391 } 392 393 wg.Wait() 394 if err, ok := <-errCh; ok { 395 t.Errorf("got %v, want nil error", err) 396 } 397 398 // Make sure that two packets are still in the queue to be read. 399 for packet := range udpConn.in { 400 bin, err := responses[counter].MarshalBinary() 401 if err != nil { 402 panic(err) 403 } 404 if bytes.Compare(packet.payload, bin) != 0 { 405 t.Errorf("SimpleSendAndRead read more packets than expected!") 406 } 407 counter++ 408 } 409 } 410 411 func TestSimpleSendAndReadDiscardGarbage(t *testing.T) { 412 pkt := &dhcp6.Packet{ 413 MessageType: dhcp6.MessageTypeSolicit, 414 TransactionID: [3]byte{0x33, 0x33, 0x33}, 415 } 416 417 responses := []*dhcp6.Packet{ 418 { 419 MessageType: dhcp6.MessageTypeAdvertise, 420 TransactionID: [3]byte{0x33, 0x33, 0x33}, 421 }, 422 } 423 424 // Both the server and client only get 2 seconds. 425 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 426 defer cancel() 427 428 mc, udpConn := serveAndClient(ctx, [][]*dhcp6.Packet{responses}) 429 defer mc.conn.Close() 430 431 udpConn.in <- udpPacket{ 432 payload: []byte{0x01}, // Too short for valid DHCPv6 packet. 433 } 434 435 wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt) 436 437 var i int 438 for recvd := range out { 439 if err := ComparePacket(recvd.Packet, responses[i]); err != nil { 440 t.Error(err) 441 } 442 i++ 443 } 444 445 wg.Wait() 446 if err, ok := <-errCh; ok { 447 t.Errorf("SimpleSendAndRead(%v): got %v %v, want %v", pkt, ok, err, nil) 448 } 449 if i != len(responses) { 450 t.Errorf("should have received %d valid packet, counter is %d", len(responses), i) 451 } 452 } 453 454 func TestSimpleSendAndReadDiscardGarbageTimeout(t *testing.T) { 455 pkt := &dhcp6.Packet{ 456 MessageType: dhcp6.MessageTypeSolicit, 457 TransactionID: [3]byte{0x33, 0x33, 0x33}, 458 } 459 460 // Both the server and client only get 2 seconds. 461 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) 462 defer cancel() 463 464 mc, udpConn := serveAndClient(ctx, nil) 465 defer mc.conn.Close() 466 467 udpConn.in <- udpPacket{ 468 payload: []byte{0x01}, // Too short for valid DHCPv6 packet. 469 } 470 471 wg, out, errCh := mc.SimpleSendAndRead(ctx, DefaultServers, pkt) 472 473 var counter int 474 for range out { 475 counter++ 476 } 477 478 wg.Wait() 479 if err, ok := <-errCh; !ok || err == nil || err.Err != context.DeadlineExceeded { 480 t.Errorf("SimpleSendAndRead(%v): got %v %v, want %v", pkt, ok, err, context.DeadlineExceeded) 481 } 482 if counter != 0 { 483 t.Errorf("should not have received a valid packet, counter is %d", counter) 484 } 485 }