github.com/polevpn/netstack@v1.10.9/tcpip/stack/transport_demuxer_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack_test
    16  
    17  import (
    18  	"math"
    19  	"math/rand"
    20  	"testing"
    21  
    22  	"github.com/polevpn/netstack/tcpip"
    23  	"github.com/polevpn/netstack/tcpip/buffer"
    24  	"github.com/polevpn/netstack/tcpip/header"
    25  	"github.com/polevpn/netstack/tcpip/link/channel"
    26  	"github.com/polevpn/netstack/tcpip/network/ipv4"
    27  	"github.com/polevpn/netstack/tcpip/network/ipv6"
    28  	"github.com/polevpn/netstack/tcpip/stack"
    29  	"github.com/polevpn/netstack/tcpip/transport/udp"
    30  	"github.com/polevpn/netstack/waiter"
    31  )
    32  
    33  const (
    34  	stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    35  	testV6Addr  = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    36  
    37  	stackAddr = "\x0a\x00\x00\x01"
    38  	stackPort = 1234
    39  	testPort  = 4096
    40  )
    41  
    42  type testContext struct {
    43  	t       *testing.T
    44  	linkEPs map[string]*channel.Endpoint
    45  	s       *stack.Stack
    46  
    47  	ep tcpip.Endpoint
    48  	wq waiter.Queue
    49  }
    50  
    51  func (c *testContext) cleanup() {
    52  	if c.ep != nil {
    53  		c.ep.Close()
    54  	}
    55  }
    56  
    57  func (c *testContext) createV6Endpoint(v6only bool) {
    58  	var err *tcpip.Error
    59  	c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
    60  	if err != nil {
    61  		c.t.Fatalf("NewEndpoint failed: %v", err)
    62  	}
    63  
    64  	var v tcpip.V6OnlyOption
    65  	if v6only {
    66  		v = 1
    67  	}
    68  	if err := c.ep.SetSockOpt(v); err != nil {
    69  		c.t.Fatalf("SetSockOpt failed: %v", err)
    70  	}
    71  }
    72  
    73  // newDualTestContextMultiNic creates the testing context and also linkEpNames
    74  // named NICs.
    75  func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
    76  	s := stack.New(stack.Options{
    77  		NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
    78  		TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
    79  	linkEPs := make(map[string]*channel.Endpoint)
    80  	for i, linkEpName := range linkEpNames {
    81  		channelEP := channel.New(256, mtu, "")
    82  		nicID := tcpip.NICID(i + 1)
    83  		if err := s.CreateNamedNIC(nicID, linkEpName, channelEP); err != nil {
    84  			t.Fatalf("CreateNIC failed: %v", err)
    85  		}
    86  		linkEPs[linkEpName] = channelEP
    87  
    88  		if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
    89  			t.Fatalf("AddAddress IPv4 failed: %v", err)
    90  		}
    91  
    92  		if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
    93  			t.Fatalf("AddAddress IPv6 failed: %v", err)
    94  		}
    95  	}
    96  
    97  	s.SetRouteTable([]tcpip.Route{
    98  		{
    99  			Destination: header.IPv4EmptySubnet,
   100  			NIC:         1,
   101  		},
   102  		{
   103  			Destination: header.IPv6EmptySubnet,
   104  			NIC:         1,
   105  		},
   106  	})
   107  
   108  	return &testContext{
   109  		t:       t,
   110  		s:       s,
   111  		linkEPs: linkEPs,
   112  	}
   113  }
   114  
   115  type headers struct {
   116  	srcPort uint16
   117  	dstPort uint16
   118  }
   119  
   120  func newPayload() []byte {
   121  	b := make([]byte, 30+rand.Intn(100))
   122  	for i := range b {
   123  		b[i] = byte(rand.Intn(256))
   124  	}
   125  	return b
   126  }
   127  
   128  func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
   129  	// Allocate a buffer for data and headers.
   130  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
   131  	copy(buf[len(buf)-len(payload):], payload)
   132  
   133  	// Initialize the IP header.
   134  	ip := header.IPv6(buf)
   135  	ip.Encode(&header.IPv6Fields{
   136  		PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
   137  		NextHeader:    uint8(udp.ProtocolNumber),
   138  		HopLimit:      65,
   139  		SrcAddr:       testV6Addr,
   140  		DstAddr:       stackV6Addr,
   141  	})
   142  
   143  	// Initialize the UDP header.
   144  	u := header.UDP(buf[header.IPv6MinimumSize:])
   145  	u.Encode(&header.UDPFields{
   146  		SrcPort: h.srcPort,
   147  		DstPort: h.dstPort,
   148  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   149  	})
   150  
   151  	// Calculate the UDP pseudo-header checksum.
   152  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
   153  
   154  	// Calculate the UDP checksum and set it.
   155  	xsum = header.Checksum(payload, xsum)
   156  	u.SetChecksum(^u.CalculateChecksum(xsum))
   157  
   158  	// Inject packet.
   159  	c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
   160  		Data: buf.ToVectorisedView(),
   161  	})
   162  }
   163  
   164  func TestTransportDemuxerRegister(t *testing.T) {
   165  	for _, test := range []struct {
   166  		name  string
   167  		proto tcpip.NetworkProtocolNumber
   168  		want  *tcpip.Error
   169  	}{
   170  		{"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
   171  		{"success", ipv4.ProtocolNumber, nil},
   172  	} {
   173  		t.Run(test.name, func(t *testing.T) {
   174  			s := stack.New(stack.Options{
   175  				NetworkProtocols:   []stack.NetworkProtocol{ipv4.NewProtocol()},
   176  				TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
   177  			if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
   178  				t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
   179  			}
   180  		})
   181  	}
   182  }
   183  
   184  // TestReuseBindToDevice injects varied packets on input devices and checks that
   185  // the distribution of packets received matches expectations.
   186  func TestDistribution(t *testing.T) {
   187  	type endpointSockopts struct {
   188  		reuse        int
   189  		bindToDevice string
   190  	}
   191  	for _, test := range []struct {
   192  		name string
   193  		// endpoints will received the inject packets.
   194  		endpoints []endpointSockopts
   195  		// wantedDistribution is the wanted ratio of packets received on each
   196  		// endpoint for each NIC on which packets are injected.
   197  		wantedDistributions map[string][]float64
   198  	}{
   199  		{
   200  			"BindPortReuse",
   201  			// 5 endpoints that all have reuse set.
   202  			[]endpointSockopts{
   203  				endpointSockopts{1, ""},
   204  				endpointSockopts{1, ""},
   205  				endpointSockopts{1, ""},
   206  				endpointSockopts{1, ""},
   207  				endpointSockopts{1, ""},
   208  			},
   209  			map[string][]float64{
   210  				// Injected packets on dev0 get distributed evenly.
   211  				"dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
   212  			},
   213  		},
   214  		{
   215  			"BindToDevice",
   216  			// 3 endpoints with various bindings.
   217  			[]endpointSockopts{
   218  				endpointSockopts{0, "dev0"},
   219  				endpointSockopts{0, "dev1"},
   220  				endpointSockopts{0, "dev2"},
   221  			},
   222  			map[string][]float64{
   223  				// Injected packets on dev0 go only to the endpoint bound to dev0.
   224  				"dev0": []float64{1, 0, 0},
   225  				// Injected packets on dev1 go only to the endpoint bound to dev1.
   226  				"dev1": []float64{0, 1, 0},
   227  				// Injected packets on dev2 go only to the endpoint bound to dev2.
   228  				"dev2": []float64{0, 0, 1},
   229  			},
   230  		},
   231  		{
   232  			"ReuseAndBindToDevice",
   233  			// 6 endpoints with various bindings.
   234  			[]endpointSockopts{
   235  				endpointSockopts{1, "dev0"},
   236  				endpointSockopts{1, "dev0"},
   237  				endpointSockopts{1, "dev1"},
   238  				endpointSockopts{1, "dev1"},
   239  				endpointSockopts{1, "dev1"},
   240  				endpointSockopts{1, ""},
   241  			},
   242  			map[string][]float64{
   243  				// Injected packets on dev0 get distributed among endpoints bound to
   244  				// dev0.
   245  				"dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
   246  				// Injected packets on dev1 get distributed among endpoints bound to
   247  				// dev1 or unbound.
   248  				"dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
   249  				// Injected packets on dev999 go only to the unbound.
   250  				"dev999": []float64{0, 0, 0, 0, 0, 1},
   251  			},
   252  		},
   253  	} {
   254  		t.Run(test.name, func(t *testing.T) {
   255  			for device, wantedDistribution := range test.wantedDistributions {
   256  				t.Run(device, func(t *testing.T) {
   257  					var devices []string
   258  					for d := range test.wantedDistributions {
   259  						devices = append(devices, d)
   260  					}
   261  					c := newDualTestContextMultiNic(t, defaultMTU, devices)
   262  					defer c.cleanup()
   263  
   264  					c.createV6Endpoint(false)
   265  
   266  					eps := make(map[tcpip.Endpoint]int)
   267  
   268  					pollChannel := make(chan tcpip.Endpoint)
   269  					for i, endpoint := range test.endpoints {
   270  						// Try to receive the data.
   271  						wq := waiter.Queue{}
   272  						we, ch := waiter.NewChannelEntry(nil)
   273  						wq.EventRegister(&we, waiter.EventIn)
   274  						defer wq.EventUnregister(&we)
   275  						defer close(ch)
   276  
   277  						var err *tcpip.Error
   278  						ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
   279  						if err != nil {
   280  							c.t.Fatalf("NewEndpoint failed: %v", err)
   281  						}
   282  						eps[ep] = i
   283  
   284  						go func(ep tcpip.Endpoint) {
   285  							for range ch {
   286  								pollChannel <- ep
   287  							}
   288  						}(ep)
   289  
   290  						defer ep.Close()
   291  						reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
   292  						if err := ep.SetSockOpt(reusePortOption); err != nil {
   293  							c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
   294  						}
   295  						bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
   296  						if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
   297  							c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
   298  						}
   299  						if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
   300  							t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
   301  						}
   302  					}
   303  
   304  					npackets := 100000
   305  					nports := 10000
   306  					if got, want := len(test.endpoints), len(wantedDistribution); got != want {
   307  						t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
   308  					}
   309  					ports := make(map[uint16]tcpip.Endpoint)
   310  					stats := make(map[tcpip.Endpoint]int)
   311  					for i := 0; i < npackets; i++ {
   312  						// Send a packet.
   313  						port := uint16(i % nports)
   314  						payload := newPayload()
   315  						c.sendV6Packet(payload,
   316  							&headers{
   317  								srcPort: testPort + port,
   318  								dstPort: stackPort},
   319  							device)
   320  
   321  						var addr tcpip.FullAddress
   322  						ep := <-pollChannel
   323  						_, _, err := ep.Read(&addr)
   324  						if err != nil {
   325  							c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
   326  						}
   327  						stats[ep]++
   328  						if i < nports {
   329  							ports[uint16(i)] = ep
   330  						} else {
   331  							// Check that all packets from one client are handled by the same
   332  							// socket.
   333  							if want, got := ports[port], ep; want != got {
   334  								t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
   335  							}
   336  						}
   337  					}
   338  
   339  					// Check that a packet distribution is as expected.
   340  					for ep, i := range eps {
   341  						wantedRatio := wantedDistribution[i]
   342  						wantedRecv := wantedRatio * float64(npackets)
   343  						actualRecv := stats[ep]
   344  						actualRatio := float64(stats[ep]) / float64(npackets)
   345  						// The deviation is less than 10%.
   346  						if math.Abs(actualRatio-wantedRatio) > 0.05 {
   347  							t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
   348  						}
   349  					}
   350  				})
   351  			}
   352  		})
   353  	}
   354  }