gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/raw/raw_test.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package raw_test
    16  
    17  import (
    18  	"os"
    19  	"testing"
    20  
    21  	"gvisor.dev/gvisor/pkg/refs"
    22  	"gvisor.dev/gvisor/pkg/tcpip"
    23  	"gvisor.dev/gvisor/pkg/tcpip/checker"
    24  	"gvisor.dev/gvisor/pkg/tcpip/header"
    25  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    26  	"gvisor.dev/gvisor/pkg/tcpip/transport/testing/context"
    27  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    28  )
    29  
    30  const (
    31  	testTOS = 0x80
    32  	testTTL = 65
    33  )
    34  
    35  func TestReceiveControlMessage(t *testing.T) {
    36  	var payload = [...]byte{0, 1, 2, 3, 4, 5}
    37  
    38  	for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6, context.UnicastV6Only, context.MulticastV4, context.MulticastV6, context.MulticastV6Only, context.Broadcast} {
    39  		t.Run(flow.String(), func(t *testing.T) {
    40  			for _, test := range []struct {
    41  				name             string
    42  				optionProtocol   tcpip.NetworkProtocolNumber
    43  				getReceiveOption func(tcpip.Endpoint) bool
    44  				setReceiveOption func(tcpip.Endpoint, bool)
    45  				presenceChecker  checker.ControlMessagesChecker
    46  				absenceChecker   checker.ControlMessagesChecker
    47  			}{
    48  				{
    49  					name:             "TOS",
    50  					optionProtocol:   header.IPv4ProtocolNumber,
    51  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTOS() },
    52  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTOS(value) },
    53  					presenceChecker:  checker.ReceiveTOS(testTOS),
    54  					absenceChecker:   checker.NoTOSReceived(),
    55  				},
    56  				{
    57  					name:             "TClass",
    58  					optionProtocol:   header.IPv6ProtocolNumber,
    59  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTClass() },
    60  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTClass(value) },
    61  					presenceChecker:  checker.ReceiveTClass(testTOS),
    62  					absenceChecker:   checker.NoTClassReceived(),
    63  				},
    64  				{
    65  					name:             "TTL",
    66  					optionProtocol:   header.IPv4ProtocolNumber,
    67  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTTL() },
    68  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTTL(value) },
    69  					presenceChecker:  checker.ReceiveTTL(testTTL),
    70  					absenceChecker:   checker.NoTTLReceived(),
    71  				},
    72  				{
    73  					name:             "HopLimit",
    74  					optionProtocol:   header.IPv6ProtocolNumber,
    75  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveHopLimit() },
    76  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveHopLimit(value) },
    77  					presenceChecker:  checker.ReceiveHopLimit(testTTL),
    78  					absenceChecker:   checker.NoHopLimitReceived(),
    79  				},
    80  				{
    81  					name:             "IPPacketInfo",
    82  					optionProtocol:   header.IPv4ProtocolNumber,
    83  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceivePacketInfo() },
    84  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceivePacketInfo(value) },
    85  					presenceChecker: func() checker.ControlMessagesChecker {
    86  						h := flow.MakeHeader4Tuple(context.Incoming)
    87  						return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
    88  							NIC: context.NICID,
    89  							// TODO(https://gvisor.dev/issue/3556): Expect the NIC's address
    90  							// instead of the header destination address for the LocalAddr
    91  							// field.
    92  							LocalAddr:       h.Dst.Addr,
    93  							DestinationAddr: h.Dst.Addr,
    94  						})
    95  					}(),
    96  					absenceChecker: checker.NoIPPacketInfoReceived(),
    97  				},
    98  				{
    99  					name:             "IPv6PacketInfo",
   100  					optionProtocol:   header.IPv6ProtocolNumber,
   101  					getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetIPv6ReceivePacketInfo() },
   102  					setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetIPv6ReceivePacketInfo(value) },
   103  					presenceChecker: func() checker.ControlMessagesChecker {
   104  						h := flow.MakeHeader4Tuple(context.Incoming)
   105  						return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{
   106  							NIC:  context.NICID,
   107  							Addr: h.Dst.Addr,
   108  						})
   109  					}(),
   110  					absenceChecker: checker.NoIPv6PacketInfoReceived(),
   111  				},
   112  			} {
   113  				t.Run(test.name, func(t *testing.T) {
   114  					c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol})
   115  					defer c.Cleanup()
   116  
   117  					c.CreateRawEndpointForFlow(flow, header.UDPProtocolNumber)
   118  					if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
   119  						c.T.Fatalf("Bind failed: %s", err)
   120  					}
   121  					if flow.IsMulticast() {
   122  						netProto := flow.NetProto()
   123  						addr := flow.GetMulticastAddr()
   124  						if err := c.Stack.JoinGroup(netProto, context.NICID, addr); err != nil {
   125  							c.T.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, context.NICID, addr, err)
   126  						}
   127  					}
   128  
   129  					buf := context.BuildUDPPacket(payload[:], flow, context.Incoming, testTOS, testTTL, false)
   130  					expectedReadData := buf
   131  					if flow.IsV6() {
   132  						// Raw IPv6 endpoints do not return the network header.
   133  						expectedReadData = expectedReadData[header.IPv6MinimumSize:]
   134  					}
   135  
   136  					if test.getReceiveOption(c.EP) {
   137  						t.Fatal("got getReceiveOption() = true, want = false")
   138  					}
   139  
   140  					test.setReceiveOption(c.EP, true)
   141  					if !test.getReceiveOption(c.EP) {
   142  						t.Fatal("got getReceiveOption() = false, want = true")
   143  					}
   144  
   145  					c.InjectPacket(flow.NetProto(), buf)
   146  					if flow.NetProto() == test.optionProtocol {
   147  						c.ReadFromEndpointExpectSuccess(expectedReadData, flow, test.presenceChecker)
   148  					} else {
   149  						c.ReadFromEndpointExpectSuccess(expectedReadData, flow, test.absenceChecker)
   150  					}
   151  				})
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  func TestMain(m *testing.M) {
   158  	refs.SetLeakMode(refs.LeaksPanic)
   159  	code := m.Run()
   160  	refs.DoLeakCheck()
   161  	os.Exit(code)
   162  }