github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/packetimpact/testbench/connections.go (about) 1 // Copyright 2020 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package testbench 16 17 import ( 18 "fmt" 19 "math/rand" 20 "testing" 21 "time" 22 23 "github.com/mohae/deepcopy" 24 "go.uber.org/multierr" 25 "golang.org/x/sys/unix" 26 "github.com/SagerNet/gvisor/pkg/tcpip" 27 "github.com/SagerNet/gvisor/pkg/tcpip/header" 28 "github.com/SagerNet/gvisor/pkg/tcpip/seqnum" 29 ) 30 31 func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { 32 switch sa := sa.(type) { 33 case *unix.SockaddrInet4: 34 return uint16(sa.Port), nil 35 case *unix.SockaddrInet6: 36 return uint16(sa.Port), nil 37 } 38 return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) 39 } 40 41 // pickPort makes a new socket and returns the socket FD and port. The domain 42 // should be AF_INET or AF_INET6. The caller must close the FD when done with 43 // the port if there is no error. 44 func (n *DUTTestNet) pickPort(domain, typ int) (fd int, port uint16, err error) { 45 fd, err = unix.Socket(domain, typ, 0) 46 if err != nil { 47 return -1, 0, fmt.Errorf("creating socket: %w", err) 48 } 49 defer func() { 50 if err != nil { 51 if cerr := unix.Close(fd); cerr != nil { 52 err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr)) 53 } 54 } 55 }() 56 var sa unix.Sockaddr 57 switch domain { 58 case unix.AF_INET: 59 var sa4 unix.SockaddrInet4 60 copy(sa4.Addr[:], n.LocalIPv4) 61 sa = &sa4 62 case unix.AF_INET6: 63 sa6 := unix.SockaddrInet6{ZoneId: n.LocalDevID} 64 copy(sa6.Addr[:], n.LocalIPv6) 65 sa = &sa6 66 default: 67 return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) 68 } 69 if err = unix.Bind(fd, sa); err != nil { 70 return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err) 71 } 72 sa, err = unix.Getsockname(fd) 73 if err != nil { 74 return -1, 0, fmt.Errorf("unix.Getsocketname(%d): %w", fd, err) 75 } 76 port, err = portFromSockaddr(sa) 77 if err != nil { 78 return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err) 79 } 80 return fd, port, nil 81 } 82 83 // layerState stores the state of a layer of a connection. 84 type layerState interface { 85 // outgoing returns an outgoing layer to be sent in a frame. It should not 86 // update layerState, that is done in layerState.sent. 87 outgoing() Layer 88 89 // incoming creates an expected Layer for comparing against a received Layer. 90 // Because the expectation can depend on values in the received Layer, it is 91 // an input to incoming. For example, the ACK number needs to be checked in a 92 // TCP packet but only if the ACK flag is set in the received packet. It 93 // should not update layerState, that is done in layerState.received. The 94 // caller takes ownership of the returned Layer. 95 incoming(received Layer) Layer 96 97 // sent updates the layerState based on the Layer that was sent. The input is 98 // a Layer with all prev and next pointers populated so that the entire frame 99 // as it was sent is available. 100 sent(sent Layer) error 101 102 // received updates the layerState based on a Layer that is received. The 103 // input is a Layer with all prev and next pointers populated so that the 104 // entire frame as it was received is available. 105 received(received Layer) error 106 107 // close frees associated resources held by the LayerState. 108 close() error 109 } 110 111 // etherState maintains state about an Ethernet connection. 112 type etherState struct { 113 out, in Ether 114 } 115 116 var _ layerState = (*etherState)(nil) 117 118 // newEtherState creates a new etherState. 119 func (n *DUTTestNet) newEtherState(out, in Ether) (*etherState, error) { 120 lmac := tcpip.LinkAddress(n.LocalMAC) 121 rmac := tcpip.LinkAddress(n.RemoteMAC) 122 s := etherState{ 123 out: Ether{SrcAddr: &lmac, DstAddr: &rmac}, 124 in: Ether{SrcAddr: &rmac, DstAddr: &lmac}, 125 } 126 if err := s.out.merge(&out); err != nil { 127 return nil, err 128 } 129 if err := s.in.merge(&in); err != nil { 130 return nil, err 131 } 132 return &s, nil 133 } 134 135 func (s *etherState) outgoing() Layer { 136 return deepcopy.Copy(&s.out).(Layer) 137 } 138 139 // incoming implements layerState.incoming. 140 func (s *etherState) incoming(Layer) Layer { 141 return deepcopy.Copy(&s.in).(Layer) 142 } 143 144 func (*etherState) sent(Layer) error { 145 return nil 146 } 147 148 func (*etherState) received(Layer) error { 149 return nil 150 } 151 152 func (*etherState) close() error { 153 return nil 154 } 155 156 // ipv4State maintains state about an IPv4 connection. 157 type ipv4State struct { 158 out, in IPv4 159 } 160 161 var _ layerState = (*ipv4State)(nil) 162 163 // newIPv4State creates a new ipv4State. 164 func (n *DUTTestNet) newIPv4State(out, in IPv4) (*ipv4State, error) { 165 lIP := tcpip.Address(n.LocalIPv4) 166 rIP := tcpip.Address(n.RemoteIPv4) 167 s := ipv4State{ 168 out: IPv4{SrcAddr: &lIP, DstAddr: &rIP}, 169 in: IPv4{SrcAddr: &rIP, DstAddr: &lIP}, 170 } 171 if err := s.out.merge(&out); err != nil { 172 return nil, err 173 } 174 if err := s.in.merge(&in); err != nil { 175 return nil, err 176 } 177 return &s, nil 178 } 179 180 func (s *ipv4State) outgoing() Layer { 181 return deepcopy.Copy(&s.out).(Layer) 182 } 183 184 // incoming implements layerState.incoming. 185 func (s *ipv4State) incoming(Layer) Layer { 186 return deepcopy.Copy(&s.in).(Layer) 187 } 188 189 func (*ipv4State) sent(Layer) error { 190 return nil 191 } 192 193 func (*ipv4State) received(Layer) error { 194 return nil 195 } 196 197 func (*ipv4State) close() error { 198 return nil 199 } 200 201 // ipv6State maintains state about an IPv6 connection. 202 type ipv6State struct { 203 out, in IPv6 204 } 205 206 var _ layerState = (*ipv6State)(nil) 207 208 // newIPv6State creates a new ipv6State. 209 func (n *DUTTestNet) newIPv6State(out, in IPv6) (*ipv6State, error) { 210 lIP := tcpip.Address(n.LocalIPv6) 211 rIP := tcpip.Address(n.RemoteIPv6) 212 s := ipv6State{ 213 out: IPv6{SrcAddr: &lIP, DstAddr: &rIP}, 214 in: IPv6{SrcAddr: &rIP, DstAddr: &lIP}, 215 } 216 if err := s.out.merge(&out); err != nil { 217 return nil, err 218 } 219 if err := s.in.merge(&in); err != nil { 220 return nil, err 221 } 222 return &s, nil 223 } 224 225 // outgoing returns an outgoing layer to be sent in a frame. 226 func (s *ipv6State) outgoing() Layer { 227 return deepcopy.Copy(&s.out).(Layer) 228 } 229 230 func (s *ipv6State) incoming(Layer) Layer { 231 return deepcopy.Copy(&s.in).(Layer) 232 } 233 234 func (s *ipv6State) sent(Layer) error { 235 // Nothing to do. 236 return nil 237 } 238 239 func (s *ipv6State) received(Layer) error { 240 // Nothing to do. 241 return nil 242 } 243 244 // close cleans up any resources held. 245 func (s *ipv6State) close() error { 246 return nil 247 } 248 249 // tcpState maintains state about a TCP connection. 250 type tcpState struct { 251 out, in TCP 252 localSeqNum, remoteSeqNum *seqnum.Value 253 synAck *TCP 254 portPickerFD int 255 finSent bool 256 } 257 258 var _ layerState = (*tcpState)(nil) 259 260 // SeqNumValue is a helper routine that allocates a new seqnum.Value value to 261 // store v and returns a pointer to it. 262 func SeqNumValue(v seqnum.Value) *seqnum.Value { 263 return &v 264 } 265 266 // newTCPState creates a new TCPState. 267 func (n *DUTTestNet) newTCPState(domain int, out, in TCP) (*tcpState, error) { 268 portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_STREAM) 269 if err != nil { 270 return nil, err 271 } 272 s := tcpState{ 273 out: TCP{SrcPort: &localPort}, 274 in: TCP{DstPort: &localPort}, 275 localSeqNum: SeqNumValue(seqnum.Value(rand.Uint32())), 276 portPickerFD: portPickerFD, 277 finSent: false, 278 } 279 if err := s.out.merge(&out); err != nil { 280 return nil, err 281 } 282 if err := s.in.merge(&in); err != nil { 283 return nil, err 284 } 285 return &s, nil 286 } 287 288 func (s *tcpState) outgoing() Layer { 289 newOutgoing := deepcopy.Copy(s.out).(TCP) 290 if s.localSeqNum != nil { 291 newOutgoing.SeqNum = Uint32(uint32(*s.localSeqNum)) 292 } 293 if s.remoteSeqNum != nil { 294 newOutgoing.AckNum = Uint32(uint32(*s.remoteSeqNum)) 295 } 296 return &newOutgoing 297 } 298 299 // incoming implements layerState.incoming. 300 func (s *tcpState) incoming(received Layer) Layer { 301 tcpReceived, ok := received.(*TCP) 302 if !ok { 303 return nil 304 } 305 newIn := deepcopy.Copy(s.in).(TCP) 306 if s.remoteSeqNum != nil { 307 newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) 308 } 309 if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 { 310 // The caller didn't specify an AckNum so we'll expect the calculated one, 311 // but only if the ACK flag is set because the AckNum is not valid in a 312 // header if ACK is not set. 313 newIn.AckNum = Uint32(uint32(*seq)) 314 } 315 return &newIn 316 } 317 318 func (s *tcpState) sent(sent Layer) error { 319 tcp, ok := sent.(*TCP) 320 if !ok { 321 return fmt.Errorf("can't update tcpState with %T Layer", sent) 322 } 323 if !s.finSent { 324 // update localSeqNum by the payload only when FIN is not yet sent by us 325 for current := tcp.next(); current != nil; current = current.next() { 326 s.localSeqNum.UpdateForward(seqnum.Size(current.length())) 327 } 328 } 329 if tcp.Flags != nil && *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { 330 s.localSeqNum.UpdateForward(1) 331 } 332 if *tcp.Flags&(header.TCPFlagFin) != 0 { 333 s.finSent = true 334 } 335 return nil 336 } 337 338 func (s *tcpState) received(l Layer) error { 339 tcp, ok := l.(*TCP) 340 if !ok { 341 return fmt.Errorf("can't update tcpState with %T Layer", l) 342 } 343 s.remoteSeqNum = SeqNumValue(seqnum.Value(*tcp.SeqNum)) 344 if *tcp.Flags&(header.TCPFlagSyn|header.TCPFlagFin) != 0 { 345 s.remoteSeqNum.UpdateForward(1) 346 } 347 for current := tcp.next(); current != nil; current = current.next() { 348 s.remoteSeqNum.UpdateForward(seqnum.Size(current.length())) 349 } 350 return nil 351 } 352 353 // close frees the port associated with this connection. 354 func (s *tcpState) close() error { 355 if err := unix.Close(s.portPickerFD); err != nil { 356 return err 357 } 358 s.portPickerFD = -1 359 return nil 360 } 361 362 // udpState maintains state about a UDP connection. 363 type udpState struct { 364 out, in UDP 365 portPickerFD int 366 } 367 368 var _ layerState = (*udpState)(nil) 369 370 // newUDPState creates a new udpState. 371 func (n *DUTTestNet) newUDPState(domain int, out, in UDP) (*udpState, error) { 372 portPickerFD, localPort, err := n.pickPort(domain, unix.SOCK_DGRAM) 373 if err != nil { 374 return nil, fmt.Errorf("picking port: %w", err) 375 } 376 s := udpState{ 377 out: UDP{SrcPort: &localPort}, 378 in: UDP{DstPort: &localPort}, 379 portPickerFD: portPickerFD, 380 } 381 if err := s.out.merge(&out); err != nil { 382 return nil, err 383 } 384 if err := s.in.merge(&in); err != nil { 385 return nil, err 386 } 387 return &s, nil 388 } 389 390 func (s *udpState) outgoing() Layer { 391 return deepcopy.Copy(&s.out).(Layer) 392 } 393 394 // incoming implements layerState.incoming. 395 func (s *udpState) incoming(Layer) Layer { 396 return deepcopy.Copy(&s.in).(Layer) 397 } 398 399 func (*udpState) sent(l Layer) error { 400 return nil 401 } 402 403 func (*udpState) received(l Layer) error { 404 return nil 405 } 406 407 // close frees the port associated with this connection. 408 func (s *udpState) close() error { 409 if err := unix.Close(s.portPickerFD); err != nil { 410 return err 411 } 412 s.portPickerFD = -1 413 return nil 414 } 415 416 // Connection holds a collection of layer states for maintaining a connection 417 // along with sockets for sniffer and injecting packets. 418 type Connection struct { 419 layerStates []layerState 420 injector Injector 421 sniffer Sniffer 422 } 423 424 // Returns the default incoming frame against which to match. If received is 425 // longer than layerStates then that may still count as a match. The reverse is 426 // never a match and nil is returned. 427 func (conn *Connection) incoming(received Layers) Layers { 428 if len(received) < len(conn.layerStates) { 429 return nil 430 } 431 in := Layers{} 432 for i, s := range conn.layerStates { 433 toMatch := s.incoming(received[i]) 434 if toMatch == nil { 435 return nil 436 } 437 in = append(in, toMatch) 438 } 439 return in 440 } 441 442 func (conn *Connection) match(override, received Layers) bool { 443 toMatch := conn.incoming(received) 444 if toMatch == nil { 445 return false // Not enough layers in gotLayers for matching. 446 } 447 if err := toMatch.merge(override); err != nil { 448 return false // Failing to merge is not matching. 449 } 450 return toMatch.match(received) 451 } 452 453 // Close frees associated resources held by the Connection. 454 func (conn *Connection) Close(t *testing.T) { 455 t.Helper() 456 457 errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) 458 for _, s := range conn.layerStates { 459 if err := s.close(); err != nil { 460 errs = multierr.Append(errs, fmt.Errorf("unable to close %+v: %s", s, err)) 461 } 462 } 463 if errs != nil { 464 t.Fatalf("unable to close %+v: %s", conn, errs) 465 } 466 } 467 468 // CreateFrame builds a frame for the connection with defaults overridden 469 // from the innermost layer out, and additionalLayers added after it. 470 // 471 // Note that overrideLayers can have a length that is less than the number 472 // of layers in this connection, and in such cases the innermost layers are 473 // overridden first. As an example, valid values of overrideLayers for a TCP- 474 // over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and 475 // [Ethernet, IPv4, TCP]. 476 func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { 477 t.Helper() 478 479 var layersToSend Layers 480 for i, s := range conn.layerStates { 481 layer := s.outgoing() 482 // overrideLayers and conn.layerStates have their tails aligned, so 483 // to find the index we move backwards by the distance i is to the 484 // end. 485 if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { 486 if err := layer.merge(overrideLayers[j]); err != nil { 487 t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) 488 } 489 } 490 layersToSend = append(layersToSend, layer) 491 } 492 layersToSend = append(layersToSend, additionalLayers...) 493 return layersToSend 494 } 495 496 // SendFrameStateless sends a frame without updating any of the layer states. 497 // 498 // This method is useful for sending out-of-band control messages such as 499 // ICMP packets, where it would not make sense to update the transport layer's 500 // state using the ICMP header. 501 func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { 502 t.Helper() 503 504 outBytes, err := frame.ToBytes() 505 if err != nil { 506 t.Fatalf("can't build outgoing packet: %s", err) 507 } 508 conn.injector.Send(t, outBytes) 509 } 510 511 // SendFrame sends a frame on the wire and updates the state of all layers. 512 func (conn *Connection) SendFrame(t *testing.T, frame Layers) { 513 t.Helper() 514 515 outBytes, err := frame.ToBytes() 516 if err != nil { 517 t.Fatalf("can't build outgoing packet: %s", err) 518 } 519 conn.injector.Send(t, outBytes) 520 521 // frame might have nil values where the caller wanted to use default values. 522 // sentFrame will have no nil values in it because it comes from parsing the 523 // bytes that were actually sent. 524 sentFrame := parse(parseEther, outBytes) 525 // Update the state of each layer based on what was sent. 526 for i, s := range conn.layerStates { 527 if err := s.sent(sentFrame[i]); err != nil { 528 t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) 529 } 530 } 531 } 532 533 // send sends a packet, possibly with layers of this connection overridden and 534 // additional layers added. 535 // 536 // Types defined with Connection as the underlying type should expose 537 // type-safe versions of this method. 538 func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { 539 t.Helper() 540 541 conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) 542 } 543 544 // recvFrame gets the next successfully parsed frame (of type Layers) within the 545 // timeout provided. If no parsable frame arrives before the timeout, it returns 546 // nil. 547 func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { 548 t.Helper() 549 550 if timeout <= 0 { 551 return nil 552 } 553 b := conn.sniffer.Recv(t, timeout) 554 if b == nil { 555 return nil 556 } 557 return parse(parseEther, b) 558 } 559 560 // layersError stores the Layers that we got and the Layers that we wanted to 561 // match. 562 type layersError struct { 563 got, want Layers 564 } 565 566 func (e *layersError) Error() string { 567 return e.got.diff(e.want) 568 } 569 570 // Expect expects a frame with the final layerStates layer matching the 571 // provided Layer within the timeout specified. If it doesn't arrive in time, 572 // an error is returned. 573 func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { 574 t.Helper() 575 576 // Make a frame that will ignore all but the final layer. 577 layers := make([]Layer, len(conn.layerStates)) 578 layers[len(layers)-1] = layer 579 580 gotFrame, err := conn.ExpectFrame(t, layers, timeout) 581 if err != nil { 582 return nil, err 583 } 584 if len(conn.layerStates)-1 < len(gotFrame) { 585 return gotFrame[len(conn.layerStates)-1], nil 586 } 587 t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) 588 panic("unreachable") 589 } 590 591 // ExpectFrame expects a frame that matches the provided Layers within the 592 // timeout specified. If one arrives in time, the Layers is returned without an 593 // error. If it doesn't arrive in time, it returns nil and error is non-nil. 594 func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { 595 t.Helper() 596 597 frames, ok := conn.ListenForFrame(t, layers, timeout) 598 if ok { 599 return frames[len(frames)-1], nil 600 } 601 if len(frames) == 0 { 602 return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout) 603 } 604 605 var errs error 606 for _, got := range frames { 607 want := conn.incoming(layers) 608 if err := want.merge(layers); err != nil { 609 errs = multierr.Combine(errs, err) 610 } else { 611 errs = multierr.Combine(errs, &layersError{got: got, want: want}) 612 } 613 } 614 return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout) 615 } 616 617 // ListenForFrame captures all frames until a frame matches the provided Layers, 618 // or until the timeout specified. Returns all captured frames, including the 619 // matched frame, and true if the desired frame was found. 620 func (conn *Connection) ListenForFrame(t *testing.T, layers Layers, timeout time.Duration) ([]Layers, bool) { 621 t.Helper() 622 623 deadline := time.Now().Add(timeout) 624 var frames []Layers 625 for { 626 var got Layers 627 if timeout := time.Until(deadline); timeout > 0 { 628 got = conn.recvFrame(t, timeout) 629 } 630 if got == nil { 631 return frames, false 632 } 633 frames = append(frames, got) 634 if conn.match(layers, got) { 635 for i, s := range conn.layerStates { 636 if err := s.received(got[i]); err != nil { 637 t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) 638 } 639 } 640 return frames, true 641 } 642 } 643 } 644 645 // Drain drains the sniffer's receive buffer by receiving packets until there's 646 // nothing else to receive. 647 func (conn *Connection) Drain(t *testing.T) { 648 t.Helper() 649 650 conn.sniffer.Drain(t) 651 } 652 653 // TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. 654 type TCPIPv4 struct { 655 Connection 656 } 657 658 // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. 659 func (n *DUTTestNet) NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { 660 t.Helper() 661 662 etherState, err := n.newEtherState(Ether{}, Ether{}) 663 if err != nil { 664 t.Fatalf("can't make etherState: %s", err) 665 } 666 ipv4State, err := n.newIPv4State(IPv4{}, IPv4{}) 667 if err != nil { 668 t.Fatalf("can't make ipv4State: %s", err) 669 } 670 tcpState, err := n.newTCPState(unix.AF_INET, outgoingTCP, incomingTCP) 671 if err != nil { 672 t.Fatalf("can't make tcpState: %s", err) 673 } 674 injector, err := n.NewInjector(t) 675 if err != nil { 676 t.Fatalf("can't make injector: %s", err) 677 } 678 sniffer, err := n.NewSniffer(t) 679 if err != nil { 680 t.Fatalf("can't make sniffer: %s", err) 681 } 682 683 return TCPIPv4{ 684 Connection: Connection{ 685 layerStates: []layerState{etherState, ipv4State, tcpState}, 686 injector: injector, 687 sniffer: sniffer, 688 }, 689 } 690 } 691 692 // Connect performs a TCP 3-way handshake. The input Connection should have a 693 // final TCP Layer. 694 func (conn *TCPIPv4) Connect(t *testing.T) { 695 t.Helper() 696 697 // Send the SYN. 698 conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn)}) 699 700 // Wait for the SYN-ACK. 701 synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) 702 if err != nil { 703 t.Fatalf("didn't get synack during handshake: %s", err) 704 } 705 conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck 706 707 // Send an ACK. 708 conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)}) 709 } 710 711 // ConnectWithOptions performs a TCP 3-way handshake with given TCP options. 712 // The input Connection should have a final TCP Layer. 713 func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { 714 t.Helper() 715 716 // Send the SYN. 717 conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn), Options: options}) 718 719 // Wait for the SYN-ACK. 720 synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) 721 if err != nil { 722 t.Fatalf("didn't get synack during handshake: %s", err) 723 } 724 conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck 725 726 // Send an ACK. 727 conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)}) 728 } 729 730 // ExpectData is a convenient method that expects a Layer and the Layer after 731 // it. If it doesn't arrive in time, it returns nil. 732 func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { 733 t.Helper() 734 735 expected := make([]Layer, len(conn.layerStates)) 736 expected[len(expected)-1] = tcp 737 if payload != nil { 738 expected = append(expected, payload) 739 } 740 return conn.ExpectFrame(t, expected, timeout) 741 } 742 743 // ExpectNextData attempts to receive the next incoming segment for the 744 // connection and expects that to match the given layers. 745 // 746 // It differs from ExpectData() in that here we are only interested in the next 747 // received segment, while ExpectData() can receive multiple segments for the 748 // connection until there is a match with given layers or a timeout. 749 func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { 750 t.Helper() 751 752 // Receive the first incoming TCP segment for this connection. 753 got, err := conn.ExpectData(t, &TCP{}, nil, timeout) 754 if err != nil { 755 return nil, err 756 } 757 758 expected := make([]Layer, len(conn.layerStates)) 759 expected[len(expected)-1] = tcp 760 if payload != nil { 761 expected = append(expected, payload) 762 tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) 763 } 764 if !conn.match(expected, got) { 765 return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) 766 } 767 return got, nil 768 } 769 770 // Send a packet with reasonable defaults. Potentially override the TCP layer in 771 // the connection with the provided layer and add additionLayers. 772 func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { 773 t.Helper() 774 775 conn.send(t, Layers{&tcp}, additionalLayers...) 776 } 777 778 // Expect expects a frame with the TCP layer matching the provided TCP within 779 // the timeout specified. If it doesn't arrive in time, an error is returned. 780 func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { 781 t.Helper() 782 783 layer, err := conn.Connection.Expect(t, &tcp, timeout) 784 if layer == nil { 785 return nil, err 786 } 787 gotTCP, ok := layer.(*TCP) 788 if !ok { 789 t.Fatalf("expected %s to be TCP", layer) 790 } 791 return gotTCP, err 792 } 793 794 func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { 795 t.Helper() 796 797 state, ok := conn.layerStates[2].(*tcpState) 798 if !ok { 799 t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) 800 } 801 return state 802 } 803 804 func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { 805 t.Helper() 806 807 state, ok := conn.layerStates[1].(*ipv4State) 808 if !ok { 809 t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) 810 } 811 return state 812 } 813 814 // RemoteSeqNum returns the next expected sequence number from the DUT. 815 func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { 816 t.Helper() 817 818 return conn.tcpState(t).remoteSeqNum 819 } 820 821 // LocalSeqNum returns the next sequence number to send from the testbench. 822 func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { 823 t.Helper() 824 825 return conn.tcpState(t).localSeqNum 826 } 827 828 // SynAck returns the SynAck that was part of the handshake. 829 func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { 830 t.Helper() 831 832 return conn.tcpState(t).synAck 833 } 834 835 // LocalAddr gets the local socket address of this connection. 836 func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { 837 t.Helper() 838 839 sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} 840 copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) 841 return sa 842 } 843 844 // GenerateOTWSeqSegment generates a segment with 845 // seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only 846 // acceptable when seqNumOffset is 0, otherwise an ACK is expected from the 847 // receiver. 848 func GenerateOTWSeqSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP { 849 t.Helper() 850 lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) 851 otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) 852 return TCP{SeqNum: Uint32(otwSeq), Flags: TCPFlags(header.TCPFlagAck)} 853 } 854 855 // GenerateUnaccACKSegment generates a segment with 856 // acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable 857 // when seqNumOffset is 0, otherwise an ACK is expected from the receiver. 858 func GenerateUnaccACKSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP { 859 t.Helper() 860 lastAcceptable := conn.RemoteSeqNum(t) 861 unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) 862 return TCP{AckNum: Uint32(unaccAck), Flags: TCPFlags(header.TCPFlagAck)} 863 } 864 865 // IPv4Conn maintains the state for all the layers in a IPv4 connection. 866 type IPv4Conn struct { 867 Connection 868 } 869 870 // NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults. 871 func (n *DUTTestNet) NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn { 872 t.Helper() 873 874 etherState, err := n.newEtherState(Ether{}, Ether{}) 875 if err != nil { 876 t.Fatalf("can't make EtherState: %s", err) 877 } 878 ipv4State, err := n.newIPv4State(outgoingIPv4, incomingIPv4) 879 if err != nil { 880 t.Fatalf("can't make IPv4State: %s", err) 881 } 882 883 injector, err := n.NewInjector(t) 884 if err != nil { 885 t.Fatalf("can't make injector: %s", err) 886 } 887 sniffer, err := n.NewSniffer(t) 888 if err != nil { 889 t.Fatalf("can't make sniffer: %s", err) 890 } 891 892 return IPv4Conn{ 893 Connection: Connection{ 894 layerStates: []layerState{etherState, ipv4State}, 895 injector: injector, 896 sniffer: sniffer, 897 }, 898 } 899 } 900 901 // Send sends a frame with ipv4 overriding the IPv4 layer defaults and 902 // additionalLayers added after it. 903 func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) { 904 t.Helper() 905 906 c.send(t, Layers{&ipv4}, additionalLayers...) 907 } 908 909 // IPv6Conn maintains the state for all the layers in a IPv6 connection. 910 type IPv6Conn struct { 911 Connection 912 } 913 914 // NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. 915 func (n *DUTTestNet) NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { 916 t.Helper() 917 918 etherState, err := n.newEtherState(Ether{}, Ether{}) 919 if err != nil { 920 t.Fatalf("can't make EtherState: %s", err) 921 } 922 ipv6State, err := n.newIPv6State(outgoingIPv6, incomingIPv6) 923 if err != nil { 924 t.Fatalf("can't make IPv6State: %s", err) 925 } 926 927 injector, err := n.NewInjector(t) 928 if err != nil { 929 t.Fatalf("can't make injector: %s", err) 930 } 931 sniffer, err := n.NewSniffer(t) 932 if err != nil { 933 t.Fatalf("can't make sniffer: %s", err) 934 } 935 936 return IPv6Conn{ 937 Connection: Connection{ 938 layerStates: []layerState{etherState, ipv6State}, 939 injector: injector, 940 sniffer: sniffer, 941 }, 942 } 943 } 944 945 // Send sends a frame with ipv6 overriding the IPv6 layer defaults and 946 // additionalLayers added after it. 947 func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { 948 t.Helper() 949 950 conn.send(t, Layers{&ipv6}, additionalLayers...) 951 } 952 953 // UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. 954 type UDPIPv4 struct { 955 Connection 956 } 957 958 // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. 959 func (n *DUTTestNet) NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { 960 t.Helper() 961 962 etherState, err := n.newEtherState(Ether{}, Ether{}) 963 if err != nil { 964 t.Fatalf("can't make etherState: %s", err) 965 } 966 ipv4State, err := n.newIPv4State(IPv4{}, IPv4{}) 967 if err != nil { 968 t.Fatalf("can't make ipv4State: %s", err) 969 } 970 udpState, err := n.newUDPState(unix.AF_INET, outgoingUDP, incomingUDP) 971 if err != nil { 972 t.Fatalf("can't make udpState: %s", err) 973 } 974 injector, err := n.NewInjector(t) 975 if err != nil { 976 t.Fatalf("can't make injector: %s", err) 977 } 978 sniffer, err := n.NewSniffer(t) 979 if err != nil { 980 t.Fatalf("can't make sniffer: %s", err) 981 } 982 983 return UDPIPv4{ 984 Connection: Connection{ 985 layerStates: []layerState{etherState, ipv4State, udpState}, 986 injector: injector, 987 sniffer: sniffer, 988 }, 989 } 990 } 991 992 func (conn *UDPIPv4) udpState(t *testing.T) *udpState { 993 t.Helper() 994 995 state, ok := conn.layerStates[2].(*udpState) 996 if !ok { 997 t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) 998 } 999 return state 1000 } 1001 1002 func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { 1003 t.Helper() 1004 1005 state, ok := conn.layerStates[1].(*ipv4State) 1006 if !ok { 1007 t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) 1008 } 1009 return state 1010 } 1011 1012 // LocalAddr gets the local socket address of this connection. 1013 func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { 1014 t.Helper() 1015 1016 sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} 1017 copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) 1018 return sa 1019 } 1020 1021 // SrcPort returns the source port of this connection. 1022 func (conn *UDPIPv4) SrcPort(t *testing.T) uint16 { 1023 t.Helper() 1024 1025 return *conn.udpState(t).out.SrcPort 1026 } 1027 1028 // Send sends a packet with reasonable defaults, potentially overriding the UDP 1029 // layer and adding additionLayers. 1030 func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { 1031 t.Helper() 1032 1033 conn.send(t, Layers{&udp}, additionalLayers...) 1034 } 1035 1036 // SendIP sends a packet with reasonable defaults, potentially overriding the 1037 // UDP and IPv4 headers and adding additionLayers. 1038 func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { 1039 t.Helper() 1040 1041 conn.send(t, Layers{&ip, &udp}, additionalLayers...) 1042 } 1043 1044 // SendFrame sends a frame on the wire and updates the state of all layers. 1045 func (conn *UDPIPv4) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { 1046 t.Helper() 1047 1048 conn.send(t, overrideLayers, additionalLayers...) 1049 } 1050 1051 // Expect expects a frame with the UDP layer matching the provided UDP within 1052 // the timeout specified. If it doesn't arrive in time, an error is returned. 1053 func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { 1054 t.Helper() 1055 1056 layer, err := conn.Connection.Expect(t, &udp, timeout) 1057 if err != nil { 1058 return nil, err 1059 } 1060 gotUDP, ok := layer.(*UDP) 1061 if !ok { 1062 t.Fatalf("expected %s to be UDP", layer) 1063 } 1064 return gotUDP, nil 1065 } 1066 1067 // ExpectData is a convenient method that expects a Layer and the Layer after 1068 // it. If it doesn't arrive in time, it returns nil. 1069 func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { 1070 t.Helper() 1071 1072 expected := make([]Layer, len(conn.layerStates)) 1073 expected[len(expected)-1] = &udp 1074 if payload.length() != 0 { 1075 expected = append(expected, &payload) 1076 } 1077 return conn.ExpectFrame(t, expected, timeout) 1078 } 1079 1080 // UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. 1081 type UDPIPv6 struct { 1082 Connection 1083 } 1084 1085 // NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. 1086 func (n *DUTTestNet) NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { 1087 t.Helper() 1088 1089 etherState, err := n.newEtherState(Ether{}, Ether{}) 1090 if err != nil { 1091 t.Fatalf("can't make etherState: %s", err) 1092 } 1093 ipv6State, err := n.newIPv6State(IPv6{}, IPv6{}) 1094 if err != nil { 1095 t.Fatalf("can't make IPv6State: %s", err) 1096 } 1097 udpState, err := n.newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP) 1098 if err != nil { 1099 t.Fatalf("can't make udpState: %s", err) 1100 } 1101 injector, err := n.NewInjector(t) 1102 if err != nil { 1103 t.Fatalf("can't make injector: %s", err) 1104 } 1105 sniffer, err := n.NewSniffer(t) 1106 if err != nil { 1107 t.Fatalf("can't make sniffer: %s", err) 1108 } 1109 return UDPIPv6{ 1110 Connection: Connection{ 1111 layerStates: []layerState{etherState, ipv6State, udpState}, 1112 injector: injector, 1113 sniffer: sniffer, 1114 }, 1115 } 1116 } 1117 1118 func (conn *UDPIPv6) udpState(t *testing.T) *udpState { 1119 t.Helper() 1120 1121 state, ok := conn.layerStates[2].(*udpState) 1122 if !ok { 1123 t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) 1124 } 1125 return state 1126 } 1127 1128 func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { 1129 t.Helper() 1130 1131 state, ok := conn.layerStates[1].(*ipv6State) 1132 if !ok { 1133 t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) 1134 } 1135 return state 1136 } 1137 1138 // LocalAddr gets the local socket address of this connection. 1139 func (conn *UDPIPv6) LocalAddr(t *testing.T, zoneID uint32) *unix.SockaddrInet6 { 1140 t.Helper() 1141 1142 sa := &unix.SockaddrInet6{ 1143 Port: int(*conn.udpState(t).out.SrcPort), 1144 // Local address is in perspective to the remote host, so it's scoped to the 1145 // ID of the remote interface. 1146 ZoneId: zoneID, 1147 } 1148 copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) 1149 return sa 1150 } 1151 1152 // SrcPort returns the source port of this connection. 1153 func (conn *UDPIPv6) SrcPort(t *testing.T) uint16 { 1154 t.Helper() 1155 1156 return *conn.udpState(t).out.SrcPort 1157 } 1158 1159 // Send sends a packet with reasonable defaults, potentially overriding the UDP 1160 // layer and adding additionLayers. 1161 func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { 1162 t.Helper() 1163 1164 conn.send(t, Layers{&udp}, additionalLayers...) 1165 } 1166 1167 // SendIPv6 sends a packet with reasonable defaults, potentially overriding the 1168 // UDP and IPv6 headers and adding additionLayers. 1169 func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { 1170 t.Helper() 1171 1172 conn.send(t, Layers{&ip, &udp}, additionalLayers...) 1173 } 1174 1175 // SendFrame sends a frame on the wire and updates the state of all layers. 1176 func (conn *UDPIPv6) SendFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { 1177 conn.send(t, overrideLayers, additionalLayers...) 1178 } 1179 1180 // Expect expects a frame with the UDP layer matching the provided UDP within 1181 // the timeout specified. If it doesn't arrive in time, an error is returned. 1182 func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { 1183 t.Helper() 1184 1185 layer, err := conn.Connection.Expect(t, &udp, timeout) 1186 if err != nil { 1187 return nil, err 1188 } 1189 gotUDP, ok := layer.(*UDP) 1190 if !ok { 1191 t.Fatalf("expected %s to be UDP", layer) 1192 } 1193 return gotUDP, nil 1194 } 1195 1196 // ExpectData is a convenient method that expects a Layer and the Layer after 1197 // it. If it doesn't arrive in time, it returns nil. 1198 func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { 1199 t.Helper() 1200 1201 expected := make([]Layer, len(conn.layerStates)) 1202 expected[len(expected)-1] = &udp 1203 if payload.length() != 0 { 1204 expected = append(expected, &payload) 1205 } 1206 return conn.ExpectFrame(t, expected, timeout) 1207 } 1208 1209 // TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. 1210 type TCPIPv6 struct { 1211 Connection 1212 } 1213 1214 // NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. 1215 func (n *DUTTestNet) NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { 1216 etherState, err := n.newEtherState(Ether{}, Ether{}) 1217 if err != nil { 1218 t.Fatalf("can't make etherState: %s", err) 1219 } 1220 ipv6State, err := n.newIPv6State(IPv6{}, IPv6{}) 1221 if err != nil { 1222 t.Fatalf("can't make ipv6State: %s", err) 1223 } 1224 tcpState, err := n.newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) 1225 if err != nil { 1226 t.Fatalf("can't make tcpState: %s", err) 1227 } 1228 injector, err := n.NewInjector(t) 1229 if err != nil { 1230 t.Fatalf("can't make injector: %s", err) 1231 } 1232 sniffer, err := n.NewSniffer(t) 1233 if err != nil { 1234 t.Fatalf("can't make sniffer: %s", err) 1235 } 1236 1237 return TCPIPv6{ 1238 Connection: Connection{ 1239 layerStates: []layerState{etherState, ipv6State, tcpState}, 1240 injector: injector, 1241 sniffer: sniffer, 1242 }, 1243 } 1244 } 1245 1246 // SrcPort returns the source port from the given Connection. 1247 func (conn *TCPIPv6) SrcPort() uint16 { 1248 state := conn.layerStates[2].(*tcpState) 1249 return *state.out.SrcPort 1250 } 1251 1252 // ExpectData is a convenient method that expects a Layer and the Layer after 1253 // it. If it doesn't arrive in time, it returns nil. 1254 func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { 1255 t.Helper() 1256 1257 expected := make([]Layer, len(conn.layerStates)) 1258 expected[len(expected)-1] = tcp 1259 if payload != nil { 1260 expected = append(expected, payload) 1261 } 1262 return conn.ExpectFrame(t, expected, timeout) 1263 }