github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/stack/transport_demuxer_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack_test
    16  
    17  import (
    18  	"io/ioutil"
    19  	"math"
    20  	"math/rand"
    21  	"strconv"
    22  	"testing"
    23  
    24  	"github.com/SagerNet/gvisor/pkg/tcpip"
    25  	"github.com/SagerNet/gvisor/pkg/tcpip/buffer"
    26  	"github.com/SagerNet/gvisor/pkg/tcpip/header"
    27  	"github.com/SagerNet/gvisor/pkg/tcpip/link/channel"
    28  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv4"
    29  	"github.com/SagerNet/gvisor/pkg/tcpip/network/ipv6"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip/ports"
    31  	"github.com/SagerNet/gvisor/pkg/tcpip/stack"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip/transport/udp"
    33  	"github.com/SagerNet/gvisor/pkg/waiter"
    34  )
    35  
    36  const (
    37  	testSrcAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
    38  	testDstAddrV6 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
    39  
    40  	testSrcAddrV4 = "\x0a\x00\x00\x01"
    41  	testDstAddrV4 = "\x0a\x00\x00\x02"
    42  
    43  	testDstPort = 1234
    44  	testSrcPort = 4096
    45  )
    46  
    47  type testContext struct {
    48  	linkEps map[tcpip.NICID]*channel.Endpoint
    49  	s       *stack.Stack
    50  	wq      waiter.Queue
    51  }
    52  
    53  // newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
    54  func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
    55  	s := stack.New(stack.Options{
    56  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    57  		TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
    58  	})
    59  	linkEps := make(map[tcpip.NICID]*channel.Endpoint)
    60  	for _, linkEpID := range linkEpIDs {
    61  		channelEp := channel.New(256, mtu, "")
    62  		if err := s.CreateNIC(linkEpID, channelEp); err != nil {
    63  			t.Fatalf("CreateNIC failed: %s", err)
    64  		}
    65  		linkEps[linkEpID] = channelEp
    66  
    67  		if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, testDstAddrV4); err != nil {
    68  			t.Fatalf("AddAddress IPv4 failed: %s", err)
    69  		}
    70  
    71  		if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, testDstAddrV6); err != nil {
    72  			t.Fatalf("AddAddress IPv6 failed: %s", err)
    73  		}
    74  	}
    75  
    76  	s.SetRouteTable([]tcpip.Route{
    77  		{Destination: header.IPv4EmptySubnet, NIC: 1},
    78  		{Destination: header.IPv6EmptySubnet, NIC: 1},
    79  	})
    80  
    81  	return &testContext{
    82  		s:       s,
    83  		linkEps: linkEps,
    84  	}
    85  }
    86  
    87  type headers struct {
    88  	srcPort uint16
    89  	dstPort uint16
    90  }
    91  
    92  func newPayload() []byte {
    93  	b := make([]byte, 30+rand.Intn(100))
    94  	for i := range b {
    95  		b[i] = byte(rand.Intn(256))
    96  	}
    97  	return b
    98  }
    99  
   100  func (c *testContext) sendV4Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
   101  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
   102  	payloadStart := len(buf) - len(payload)
   103  	copy(buf[payloadStart:], payload)
   104  
   105  	// Initialize the IP header.
   106  	ip := header.IPv4(buf)
   107  	ip.Encode(&header.IPv4Fields{
   108  		TOS:         0x80,
   109  		TotalLength: uint16(len(buf)),
   110  		TTL:         65,
   111  		Protocol:    uint8(udp.ProtocolNumber),
   112  		SrcAddr:     testSrcAddrV4,
   113  		DstAddr:     testDstAddrV4,
   114  	})
   115  	ip.SetChecksum(^ip.CalculateChecksum())
   116  
   117  	// Initialize the UDP header.
   118  	u := header.UDP(buf[header.IPv4MinimumSize:])
   119  	u.Encode(&header.UDPFields{
   120  		SrcPort: h.srcPort,
   121  		DstPort: h.dstPort,
   122  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   123  	})
   124  
   125  	// Calculate the UDP pseudo-header checksum.
   126  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV4, testDstAddrV4, uint16(len(u)))
   127  
   128  	// Calculate the UDP checksum and set it.
   129  	xsum = header.Checksum(payload, xsum)
   130  	u.SetChecksum(^u.CalculateChecksum(xsum))
   131  
   132  	// Inject packet.
   133  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   134  		Data: buf.ToVectorisedView(),
   135  	})
   136  	c.linkEps[linkEpID].InjectInbound(ipv4.ProtocolNumber, pkt)
   137  }
   138  
   139  func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
   140  	// Allocate a buffer for data and headers.
   141  	buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
   142  	copy(buf[len(buf)-len(payload):], payload)
   143  
   144  	// Initialize the IP header.
   145  	ip := header.IPv6(buf)
   146  	ip.Encode(&header.IPv6Fields{
   147  		PayloadLength:     uint16(header.UDPMinimumSize + len(payload)),
   148  		TransportProtocol: udp.ProtocolNumber,
   149  		HopLimit:          65,
   150  		SrcAddr:           testSrcAddrV6,
   151  		DstAddr:           testDstAddrV6,
   152  	})
   153  
   154  	// Initialize the UDP header.
   155  	u := header.UDP(buf[header.IPv6MinimumSize:])
   156  	u.Encode(&header.UDPFields{
   157  		SrcPort: h.srcPort,
   158  		DstPort: h.dstPort,
   159  		Length:  uint16(header.UDPMinimumSize + len(payload)),
   160  	})
   161  
   162  	// Calculate the UDP pseudo-header checksum.
   163  	xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testSrcAddrV6, testDstAddrV6, uint16(len(u)))
   164  
   165  	// Calculate the UDP checksum and set it.
   166  	xsum = header.Checksum(payload, xsum)
   167  	u.SetChecksum(^u.CalculateChecksum(xsum))
   168  
   169  	// Inject packet.
   170  	pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
   171  		Data: buf.ToVectorisedView(),
   172  	})
   173  	c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, pkt)
   174  }
   175  
   176  func TestTransportDemuxerRegister(t *testing.T) {
   177  	for _, test := range []struct {
   178  		name  string
   179  		proto tcpip.NetworkProtocolNumber
   180  		want  tcpip.Error
   181  	}{
   182  		{"failure", ipv6.ProtocolNumber, &tcpip.ErrUnknownProtocol{}},
   183  		{"success", ipv4.ProtocolNumber, nil},
   184  	} {
   185  		t.Run(test.name, func(t *testing.T) {
   186  			s := stack.New(stack.Options{
   187  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   188  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   189  			})
   190  			var wq waiter.Queue
   191  			ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   192  			if err != nil {
   193  				t.Fatal(err)
   194  			}
   195  			tEP, ok := ep.(stack.TransportEndpoint)
   196  			if !ok {
   197  				t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
   198  			}
   199  			if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
   200  				t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
   201  			}
   202  		})
   203  	}
   204  }
   205  
   206  func TestTransportDemuxerRegisterMultiple(t *testing.T) {
   207  	type test struct {
   208  		flags ports.Flags
   209  		want  tcpip.Error
   210  	}
   211  	for _, subtest := range []struct {
   212  		name  string
   213  		tests []test
   214  	}{
   215  		{"zeroFlags", []test{
   216  			{ports.Flags{}, nil},
   217  			{ports.Flags{}, &tcpip.ErrPortInUse{}},
   218  		}},
   219  		{"multibindFlags", []test{
   220  			// Allow multiple registrations same TransportEndpointID with multibind flags.
   221  			{ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
   222  			{ports.Flags{LoadBalanced: true, MostRecent: true}, nil},
   223  			// Disallow registration w/same ID for a non-multibindflag.
   224  			{ports.Flags{TupleOnly: true}, &tcpip.ErrPortInUse{}},
   225  		}},
   226  	} {
   227  		t.Run(subtest.name, func(t *testing.T) {
   228  			s := stack.New(stack.Options{
   229  				NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol},
   230  				TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
   231  			})
   232  			var eps []tcpip.Endpoint
   233  			for idx, test := range subtest.tests {
   234  				var wq waiter.Queue
   235  				ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
   236  				if err != nil {
   237  					t.Fatal(err)
   238  				}
   239  				eps = append(eps, ep)
   240  				tEP, ok := ep.(stack.TransportEndpoint)
   241  				if !ok {
   242  					t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
   243  				}
   244  				id := stack.TransportEndpointID{LocalPort: 1}
   245  				if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{ipv4.ProtocolNumber}, udp.ProtocolNumber, id, tEP, test.flags, 0), test.want; got != want {
   246  					t.Fatalf("test index: %d, s.RegisterTransportEndpoint(ipv4.ProtocolNumber, udp.ProtocolNumber, _, _, %+v, 0) = %s, want %s", idx, test.flags, got, want)
   247  				}
   248  			}
   249  			for _, ep := range eps {
   250  				ep.Close()
   251  			}
   252  		})
   253  	}
   254  }
   255  
   256  // TestBindToDeviceDistribution injects varied packets on input devices and checks that
   257  // the distribution of packets received matches expectations.
   258  func TestBindToDeviceDistribution(t *testing.T) {
   259  	type endpointSockopts struct {
   260  		reuse        bool
   261  		bindToDevice tcpip.NICID
   262  	}
   263  	tcs := []struct {
   264  		name string
   265  		// endpoints will received the inject packets.
   266  		endpoints []endpointSockopts
   267  		// wantDistributions is the want ratio of packets received on each
   268  		// endpoint for each NIC on which packets are injected.
   269  		wantDistributions map[tcpip.NICID][]float64
   270  	}{
   271  		{
   272  			name: "BindPortReuse",
   273  			// 5 endpoints that all have reuse set.
   274  			endpoints: []endpointSockopts{
   275  				{reuse: true, bindToDevice: 0},
   276  				{reuse: true, bindToDevice: 0},
   277  				{reuse: true, bindToDevice: 0},
   278  				{reuse: true, bindToDevice: 0},
   279  				{reuse: true, bindToDevice: 0},
   280  			},
   281  			wantDistributions: map[tcpip.NICID][]float64{
   282  				// Injected packets on dev0 get distributed evenly.
   283  				1: {0.2, 0.2, 0.2, 0.2, 0.2},
   284  			},
   285  		},
   286  		{
   287  			name: "BindToDevice",
   288  			// 3 endpoints with various bindings.
   289  			endpoints: []endpointSockopts{
   290  				{reuse: false, bindToDevice: 1},
   291  				{reuse: false, bindToDevice: 2},
   292  				{reuse: false, bindToDevice: 3},
   293  			},
   294  			wantDistributions: map[tcpip.NICID][]float64{
   295  				// Injected packets on dev0 go only to the endpoint bound to dev0.
   296  				1: {1, 0, 0},
   297  				// Injected packets on dev1 go only to the endpoint bound to dev1.
   298  				2: {0, 1, 0},
   299  				// Injected packets on dev2 go only to the endpoint bound to dev2.
   300  				3: {0, 0, 1},
   301  			},
   302  		},
   303  		{
   304  			name: "ReuseAndBindToDevice",
   305  			// 6 endpoints with various bindings.
   306  			endpoints: []endpointSockopts{
   307  				{reuse: true, bindToDevice: 1},
   308  				{reuse: true, bindToDevice: 1},
   309  				{reuse: true, bindToDevice: 2},
   310  				{reuse: true, bindToDevice: 2},
   311  				{reuse: true, bindToDevice: 2},
   312  				{reuse: true, bindToDevice: 0},
   313  			},
   314  			wantDistributions: map[tcpip.NICID][]float64{
   315  				// Injected packets on dev0 get distributed among endpoints bound to
   316  				// dev0.
   317  				1: {0.5, 0.5, 0, 0, 0, 0},
   318  				// Injected packets on dev1 get distributed among endpoints bound to
   319  				// dev1 or unbound.
   320  				2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
   321  				// Injected packets on dev999 go only to the unbound.
   322  				1000: {0, 0, 0, 0, 0, 1},
   323  			},
   324  		},
   325  	}
   326  	protos := map[string]tcpip.NetworkProtocolNumber{
   327  		"IPv4": ipv4.ProtocolNumber,
   328  		"IPv6": ipv6.ProtocolNumber,
   329  	}
   330  
   331  	for _, test := range tcs {
   332  		for protoName, protoNum := range protos {
   333  			for device, wantDistribution := range test.wantDistributions {
   334  				t.Run(test.name+protoName+"-"+strconv.Itoa(int(device)), func(t *testing.T) {
   335  					// Create the NICs.
   336  					var devices []tcpip.NICID
   337  					for d := range test.wantDistributions {
   338  						devices = append(devices, d)
   339  					}
   340  					c := newDualTestContextMultiNIC(t, defaultMTU, devices)
   341  
   342  					// Create endpoints and bind each to a NIC, sometimes reusing ports.
   343  					eps := make(map[tcpip.Endpoint]int)
   344  					pollChannel := make(chan tcpip.Endpoint)
   345  					for i, endpoint := range test.endpoints {
   346  						// Try to receive the data.
   347  						wq := waiter.Queue{}
   348  						we, ch := waiter.NewChannelEntry(nil)
   349  						wq.EventRegister(&we, waiter.ReadableEvents)
   350  						t.Cleanup(func() {
   351  							wq.EventUnregister(&we)
   352  							close(ch)
   353  						})
   354  
   355  						var err tcpip.Error
   356  						ep, err := c.s.NewEndpoint(udp.ProtocolNumber, protoNum, &wq)
   357  						if err != nil {
   358  							t.Fatalf("NewEndpoint failed: %s", err)
   359  						}
   360  						t.Cleanup(ep.Close)
   361  						eps[ep] = i
   362  
   363  						go func(ep tcpip.Endpoint) {
   364  							for range ch {
   365  								pollChannel <- ep
   366  							}
   367  						}(ep)
   368  
   369  						ep.SocketOptions().SetReusePort(endpoint.reuse)
   370  						if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil {
   371  							t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err)
   372  						}
   373  
   374  						var dstAddr tcpip.Address
   375  						switch protoNum {
   376  						case ipv4.ProtocolNumber:
   377  							dstAddr = testDstAddrV4
   378  						case ipv6.ProtocolNumber:
   379  							dstAddr = testDstAddrV6
   380  						default:
   381  							t.Fatalf("unexpected protocol number: %d", protoNum)
   382  						}
   383  						if err := ep.Bind(tcpip.FullAddress{Addr: dstAddr, Port: testDstPort}); err != nil {
   384  							t.Fatalf("ep.Bind(...) on endpoint %d failed: %s", i, err)
   385  						}
   386  					}
   387  
   388  					// Send packets across a range of ports, checking that packets from
   389  					// the same source port are always demultiplexed to the same
   390  					// destination endpoint.
   391  					npackets := 10_000
   392  					nports := 1_000
   393  					if got, want := len(test.endpoints), len(wantDistribution); got != want {
   394  						t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
   395  					}
   396  					endpoints := make(map[uint16]tcpip.Endpoint)
   397  					stats := make(map[tcpip.Endpoint]int)
   398  					for i := 0; i < npackets; i++ {
   399  						// Send a packet.
   400  						port := uint16(i % nports)
   401  						payload := newPayload()
   402  						hdrs := &headers{
   403  							srcPort: testSrcPort + port,
   404  							dstPort: testDstPort,
   405  						}
   406  						switch protoNum {
   407  						case ipv4.ProtocolNumber:
   408  							c.sendV4Packet(payload, hdrs, device)
   409  						case ipv6.ProtocolNumber:
   410  							c.sendV6Packet(payload, hdrs, device)
   411  						default:
   412  							t.Fatalf("unexpected protocol number: %d", protoNum)
   413  						}
   414  
   415  						ep := <-pollChannel
   416  						if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil {
   417  							t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
   418  						}
   419  						stats[ep]++
   420  						if i < nports {
   421  							endpoints[uint16(i)] = ep
   422  						} else {
   423  							// Check that all packets from one client are handled by the same
   424  							// socket.
   425  							if want, got := endpoints[port], ep; want != got {
   426  								t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
   427  							}
   428  						}
   429  					}
   430  
   431  					// Check that a packet distribution is as expected.
   432  					for ep, i := range eps {
   433  						wantRatio := wantDistribution[i]
   434  						wantRecv := wantRatio * float64(npackets)
   435  						actualRecv := stats[ep]
   436  						actualRatio := float64(stats[ep]) / float64(npackets)
   437  						// The deviation is less than 10%.
   438  						if math.Abs(actualRatio-wantRatio) > 0.05 {
   439  							t.Errorf("want about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantRatio*100, wantRecv, npackets, i, actualRatio*100, actualRecv, npackets)
   440  						}
   441  					}
   442  				})
   443  			}
   444  		}
   445  	}
   446  }