github.com/flowerwrong/netstack@v0.0.0-20191009141956-e5848263af28/tcpip/stack/transport_demuxer_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack_test
    16  
    17  import (
    18  	"math"
    19  	"math/rand"
    20  	"testing"
    21  
    22  	"github.com/FlowerWrong/netstack/tcpip"
    23  	"github.com/FlowerWrong/netstack/tcpip/buffer"
    24  	"github.com/FlowerWrong/netstack/tcpip/header"
    25  	"github.com/FlowerWrong/netstack/tcpip/link/channel"
    26  	"github.com/FlowerWrong/netstack/tcpip/network/ipv4"
    27  	"github.com/FlowerWrong/netstack/tcpip/network/ipv6"
    28  	"github.com/FlowerWrong/netstack/tcpip/stack"
    29  	"github.com/FlowerWrong/netstack/tcpip/transport/udp"
    30  	"github.com/FlowerWrong/netstack/waiter"
    31  )
    32  
    33  const (
    34  	stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    35  	testV6Addr  = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    36  
    37  	stackAddr = "\x0a\x00\x00\x01"
    38  	stackPort = 1234
    39  	testPort  = 4096
    40  )
    41  
    42  type testContext struct {
    43  	t       *testing.T
    44  	linkEPs map[string]*channel.Endpoint
    45  	s       *stack.Stack
    46  
    47  	ep tcpip.Endpoint
    48  	wq waiter.Queue
    49  }
    50  
    51  func (c *testContext) cleanup() {
    52  	if c.ep != nil {
    53  		c.ep.Close()
    54  	}
    55  }
    56  
    57  func (c *testContext) createV6Endpoint(v6only bool) {
    58  	var err *tcpip.Error
    59  	c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
    60  	if err != nil {
    61  		c.t.Fatalf("NewEndpoint failed: %v", err)
    62  	}
    63  
    64  	var v tcpip.V6OnlyOption
    65  	if v6only {
    66  		v = 1
    67  	}
    68  	if err := c.ep.SetSockOpt(v); err != nil {
    69  		c.t.Fatalf("SetSockOpt failed: %v", err)
    70  	}
    71  }
    72  
    73  // newDualTestContextMultiNic creates the testing context and also linkEpNames
    74  // named NICs.
    75  func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
    76  	s := stack.New(stack.Options{
    77  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
    78  		TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
    79  	linkEPs := make(map[string]*channel.Endpoint)
    80  	for i, linkEpName := range linkEpNames {
    81  		channelEP := channel.New(256, mtu, "")
    82  		nicid := tcpip.NICID(i + 1)
    83  		if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
    84  			t.Fatalf("CreateNIC failed: %v", err)
    85  		}
    86  		linkEPs[linkEpName] = channelEP
    87  
    88  		if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
    89  			t.Fatalf("AddAddress IPv4 failed: %v", err)
    90  		}
    91  
    92  		if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
    93  			t.Fatalf("AddAddress IPv6 failed: %v", err)
    94  		}
    95  	}
    96  
    97  	s.SetRouteTable([]tcpip.Route{
    98  		{
    99  			Destination: header.IPv4EmptySubnet,
   100  			NIC:         1,
   101  		},
   102  		{
   103  			Destination: header.IPv6EmptySubnet,
   104  			NIC:         1,
   105  		},
   106  	})
   107  
   108  	return &testContext{
   109  		t:       t,
   110  		s:       s,
   111  		linkEPs: linkEPs,
   112  	}
   113  }
   114  
   115  type headers struct {
   116  	srcPort uint16
   117  	dstPort uint16
   118  }
   119  
   120  func newPayload() []byte {
   121  	b := make([]byte, 30+rand.Intn(100))
   122  	for i := range b {
   123  		b[i] = byte(rand.Intn(256))
   124  	}
   125  	return b
   126  }
   127  
   128  func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
   129  	// Allocate a buffer for data and headers.
   130  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
   131  	copy(buf[len(buf)-len(payload):], payload)
   132  
   133  	// Initialize the IP header.
   134  	ip := header.IPv6(buf)
   135  	ip.Encode(&header.IPv6Fields{
   136  		PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
   137  		NextHeader:    uint8(udp.ProtocolNumber),
   138  		HopLimit:      65,
   139  		SrcAddr:       testV6Addr,
   140  		DstAddr:       stackV6Addr,
   141  	})
   142  
   143  	// Initialize the UDP header.
   144  	u := header.UDP(buf[header.IPv6MinimumSize:])
   145  	u.Encode(&header.UDPFields{
   146  		SrcPort: h.srcPort,
   147  		DstPort: h.dstPort,
   148  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   149  	})
   150  
   151  	// Calculate the UDP pseudo-header checksum.
   152  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
   153  
   154  	// Calculate the UDP checksum and set it.
   155  	xsum = header.Checksum(payload, xsum)
   156  	u.SetChecksum(^u.CalculateChecksum(xsum))
   157  
   158  	// Inject packet.
   159  	c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
   160  }
   161  
   162  func TestTransportDemuxerRegister(t *testing.T) {
   163  	for _, test := range []struct {
   164  		name  string
   165  		proto tcpip.NetworkProtocolNumber
   166  		want  *tcpip.Error
   167  	}{
   168  		{"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
   169  		{"success", ipv4.ProtocolNumber, nil},
   170  	} {
   171  		t.Run(test.name, func(t *testing.T) {
   172  			s := stack.New(stack.Options{
   173  				NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol()},
   174  				TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
   175  			if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
   176  				t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
   177  			}
   178  		})
   179  	}
   180  }
   181  
   182  // TestReuseBindToDevice injects varied packets on input devices and checks that
   183  // the distribution of packets received matches expectations.
   184  func TestDistribution(t *testing.T) {
   185  	type endpointSockopts struct {
   186  		reuse        int
   187  		bindToDevice string
   188  	}
   189  	for _, test := range []struct {
   190  		name string
   191  		// endpoints will received the inject packets.
   192  		endpoints []endpointSockopts
   193  		// wantedDistribution is the wanted ratio of packets received on each
   194  		// endpoint for each NIC on which packets are injected.
   195  		wantedDistributions map[string][]float64
   196  	}{
   197  		{
   198  			"BindPortReuse",
   199  			// 5 endpoints that all have reuse set.
   200  			[]endpointSockopts{
   201  				endpointSockopts{1, ""},
   202  				endpointSockopts{1, ""},
   203  				endpointSockopts{1, ""},
   204  				endpointSockopts{1, ""},
   205  				endpointSockopts{1, ""},
   206  			},
   207  			map[string][]float64{
   208  				// Injected packets on dev0 get distributed evenly.
   209  				"dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
   210  			},
   211  		},
   212  		{
   213  			"BindToDevice",
   214  			// 3 endpoints with various bindings.
   215  			[]endpointSockopts{
   216  				endpointSockopts{0, "dev0"},
   217  				endpointSockopts{0, "dev1"},
   218  				endpointSockopts{0, "dev2"},
   219  			},
   220  			map[string][]float64{
   221  				// Injected packets on dev0 go only to the endpoint bound to dev0.
   222  				"dev0": []float64{1, 0, 0},
   223  				// Injected packets on dev1 go only to the endpoint bound to dev1.
   224  				"dev1": []float64{0, 1, 0},
   225  				// Injected packets on dev2 go only to the endpoint bound to dev2.
   226  				"dev2": []float64{0, 0, 1},
   227  			},
   228  		},
   229  		{
   230  			"ReuseAndBindToDevice",
   231  			// 6 endpoints with various bindings.
   232  			[]endpointSockopts{
   233  				endpointSockopts{1, "dev0"},
   234  				endpointSockopts{1, "dev0"},
   235  				endpointSockopts{1, "dev1"},
   236  				endpointSockopts{1, "dev1"},
   237  				endpointSockopts{1, "dev1"},
   238  				endpointSockopts{1, ""},
   239  			},
   240  			map[string][]float64{
   241  				// Injected packets on dev0 get distributed among endpoints bound to
   242  				// dev0.
   243  				"dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
   244  				// Injected packets on dev1 get distributed among endpoints bound to
   245  				// dev1 or unbound.
   246  				"dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
   247  				// Injected packets on dev999 go only to the unbound.
   248  				"dev999": []float64{0, 0, 0, 0, 0, 1},
   249  			},
   250  		},
   251  	} {
   252  		t.Run(test.name, func(t *testing.T) {
   253  			for device, wantedDistribution := range test.wantedDistributions {
   254  				t.Run(device, func(t *testing.T) {
   255  					var devices []string
   256  					for d := range test.wantedDistributions {
   257  						devices = append(devices, d)
   258  					}
   259  					c := newDualTestContextMultiNic(t, defaultMTU, devices)
   260  					defer c.cleanup()
   261  
   262  					c.createV6Endpoint(false)
   263  
   264  					eps := make(map[tcpip.Endpoint]int)
   265  
   266  					pollChannel := make(chan tcpip.Endpoint)
   267  					for i, endpoint := range test.endpoints {
   268  						// Try to receive the data.
   269  						wq := waiter.Queue{}
   270  						we, ch := waiter.NewChannelEntry(nil)
   271  						wq.EventRegister(&we, waiter.EventIn)
   272  						defer wq.EventUnregister(&we)
   273  						defer close(ch)
   274  
   275  						var err *tcpip.Error
   276  						ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
   277  						if err != nil {
   278  							c.t.Fatalf("NewEndpoint failed: %v", err)
   279  						}
   280  						eps[ep] = i
   281  
   282  						go func(ep tcpip.Endpoint) {
   283  							for range ch {
   284  								pollChannel <- ep
   285  							}
   286  						}(ep)
   287  
   288  						defer ep.Close()
   289  						reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
   290  						if err := ep.SetSockOpt(reusePortOption); err != nil {
   291  							c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
   292  						}
   293  						bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
   294  						if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
   295  							c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
   296  						}
   297  						if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
   298  							t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
   299  						}
   300  					}
   301  
   302  					npackets := 100000
   303  					nports := 10000
   304  					if got, want := len(test.endpoints), len(wantedDistribution); got != want {
   305  						t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
   306  					}
   307  					ports := make(map[uint16]tcpip.Endpoint)
   308  					stats := make(map[tcpip.Endpoint]int)
   309  					for i := 0; i < npackets; i++ {
   310  						// Send a packet.
   311  						port := uint16(i % nports)
   312  						payload := newPayload()
   313  						c.sendV6Packet(payload,
   314  							&headers{
   315  								srcPort: testPort + port,
   316  								dstPort: stackPort},
   317  							device)
   318  
   319  						var addr tcpip.FullAddress
   320  						ep := <-pollChannel
   321  						_, _, err := ep.Read(&addr)
   322  						if err != nil {
   323  							c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
   324  						}
   325  						stats[ep]++
   326  						if i < nports {
   327  							ports[uint16(i)] = ep
   328  						} else {
   329  							// Check that all packets from one client are handled by the same
   330  							// socket.
   331  							if want, got := ports[port], ep; want != got {
   332  								t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
   333  							}
   334  						}
   335  					}
   336  
   337  					// Check that a packet distribution is as expected.
   338  					for ep, i := range eps {
   339  						wantedRatio := wantedDistribution[i]
   340  						wantedRecv := wantedRatio * float64(npackets)
   341  						actualRecv := stats[ep]
   342  						actualRatio := float64(stats[ep]) / float64(npackets)
   343  						// The deviation is less than 10%.
   344  						if math.Abs(actualRatio-wantedRatio) > 0.05 {
   345  							t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
   346  						}
   347  					}
   348  				})
   349  			}
   350  		})
   351  	}
   352  }