gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/transport/testing/context/context.go (about) 1 // Copyright 2022 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package context provides a context used by datagram-based network endpoints 16 // tests. It also defines the TestFlow type to facilitate IP configurations. 17 package context 18 19 import ( 20 "bytes" 21 "reflect" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 "golang.org/x/time/rate" 26 "gvisor.dev/gvisor/pkg/buffer" 27 "gvisor.dev/gvisor/pkg/refs" 28 "gvisor.dev/gvisor/pkg/tcpip" 29 "gvisor.dev/gvisor/pkg/tcpip/checker" 30 "gvisor.dev/gvisor/pkg/tcpip/faketime" 31 "gvisor.dev/gvisor/pkg/tcpip/header" 32 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 33 "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" 34 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 35 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 36 "gvisor.dev/gvisor/pkg/tcpip/stack" 37 "gvisor.dev/gvisor/pkg/tcpip/transport/raw" 38 "gvisor.dev/gvisor/pkg/waiter" 39 ) 40 41 const ( 42 // NICID is the id of the nic created by the Context. 43 NICID = 1 44 45 // DefaultMTU is the MTU used by the Context, except where another value is 46 // explicitly specified during initialization. It is chosen to match the MTU 47 // of loopback interfaces on linux systems. 48 DefaultMTU = 65536 49 ) 50 51 // Context is a testing context for datagram-based network endpoints. 52 type Context struct { 53 // T is the testing context. 54 T *testing.T 55 56 // LinkEP is the link endpoint that is attached to the stack's NIC. 57 LinkEP *channel.Endpoint 58 59 // Stack is the networking stack owned by the context. 60 Stack *stack.Stack 61 62 // EP is the transport endpoint owned by the context. 63 EP tcpip.Endpoint 64 65 // WQ is the wait queue associated with EP and is used to block for events on 66 // EP. 67 WQ waiter.Queue 68 } 69 70 // Options contains options for creating a new test context. 71 type Options struct { 72 // MTU is the mtu that the link endpoint will be initialized with. 73 MTU uint32 74 75 // HandleLocal specifies if non-loopback interfaces are allowed to loop 76 // packets. 77 HandleLocal bool 78 } 79 80 // New allocates and initializes a test context containing a configured stack. 81 func New(t *testing.T, transportProtocols []stack.TransportProtocolFactory) *Context { 82 t.Helper() 83 84 options := Options{ 85 MTU: DefaultMTU, 86 HandleLocal: true, 87 } 88 89 return NewWithOptions(t, transportProtocols, options) 90 } 91 92 // NewWithOptions allocates and initializes a test context containing a 93 // configured stack with the provided options. 94 func NewWithOptions(t *testing.T, transportProtocols []stack.TransportProtocolFactory, options Options) *Context { 95 t.Helper() 96 97 stackOptions := stack.Options{ 98 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 99 TransportProtocols: transportProtocols, 100 HandleLocal: options.HandleLocal, 101 Clock: &faketime.NullClock{}, 102 RawFactory: &raw.EndpointFactory{}, 103 } 104 105 s := stack.New(stackOptions) 106 // Disable ICMP rate limiter since we're using Null clock, which never 107 // advances time and thus never allows ICMP messages. 108 s.SetICMPLimit(rate.Inf) 109 ep := channel.New(256, options.MTU, "") 110 wep := stack.LinkEndpoint(ep) 111 112 if testing.Verbose() { 113 wep = sniffer.New(ep) 114 } 115 if err := s.CreateNIC(NICID, wep); err != nil { 116 t.Fatalf("CreateNIC(%d, _): %s", NICID, err) 117 } 118 119 protocolAddrV4 := tcpip.ProtocolAddress{ 120 Protocol: ipv4.ProtocolNumber, 121 AddressWithPrefix: tcpip.Address(StackAddr).WithPrefix(), 122 } 123 if err := s.AddProtocolAddress(NICID, protocolAddrV4, stack.AddressProperties{}); err != nil { 124 t.Fatalf("AddProtocolAddress(%d, %#v, {}): %s", NICID, protocolAddrV4, err) 125 } 126 127 protocolAddrV6 := tcpip.ProtocolAddress{ 128 Protocol: ipv6.ProtocolNumber, 129 AddressWithPrefix: tcpip.Address(StackV6Addr).WithPrefix(), 130 } 131 if err := s.AddProtocolAddress(NICID, protocolAddrV6, stack.AddressProperties{}); err != nil { 132 t.Fatalf("AddProtocolAddress(%d, %#v, {}): %s", NICID, protocolAddrV6, err) 133 } 134 135 s.SetRouteTable([]tcpip.Route{ 136 { 137 Destination: header.IPv4EmptySubnet, 138 NIC: NICID, 139 }, 140 { 141 Destination: header.IPv6EmptySubnet, 142 NIC: NICID, 143 }, 144 }) 145 146 return &Context{ 147 T: t, 148 Stack: s, 149 LinkEP: ep, 150 } 151 } 152 153 // Cleanup closes the context endpoint if required. 154 func (c *Context) Cleanup() { 155 _ = c.LinkEP.Drain() 156 if c.EP != nil { 157 c.EP.Close() 158 } 159 c.Stack.Destroy() 160 c.Stack = nil 161 refs.DoRepeatedLeakCheck() 162 } 163 164 // CreateEndpoint creates the Context's Endpoint. 165 func (c *Context) CreateEndpoint(network tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber) { 166 c.T.Helper() 167 168 var err tcpip.Error 169 c.EP, err = c.Stack.NewEndpoint(transport, network, &c.WQ) 170 if err != nil { 171 c.T.Fatalf("c.Stack.NewEndpoint(%d, %d, _) failed: %s", transport, network, err) 172 } 173 } 174 175 // CreateEndpointForFlow creates the Context's Endpoint and configured it 176 // according to the given TestFlow. 177 func (c *Context) CreateEndpointForFlow(flow TestFlow, transport tcpip.TransportProtocolNumber) { 178 c.T.Helper() 179 180 c.CreateEndpoint(flow.SockProto(), transport) 181 if flow.isV6Only() { 182 c.EP.SocketOptions().SetV6Only(true) 183 } else if flow.isBroadcast() { 184 c.EP.SocketOptions().SetBroadcast(true) 185 } 186 } 187 188 // CreateRawEndpoint creates the Context's Endpoint. 189 func (c *Context) CreateRawEndpoint(network tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber) { 190 c.T.Helper() 191 192 var err tcpip.Error 193 c.EP, err = c.Stack.NewRawEndpoint(transport, network, &c.WQ, true /* associated */) 194 if err != nil { 195 c.T.Fatal("c.Stack.NewRawEndpoint failed: ", err) 196 } 197 } 198 199 // CreateRawEndpointForFlow creates the Context's Endpoint and configured it 200 // according to the given TestFlow. 201 func (c *Context) CreateRawEndpointForFlow(flow TestFlow, transport tcpip.TransportProtocolNumber) { 202 c.T.Helper() 203 204 c.CreateRawEndpoint(flow.SockProto(), transport) 205 if flow.isV6Only() { 206 c.EP.SocketOptions().SetV6Only(true) 207 } else if flow.isBroadcast() { 208 c.EP.SocketOptions().SetBroadcast(true) 209 } 210 } 211 212 // CheckEndpointWriteStats checks that the write statistic related to the given 213 // error has been incremented as expected. 214 func (c *Context) CheckEndpointWriteStats(incr uint64, want *tcpip.TransportEndpointStats, err tcpip.Error) { 215 var got tcpip.TransportEndpointStats 216 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&got) 217 switch err.(type) { 218 case nil: 219 want.PacketsSent.IncrementBy(incr) 220 case *tcpip.ErrMessageTooLong, *tcpip.ErrInvalidOptionValue: 221 want.WriteErrors.InvalidArgs.IncrementBy(incr) 222 case *tcpip.ErrClosedForSend: 223 want.WriteErrors.WriteClosed.IncrementBy(incr) 224 case *tcpip.ErrInvalidEndpointState: 225 want.WriteErrors.InvalidEndpointState.IncrementBy(incr) 226 case *tcpip.ErrHostUnreachable, *tcpip.ErrBroadcastDisabled, *tcpip.ErrNetworkUnreachable: 227 want.SendErrors.NoRoute.IncrementBy(incr) 228 default: 229 want.SendErrors.SendToNetworkFailed.IncrementBy(incr) 230 } 231 if !reflect.DeepEqual(&got, want) { 232 c.T.Errorf("Endpoint stats not matching for error %s: got %#v, want %#v", err, &got, want) 233 } 234 } 235 236 // CheckEndpointReadStats checks that the read statistic related to the given 237 // error has been incremented as expected. 238 func (c *Context) CheckEndpointReadStats(incr uint64, want *tcpip.TransportEndpointStats, err tcpip.Error) { 239 c.T.Helper() 240 241 var got tcpip.TransportEndpointStats 242 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&got) 243 switch err.(type) { 244 case nil, *tcpip.ErrWouldBlock: 245 case *tcpip.ErrClosedForReceive: 246 want.ReadErrors.ReadClosed.IncrementBy(incr) 247 default: 248 c.T.Errorf("Endpoint error missing stats update for err %s", err) 249 } 250 if !reflect.DeepEqual(&got, want) { 251 c.T.Errorf("Endpoint stats not matching for error %s: got %#v, want %#v", err, &got, want) 252 } 253 } 254 255 // InjectPacket injects a packet into the context's link endpoint. 256 func (c *Context) InjectPacket(netProto tcpip.NetworkProtocolNumber, buf []byte) { 257 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 258 Payload: buffer.MakeWithData(buf), 259 }) 260 defer pkt.DecRef() 261 c.LinkEP.InjectInbound(netProto, pkt) 262 } 263 264 // readExpectations holds information about the expected outcome when reading 265 // from the context's endpoint. 266 type readExpectations struct { 267 nothingToRead bool 268 payload []byte 269 addresses Header4Tuple 270 readShouldFail bool 271 } 272 273 // readFromEndpoint attempts to read a packet from the endpoint and compares the 274 // outcome with the given expectations. 275 func (c *Context) readFromEndpoint(expectations readExpectations, checkers ...checker.ControlMessagesChecker) { 276 c.T.Helper() 277 278 // Try to receive the data. 279 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 280 c.WQ.EventRegister(&we) 281 defer c.WQ.EventUnregister(&we) 282 283 // Take a snapshot of the stats to validate them at the end of the test. 284 var epstats tcpip.TransportEndpointStats 285 c.EP.Stats().(*tcpip.TransportEndpointStats).Clone(&epstats) 286 287 var buf bytes.Buffer 288 res, err := c.EP.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) 289 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 290 select { 291 case <-ch: 292 res, err = c.EP.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) 293 default: 294 if expectations.nothingToRead { 295 return 296 } 297 c.T.Fatal("timed out waiting for data") 298 } 299 } 300 301 if expectations.readShouldFail && err != nil { 302 c.CheckEndpointReadStats(1, &epstats, err) 303 return 304 } 305 306 if err != nil { 307 c.T.Fatal("Read failed:", err) 308 } 309 310 if expectations.nothingToRead { 311 c.T.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) 312 } 313 314 // Check the read result. 315 if diff := cmp.Diff(tcpip.ReadResult{ 316 Count: buf.Len(), 317 Total: buf.Len(), 318 RemoteAddr: tcpip.FullAddress{Addr: expectations.addresses.Src.Addr}, 319 }, res, checker.IgnoreCmpPath( 320 "ControlMessages", // ControlMessages are checked below. 321 "RemoteAddr.NIC", 322 "RemoteAddr.Port", 323 )); diff != "" { 324 c.T.Fatalf("Read: unexpected result (-want +got):\n%s", diff) 325 } 326 327 // Check the payload. 328 v := buf.Bytes() 329 if !bytes.Equal(expectations.payload, v) { 330 c.T.Fatalf("got payload = %x, want = %x", v, expectations.payload) 331 } 332 333 // Run any checkers against the ControlMessages. 334 for _, f := range checkers { 335 f(c.T, res.ControlMessages) 336 } 337 338 c.CheckEndpointReadStats(1, &epstats, err) 339 } 340 341 // ReadFromEndpointExpectSuccess attempts to reads from the endpoint and 342 // performs checks on the received packet, according to the given flow and 343 // checkers. 344 func (c *Context) ReadFromEndpointExpectSuccess(payload []byte, flow TestFlow, checkers ...checker.ControlMessagesChecker) { 345 c.T.Helper() 346 347 c.readFromEndpoint(readExpectations{ 348 payload: payload, 349 addresses: flow.MakeHeader4Tuple(Incoming), 350 }, checkers...) 351 } 352 353 // ReadFromEndpointExpectNoPacket reads from the endpoint and checks that no 354 // packets was received. 355 func (c *Context) ReadFromEndpointExpectNoPacket() { 356 c.T.Helper() 357 358 c.readFromEndpoint(readExpectations{ 359 nothingToRead: true, 360 }) 361 } 362 363 // ReadFromEndpointExpectError reads from the endpoint and checks that an 364 // error was returned. 365 func (c *Context) ReadFromEndpointExpectError() { 366 c.T.Helper() 367 368 c.readFromEndpoint(readExpectations{ 369 readShouldFail: true, 370 }) 371 }