gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/raw/raw_test.go (about) 1 // Copyright 2022 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package raw_test 16 17 import ( 18 "os" 19 "testing" 20 21 "gvisor.dev/gvisor/pkg/refs" 22 "gvisor.dev/gvisor/pkg/tcpip" 23 "gvisor.dev/gvisor/pkg/tcpip/checker" 24 "gvisor.dev/gvisor/pkg/tcpip/header" 25 "gvisor.dev/gvisor/pkg/tcpip/stack" 26 "gvisor.dev/gvisor/pkg/tcpip/transport/testing/context" 27 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 28 ) 29 30 const ( 31 testTOS = 0x80 32 testTTL = 65 33 ) 34 35 func TestReceiveControlMessage(t *testing.T) { 36 var payload = [...]byte{0, 1, 2, 3, 4, 5} 37 38 for _, flow := range []context.TestFlow{context.UnicastV4, context.UnicastV6, context.UnicastV6Only, context.MulticastV4, context.MulticastV6, context.MulticastV6Only, context.Broadcast} { 39 t.Run(flow.String(), func(t *testing.T) { 40 for _, test := range []struct { 41 name string 42 optionProtocol tcpip.NetworkProtocolNumber 43 getReceiveOption func(tcpip.Endpoint) bool 44 setReceiveOption func(tcpip.Endpoint, bool) 45 presenceChecker checker.ControlMessagesChecker 46 absenceChecker checker.ControlMessagesChecker 47 }{ 48 { 49 name: "TOS", 50 optionProtocol: header.IPv4ProtocolNumber, 51 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTOS() }, 52 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTOS(value) }, 53 presenceChecker: checker.ReceiveTOS(testTOS), 54 absenceChecker: checker.NoTOSReceived(), 55 }, 56 { 57 name: "TClass", 58 optionProtocol: header.IPv6ProtocolNumber, 59 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTClass() }, 60 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTClass(value) }, 61 presenceChecker: checker.ReceiveTClass(testTOS), 62 absenceChecker: checker.NoTClassReceived(), 63 }, 64 { 65 name: "TTL", 66 optionProtocol: header.IPv4ProtocolNumber, 67 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveTTL() }, 68 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveTTL(value) }, 69 presenceChecker: checker.ReceiveTTL(testTTL), 70 absenceChecker: checker.NoTTLReceived(), 71 }, 72 { 73 name: "HopLimit", 74 optionProtocol: header.IPv6ProtocolNumber, 75 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceiveHopLimit() }, 76 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceiveHopLimit(value) }, 77 presenceChecker: checker.ReceiveHopLimit(testTTL), 78 absenceChecker: checker.NoHopLimitReceived(), 79 }, 80 { 81 name: "IPPacketInfo", 82 optionProtocol: header.IPv4ProtocolNumber, 83 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetReceivePacketInfo() }, 84 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetReceivePacketInfo(value) }, 85 presenceChecker: func() checker.ControlMessagesChecker { 86 h := flow.MakeHeader4Tuple(context.Incoming) 87 return checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ 88 NIC: context.NICID, 89 // TODO(https://gvisor.dev/issue/3556): Expect the NIC's address 90 // instead of the header destination address for the LocalAddr 91 // field. 92 LocalAddr: h.Dst.Addr, 93 DestinationAddr: h.Dst.Addr, 94 }) 95 }(), 96 absenceChecker: checker.NoIPPacketInfoReceived(), 97 }, 98 { 99 name: "IPv6PacketInfo", 100 optionProtocol: header.IPv6ProtocolNumber, 101 getReceiveOption: func(ep tcpip.Endpoint) bool { return ep.SocketOptions().GetIPv6ReceivePacketInfo() }, 102 setReceiveOption: func(ep tcpip.Endpoint, value bool) { ep.SocketOptions().SetIPv6ReceivePacketInfo(value) }, 103 presenceChecker: func() checker.ControlMessagesChecker { 104 h := flow.MakeHeader4Tuple(context.Incoming) 105 return checker.ReceiveIPv6PacketInfo(tcpip.IPv6PacketInfo{ 106 NIC: context.NICID, 107 Addr: h.Dst.Addr, 108 }) 109 }(), 110 absenceChecker: checker.NoIPv6PacketInfoReceived(), 111 }, 112 } { 113 t.Run(test.name, func(t *testing.T) { 114 c := context.New(t, []stack.TransportProtocolFactory{udp.NewProtocol}) 115 defer c.Cleanup() 116 117 c.CreateRawEndpointForFlow(flow, header.UDPProtocolNumber) 118 if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { 119 c.T.Fatalf("Bind failed: %s", err) 120 } 121 if flow.IsMulticast() { 122 netProto := flow.NetProto() 123 addr := flow.GetMulticastAddr() 124 if err := c.Stack.JoinGroup(netProto, context.NICID, addr); err != nil { 125 c.T.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, context.NICID, addr, err) 126 } 127 } 128 129 buf := context.BuildUDPPacket(payload[:], flow, context.Incoming, testTOS, testTTL, false) 130 expectedReadData := buf 131 if flow.IsV6() { 132 // Raw IPv6 endpoints do not return the network header. 133 expectedReadData = expectedReadData[header.IPv6MinimumSize:] 134 } 135 136 if test.getReceiveOption(c.EP) { 137 t.Fatal("got getReceiveOption() = true, want = false") 138 } 139 140 test.setReceiveOption(c.EP, true) 141 if !test.getReceiveOption(c.EP) { 142 t.Fatal("got getReceiveOption() = false, want = true") 143 } 144 145 c.InjectPacket(flow.NetProto(), buf) 146 if flow.NetProto() == test.optionProtocol { 147 c.ReadFromEndpointExpectSuccess(expectedReadData, flow, test.presenceChecker) 148 } else { 149 c.ReadFromEndpointExpectSuccess(expectedReadData, flow, test.absenceChecker) 150 } 151 }) 152 } 153 }) 154 } 155 } 156 157 func TestMain(m *testing.M) { 158 refs.SetLeakMode(refs.LeaksPanic) 159 code := m.Run() 160 refs.DoLeakCheck() 161 os.Exit(code) 162 }