gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/stack/transport_demuxer_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack_test
    16  
    17  import (
    18  	"io/ioutil"
    19  	"math"
    20  	"math/rand"
    21  	"strconv"
    22  	"testing"
    23  
    24  	"gvisor.dev/gvisor/pkg/buffer"
    25  	"gvisor.dev/gvisor/pkg/tcpip"
    26  	"gvisor.dev/gvisor/pkg/tcpip/checksum"
    27  	"gvisor.dev/gvisor/pkg/tcpip/header"
    28  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    29  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    30  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    31  	"gvisor.dev/gvisor/pkg/tcpip/ports"
    32  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    33  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    34  	"gvisor.dev/gvisor/pkg/waiter"
    35  )
    36  
    37  const (
    38  	testDstPort = 1234
    39  	testSrcPort = 4096
    40  )
    41  
    42  var (
    43  	testDstAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"))
    44  	testSrcAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"))
    45  
    46  	testSrcAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01"))
    47  	testDstAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02"))
    48  )
    49  
    50  type testContext struct {
    51  	linkEps map[tcpip.NICID]*channel.Endpoint
    52  	s       *stack.Stack
    53  	wq      waiter.Queue
    54  }
    55  
    56  // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
    57  func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
    58  	s := stack.New(stack.Options{
    59  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    60  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
    61  	})
    62  	linkEps := make(map[tcpip.NICID]*channel.Endpoint)
    63  	for _, linkEpID := range linkEpIDs {
    64  		channelEp := channel.New(256, mtu, "")
    65  		if err := s.CreateNIC(linkEpID, channelEp); err != nil {
    66  			t.Fatalf("CreateNIC failed: %s", err)
    67  		}
    68  		linkEps[linkEpID] = channelEp
    69  
    70  		protocolAddrV4 := tcpip.ProtocolAddress{
    71  			Protocol:          ipv4.ProtocolNumber,
    72  			AddressWithPrefix: testDstAddrV4.WithPrefix(),
    73  		}
    74  		if err := s.AddProtocolAddress(linkEpID, protocolAddrV4, stack.AddressProperties{}); err != nil {
    75  			t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV4, err)
    76  		}
    77  
    78  		protocolAddrV6 := tcpip.ProtocolAddress{
    79  			Protocol:          ipv6.ProtocolNumber,
    80  			AddressWithPrefix: testDstAddrV6.WithPrefix(),
    81  		}
    82  		if err := s.AddProtocolAddress(linkEpID, protocolAddrV6, stack.AddressProperties{}); err != nil {
    83  			t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", linkEpID, protocolAddrV6, err)
    84  		}
    85  	}
    86  
    87  	s.SetRouteTable([]tcpip.Route{
    88  		{Destination: header.IPv4EmptySubnet, NIC: 1},
    89  		{Destination: header.IPv6EmptySubnet, NIC: 1},
    90  	})
    91  
    92  	return &testContext{
    93  		s:       s,
    94  		linkEps: linkEps,
    95  	}
    96  }
    97  
    98  type headers struct {
    99  	srcPort uint16
   100  	dstPort uint16
   101  }
   102  
   103  func newPayload() []byte {
   104  	b := make([]byte, 30+rand.Intn(100))
   105  	for i := range b {
   106  		b[i] = byte(rand.Intn(256))
   107  	}
   108  	return b
   109  }
   110  
   111  func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
   112  	buf := make([]byte, header.UDPMinimumSize+header.IPv4MinimumSize+len(payload))
   113  	payloadStart := len(buf) - len(payload)
   114  	copy(buf[payloadStart:], payload)
   115  
   116  	// Initialize the IP header.
   117  	ip := header.IPv4(buf)
   118  	ip.Encode(&header.IPv4Fields{
   119  		TOS:         0x80,
   120  		TotalLength: uint16(len(buf)),
   121  		TTL:         65,
   122  		Protocol:    uint8(udp.ProtocolNumber),
   123  		SrcAddr:     testSrcAddrV4,
   124  		DstAddr:     testDstAddrV4,
   125  	})
   126  	ip.SetChecksum(^ip.CalculateChecksum())
   127  
   128  	// Initialize the UDP header.
   129  	u := header.UDP(buf[header.IPv4MinimumSize:])
   130  	u.Encode(&header.UDPFields{
   131  		SrcPort: h.srcPort,
   132  		DstPort: h.dstPort,
   133  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   134  	})
   135  
   136  	// Calculate the UDP pseudo-header checksum.
   137  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
   138  
   139  	// Calculate the UDP checksum and set it.
   140  	xsum = checksum.Checksum(payload, xsum)
   141  	u.SetChecksum(^u.CalculateChecksum(xsum))
   142  
   143  	// Inject packet.
   144  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   145  		Payload: buffer.MakeWithData(buf),
   146  	})
   147  	c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
   148  }
   149  
   150  func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
   151  	// Allocate a buffer for data and headers.
   152  	buf := make([]byte, header.UDPMinimumSize+header.IPv6MinimumSize+len(payload))
   153  	copy(buf[len(buf)-len(payload):], payload)
   154  
   155  	// Initialize the IP header.
   156  	ip := header.IPv6(buf)
   157  	ip.Encode(&header.IPv6Fields{
   158  		PayloadLength:     uint16(header.UDPMinimumSize + len(payload)),
   159  		TransportProtocol: udp.ProtocolNumber,
   160  		HopLimit:          65,
   161  		SrcAddr:           testSrcAddrV6,
   162  		DstAddr:           testDstAddrV6,
   163  	})
   164  
   165  	// Initialize the UDP header.
   166  	u := header.UDP(buf[header.IPv6MinimumSize:])
   167  	u.Encode(&header.UDPFields{
   168  		SrcPort: h.srcPort,
   169  		DstPort: h.dstPort,
   170  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   171  	})
   172  
   173  	// Calculate the UDP pseudo-header checksum.
   174  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
   175  
   176  	// Calculate the UDP checksum and set it.
   177  	xsum = checksum.Checksum(payload, xsum)
   178  	u.SetChecksum(^u.CalculateChecksum(xsum))
   179  
   180  	// Inject packet.
   181  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   182  		Payload: buffer.MakeWithData(buf),
   183  	})
   184  	c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
   185  }
   186  
   187  func TestTransportDemuxerRegister(t *testing.T) {
   188  	for _, test := range []struct {
   189  		name  string
   190  		proto tcpip.NetworkProtocolNumber
   191  		want  tcpip.Error
   192  	}{
   193  		{"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
   194  		{"success", ipv4.ProtocolNumber, nil},
   195  	} {
   196  		t.Run(test.name, func(t *testing.T) {
   197  			s := stack.New(stack.Options{
   198  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   199  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   200  			})
   201  			var wq waiter.Queue
   202  			ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   203  			if err != nil {
   204  				t.Fatal(err)
   205  			}
   206  			tEP, ok := ep.(stack.TransportEndpoint)
   207  			if !ok {
   208  				t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
   209  			}
   210  			if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
   211  				t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
   212  			}
   213  		})
   214  	}
   215  }
   216  
   217  func TestTransportDemuxerRegisterMultiple(t *testing.T) {
   218  	type test struct {
   219  		flags ports.Flags
   220  		want  tcpip.Error
   221  	}
   222  	for _, subtest := range []struct {
   223  		name  string
   224  		tests []test
   225  	}{
   226  		{"zeroFlags", []test{
   227  			{ports.Flags{}, nil},
   228  			{ports.Flags{}, &tcpip.ErrPortInUse{}},
   229  		}},
   230  		{"multibindFlags", []test{
   231  			// Allow multiple registrations same TransportEndpointID with multibind flags.
   232  			{ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
   233  			{ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
   234  			// Disallow registration w/same ID for a non-multibindflag.
   235  			{ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
   236  		}},
   237  	} {
   238  		t.Run(subtest.name, func(t *testing.T) {
   239  			s := stack.New(stack.Options{
   240  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   241  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   242  			})
   243  			var eps []tcpip.Endpoint
   244  			for idx, test := range subtest.tests {
   245  				var wq waiter.Queue
   246  				ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   247  				if err != nil {
   248  					t.Fatal(err)
   249  				}
   250  				eps = append(eps, ep)
   251  				tEP, ok := ep.(stack.TransportEndpoint)
   252  				if !ok {
   253  					t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
   254  				}
   255  				id := stack.TransportEndpointID{LocalPort: 1}
   256  				if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
   257  					t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
   258  				}
   259  			}
   260  			for _, ep := range eps {
   261  				ep.Close()
   262  			}
   263  		})
   264  	}
   265  }
   266  
   267  // TestBindToDeviceDistribution injects varied packets on input devices and checks that
   268  // the distribution of packets received matches expectations.
   269  func TestBindToDeviceDistribution(t *testing.T) {
   270  	type endpointSockopts struct {
   271  		reuse        bool
   272  		bindToDevice tcpip.NICID
   273  	}
   274  	tcs := []struct {
   275  		name string
   276  		// endpoints will received the inject packets.
   277  		endpoints []endpointSockopts
   278  		// wantDistributions is the want ratio of packets received on each
   279  		// endpoint for each NIC on which packets are injected.
   280  		wantDistributions map[tcpip.NICID][]float64
   281  	}{
   282  		{
   283  			name: "BindPortReuse",
   284  			// 5 endpoints that all have reuse set.
   285  			endpoints: []endpointSockopts{
   286  				{reuse: true, bindToDevice: 0},
   287  				{reuse: true, bindToDevice: 0},
   288  				{reuse: true, bindToDevice: 0},
   289  				{reuse: true, bindToDevice: 0},
   290  				{reuse: true, bindToDevice: 0},
   291  			},
   292  			wantDistributions: map[tcpip.NICID][]float64{
   293  				// Injected packets on dev0 get distributed evenly.
   294  				1: {0.2, 0.2, 0.2, 0.2, 0.2},
   295  			},
   296  		},
   297  		{
   298  			name: "BindToDevice",
   299  			// 3 endpoints with various bindings.
   300  			endpoints: []endpointSockopts{
   301  				{reuse: false, bindToDevice: 1},
   302  				{reuse: false, bindToDevice: 2},
   303  				{reuse: false, bindToDevice: 3},
   304  			},
   305  			wantDistributions: map[tcpip.NICID][]float64{
   306  				// Injected packets on dev0 go only to the endpoint bound to dev0.
   307  				1: {1, 0, 0},
   308  				// Injected packets on dev1 go only to the endpoint bound to dev1.
   309  				2: {0, 1, 0},
   310  				// Injected packets on dev2 go only to the endpoint bound to dev2.
   311  				3: {0, 0, 1},
   312  			},
   313  		},
   314  		{
   315  			name: "ReuseAndBindToDevice",
   316  			// 6 endpoints with various bindings.
   317  			endpoints: []endpointSockopts{
   318  				{reuse: true, bindToDevice: 1},
   319  				{reuse: true, bindToDevice: 1},
   320  				{reuse: true, bindToDevice: 2},
   321  				{reuse: true, bindToDevice: 2},
   322  				{reuse: true, bindToDevice: 2},
   323  				{reuse: true, bindToDevice: 0},
   324  			},
   325  			wantDistributions: map[tcpip.NICID][]float64{
   326  				// Injected packets on dev0 get distributed among endpoints bound to
   327  				// dev0.
   328  				1: {0.5, 0.5, 0, 0, 0, 0},
   329  				// Injected packets on dev1 get distributed among endpoints bound to
   330  				// dev1 or unbound.
   331  				2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
   332  				// Injected packets on dev999 go only to the unbound.
   333  				1000: {0, 0, 0, 0, 0, 1},
   334  			},
   335  		},
   336  	}
   337  	protos := map[string]tcpip.NetworkProtocolNumber{
   338  		"IPv4": ipv4.ProtocolNumber,
   339  		"IPv6": ipv6.ProtocolNumber,
   340  	}
   341  
   342  	for _, test := range tcs {
   343  		for protoName, protoNum := range protos {
   344  			for device, wantDistribution := range test.wantDistributions {
   345  				t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) {
   346  					// Create the NICs.
   347  					var devices []tcpip.NICID
   348  					for d := range test.wantDistributions {
   349  						devices = append(devices, d)
   350  					}
   351  					c := newDualTestContextMultiNIC(t, defaultMTU, devices)
   352  
   353  					// Create endpoints and bind each to a NIC, sometimes reusing ports.
   354  					eps := make(map[tcpip.Endpoint]int)
   355  					pollChannel := make(chan tcpip.Endpoint)
   356  					for i, endpoint := range test.endpoints {
   357  						// Try to receive the data.
   358  						wq := waiter.Queue{}
   359  						we, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   360  						wq.EventRegister(&we)
   361  						t.Cleanup(func() {
   362  							wq.EventUnregister(&we)
   363  							close(ch)
   364  						})
   365  
   366  						var err tcpip.Error
   367  						ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq)
   368  						if err != nil {
   369  							t.Fatalf("NewEndpoint failed: %s", err)
   370  						}
   371  						t.Cleanup(ep.Close)
   372  						eps[ep] = i
   373  
   374  						go func(ep tcpip.Endpoint) {
   375  							for range ch {
   376  								pollChannel <- ep
   377  							}
   378  						}(ep)
   379  
   380  						ep.SocketOptions().SetReusePort(endpoint.reuse)
   381  						if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
   382  							t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
   383  						}
   384  
   385  						var dstAddr tcpip.Address
   386  						switch protoNum {
   387  						case ipv4.ProtocolNumber:
   388  							dstAddr = testDstAddrV4
   389  						case ipv6.ProtocolNumber:
   390  							dstAddr = testDstAddrV6
   391  						default:
   392  							t.Fatalf("unexpected protocol number: %d", protoNum)
   393  						}
   394  						if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
   395  							t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
   396  						}
   397  					}
   398  
   399  					// Send packets across a range of ports, checking that packets from
   400  					// the same source port are always demultiplexed to the same
   401  					// destination endpoint.
   402  					npackets := 10_000
   403  					nports := 1_000
   404  					if got, want := len(test.endpoints), len(wantDistribution); got != want {
   405  						t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
   406  					}
   407  					endpoints := make(map[uint16]tcpip.Endpoint)
   408  					stats := make(map[tcpip.Endpoint]int)
   409  					for i := 0; i < npackets; i++ {
   410  						// Send a packet.
   411  						port := uint16(i % nports)
   412  						payload := newPayload()
   413  						hdrs := &headers{
   414  							srcPort: testSrcPort + port,
   415  							dstPort: testDstPort,
   416  						}
   417  						switch protoNum {
   418  						case ipv4.ProtocolNumber:
   419  							c.sendV4Packet(payload, hdrs, device)
   420  						case ipv6.ProtocolNumber:
   421  							c.sendV6Packet(payload, hdrs, device)
   422  						default:
   423  							t.Fatalf("unexpected protocol number: %d", protoNum)
   424  						}
   425  
   426  						ep := <-pollChannel
   427  						if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil {
   428  							t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
   429  						}
   430  						stats[ep]++
   431  						if i < nports {
   432  							endpoints[uint16(i)] = ep
   433  						} else {
   434  							// Check that all packets from one client are handled by the same
   435  							// socket.
   436  							if want, got := endpoints[port], ep; want != got {
   437  								t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
   438  							}
   439  						}
   440  					}
   441  
   442  					// Check that a packet distribution is as expected.
   443  					for ep, i := range eps {
   444  						wantRatio := wantDistribution[i]
   445  						wantRecv := wantRatio * float64(npackets)
   446  						actualRecv := stats[ep]
   447  						actualRatio := float64(stats[ep]) / float64(npackets)
   448  						// The deviation is less than 10%.
   449  						if math.Abs(actualRatio-wantRatio) > 0.05 {
   450  							t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
   451  						}
   452  					}
   453  				})
   454  			}
   455  		}
   456  	}
   457  }