gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/tests/integration/iptables_test.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package iptables_test 16 17 import ( 18 "bytes" 19 "fmt" 20 "math" 21 "testing" 22 23 "github.com/google/go-cmp/cmp" 24 "gvisor.dev/gvisor/pkg/buffer" 25 "gvisor.dev/gvisor/pkg/tcpip" 26 "gvisor.dev/gvisor/pkg/tcpip/checker" 27 "gvisor.dev/gvisor/pkg/tcpip/checksum" 28 "gvisor.dev/gvisor/pkg/tcpip/header" 29 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 30 "gvisor.dev/gvisor/pkg/tcpip/link/loopback" 31 "gvisor.dev/gvisor/pkg/tcpip/network/arp" 32 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 33 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 34 "gvisor.dev/gvisor/pkg/tcpip/prependable" 35 "gvisor.dev/gvisor/pkg/tcpip/stack" 36 "gvisor.dev/gvisor/pkg/tcpip/tests/utils" 37 "gvisor.dev/gvisor/pkg/tcpip/testutil" 38 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 39 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 40 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 41 "gvisor.dev/gvisor/pkg/waiter" 42 ) 43 44 type inputIfNameMatcher struct { 45 name string 46 } 47 48 var _ stack.Matcher = (*inputIfNameMatcher)(nil) 49 50 func (*inputIfNameMatcher) Name() string { 51 return "inputIfNameMatcher" 52 } 53 54 func (im *inputIfNameMatcher) Match(hook stack.Hook, _ *stack.PacketBuffer, inNicName, _ string) (bool, bool) { 55 return (hook == stack.Input && im.name != "" && im.name == inNicName), false 56 } 57 58 const ( 59 nicID = 1 60 nicName = "nic1" 61 anotherNicName = "nic2" 62 linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") 63 payloadSize = 20 64 ) 65 66 var ( 67 srcAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x01")) 68 dstAddrV4 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x02")) 69 srcAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")) 70 dstAddrV6 = tcpip.AddrFromSlice([]byte("\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")) 71 ) 72 73 func genStackV6(t *testing.T) (*stack.Stack, *channel.Endpoint) { 74 t.Helper() 75 s := stack.New(stack.Options{ 76 NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, 77 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 78 }) 79 e := channel.New(0, header.IPv6MinimumMTU, linkAddr) 80 nicOpts := stack.NICOptions{Name: nicName} 81 if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { 82 t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) 83 } 84 protocolAddr := tcpip.ProtocolAddress{ 85 Protocol: header.IPv6ProtocolNumber, 86 AddressWithPrefix: dstAddrV6.WithPrefix(), 87 } 88 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 89 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 90 } 91 return s, e 92 } 93 94 func genStackV4(t *testing.T) (*stack.Stack, *channel.Endpoint) { 95 t.Helper() 96 s := stack.New(stack.Options{ 97 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, 98 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol}, 99 }) 100 e := channel.New(0, header.IPv4MinimumMTU, linkAddr) 101 nicOpts := stack.NICOptions{Name: nicName} 102 if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { 103 t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, nicOpts, err) 104 } 105 protocolAddr := tcpip.ProtocolAddress{ 106 Protocol: header.IPv4ProtocolNumber, 107 AddressWithPrefix: dstAddrV4.WithPrefix(), 108 } 109 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 110 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 111 } 112 return s, e 113 } 114 115 func genPacketV6() *stack.PacketBuffer { 116 pktSize := header.IPv6MinimumSize + payloadSize 117 hdr := prependable.New(pktSize) 118 ip := header.IPv6(hdr.Prepend(pktSize)) 119 ip.Encode(&header.IPv6Fields{ 120 PayloadLength: payloadSize, 121 TransportProtocol: 99, 122 HopLimit: 255, 123 SrcAddr: srcAddrV6, 124 DstAddr: dstAddrV6, 125 }) 126 buf := buffer.MakeWithData(hdr.View()) 127 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 128 } 129 130 func genPacketV4() *stack.PacketBuffer { 131 pktSize := header.IPv4MinimumSize + payloadSize 132 hdr := prependable.New(pktSize) 133 ip := header.IPv4(hdr.Prepend(pktSize)) 134 ip.Encode(&header.IPv4Fields{ 135 TOS: 0, 136 TotalLength: uint16(pktSize), 137 ID: 1, 138 Flags: 0, 139 FragmentOffset: 16, 140 TTL: 48, 141 Protocol: 99, 142 SrcAddr: srcAddrV4, 143 DstAddr: dstAddrV4, 144 }) 145 ip.SetChecksum(0) 146 ip.SetChecksum(^ip.CalculateChecksum()) 147 buf := buffer.MakeWithData(hdr.View()) 148 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 149 } 150 151 func TestIPTablesStatsForInput(t *testing.T) { 152 tests := []struct { 153 name string 154 setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) 155 setupFilter func(*testing.T, *stack.Stack) 156 genPacket func() *stack.PacketBuffer 157 proto tcpip.NetworkProtocolNumber 158 expectReceived int 159 expectInputDropped int 160 }{ 161 { 162 name: "IPv6 Accept", 163 setupStack: genStackV6, 164 setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, 165 genPacket: genPacketV6, 166 proto: header.IPv6ProtocolNumber, 167 expectReceived: 1, 168 expectInputDropped: 0, 169 }, 170 { 171 name: "IPv4 Accept", 172 setupStack: genStackV4, 173 setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, 174 genPacket: genPacketV4, 175 proto: header.IPv4ProtocolNumber, 176 expectReceived: 1, 177 expectInputDropped: 0, 178 }, 179 { 180 name: "IPv6 Drop (input interface matches)", 181 setupStack: genStackV6, 182 setupFilter: func(t *testing.T, s *stack.Stack) { 183 t.Helper() 184 ipt := s.IPTables() 185 filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) 186 ruleIdx := filter.BuiltinChains[stack.Input] 187 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} 188 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 189 filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} 190 // Make sure the packet is not dropped by the next rule. 191 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 192 ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */) 193 }, 194 genPacket: genPacketV6, 195 proto: header.IPv6ProtocolNumber, 196 expectReceived: 1, 197 expectInputDropped: 1, 198 }, 199 { 200 name: "IPv4 Drop (input interface matches)", 201 setupStack: genStackV4, 202 setupFilter: func(t *testing.T, s *stack.Stack) { 203 t.Helper() 204 ipt := s.IPTables() 205 filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) 206 ruleIdx := filter.BuiltinChains[stack.Input] 207 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: nicName} 208 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 209 filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} 210 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 211 ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */) 212 }, 213 genPacket: genPacketV4, 214 proto: header.IPv4ProtocolNumber, 215 expectReceived: 1, 216 expectInputDropped: 1, 217 }, 218 { 219 name: "IPv6 Accept (input interface does not match)", 220 setupStack: genStackV6, 221 setupFilter: func(t *testing.T, s *stack.Stack) { 222 t.Helper() 223 ipt := s.IPTables() 224 filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) 225 ruleIdx := filter.BuiltinChains[stack.Input] 226 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} 227 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 228 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 229 ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */) 230 }, 231 genPacket: genPacketV6, 232 proto: header.IPv6ProtocolNumber, 233 expectReceived: 1, 234 expectInputDropped: 0, 235 }, 236 { 237 name: "IPv4 Accept (input interface does not match)", 238 setupStack: genStackV4, 239 setupFilter: func(t *testing.T, s *stack.Stack) { 240 t.Helper() 241 ipt := s.IPTables() 242 filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) 243 ruleIdx := filter.BuiltinChains[stack.Input] 244 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{InputInterface: anotherNicName} 245 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 246 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 247 ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */) 248 }, 249 genPacket: genPacketV4, 250 proto: header.IPv4ProtocolNumber, 251 expectReceived: 1, 252 expectInputDropped: 0, 253 }, 254 { 255 name: "IPv6 Drop (input interface does not match but invert is true)", 256 setupStack: genStackV6, 257 setupFilter: func(t *testing.T, s *stack.Stack) { 258 t.Helper() 259 ipt := s.IPTables() 260 filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) 261 ruleIdx := filter.BuiltinChains[stack.Input] 262 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ 263 InputInterface: anotherNicName, 264 InputInterfaceInvert: true, 265 } 266 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 267 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 268 ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */) 269 }, 270 genPacket: genPacketV6, 271 proto: header.IPv6ProtocolNumber, 272 expectReceived: 1, 273 expectInputDropped: 1, 274 }, 275 { 276 name: "IPv4 Drop (input interface does not match but invert is true)", 277 setupStack: genStackV4, 278 setupFilter: func(t *testing.T, s *stack.Stack) { 279 t.Helper() 280 ipt := s.IPTables() 281 filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) 282 ruleIdx := filter.BuiltinChains[stack.Input] 283 filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{ 284 InputInterface: anotherNicName, 285 InputInterfaceInvert: true, 286 } 287 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 288 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 289 ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */) 290 }, 291 genPacket: genPacketV4, 292 proto: header.IPv4ProtocolNumber, 293 expectReceived: 1, 294 expectInputDropped: 1, 295 }, 296 { 297 name: "IPv6 Accept (input interface does not match using a matcher)", 298 setupStack: genStackV6, 299 setupFilter: func(t *testing.T, s *stack.Stack) { 300 t.Helper() 301 ipt := s.IPTables() 302 filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) 303 ruleIdx := filter.BuiltinChains[stack.Input] 304 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 305 filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} 306 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 307 ipt.ForceReplaceTable(stack.FilterID, filter, true /* ipv6 */) 308 }, 309 genPacket: genPacketV6, 310 proto: header.IPv6ProtocolNumber, 311 expectReceived: 1, 312 expectInputDropped: 0, 313 }, 314 { 315 name: "IPv4 Accept (input interface does not match using a matcher)", 316 setupStack: genStackV4, 317 setupFilter: func(t *testing.T, s *stack.Stack) { 318 t.Helper() 319 ipt := s.IPTables() 320 filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) 321 ruleIdx := filter.BuiltinChains[stack.Input] 322 filter.Rules[ruleIdx].Target = &stack.DropTarget{} 323 filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} 324 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 325 ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */) 326 }, 327 genPacket: genPacketV4, 328 proto: header.IPv4ProtocolNumber, 329 expectReceived: 1, 330 expectInputDropped: 0, 331 }, 332 } 333 334 for _, test := range tests { 335 t.Run(test.name, func(t *testing.T) { 336 s, e := test.setupStack(t) 337 defer s.Destroy() 338 test.setupFilter(t, s) 339 e.InjectInbound(test.proto, test.genPacket()) 340 341 if got := int(s.Stats().IP.PacketsReceived.Value()); got != test.expectReceived { 342 t.Errorf("got PacketReceived = %d, want = %d", got, test.expectReceived) 343 } 344 if got := int(s.Stats().IP.IPTablesInputDropped.Value()); got != test.expectInputDropped { 345 t.Errorf("got IPTablesInputDropped = %d, want = %d", got, test.expectInputDropped) 346 } 347 }) 348 } 349 } 350 351 var _ stack.LinkEndpoint = (*channelEndpoint)(nil) 352 353 type channelEndpoint struct { 354 *channel.Endpoint 355 356 t *testing.T 357 } 358 359 var _ stack.Matcher = (*udpSourcePortMatcher)(nil) 360 361 type udpSourcePortMatcher struct { 362 port uint16 363 } 364 365 func (*udpSourcePortMatcher) Name() string { 366 return "udpSourcePortMatcher" 367 } 368 369 func (m *udpSourcePortMatcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches, hotdrop bool) { 370 udp := header.UDP(pkt.TransportHeader().Slice()) 371 if len(udp) < header.UDPMinimumSize { 372 // Drop immediately as the packet is invalid. 373 return false, true 374 } 375 376 return udp.SourcePort() == m.port, false 377 } 378 379 func TestIPTableWritePackets(t *testing.T) { 380 const ( 381 nicID = 1 382 383 dropLocalPort = utils.LocalPort - 1 384 acceptPackets = 2 385 dropPackets = 3 386 ) 387 388 udpHdr := func(hdr []byte, srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16) { 389 u := header.UDP(hdr) 390 u.Encode(&header.UDPFields{ 391 SrcPort: srcPort, 392 DstPort: dstPort, 393 Length: header.UDPMinimumSize, 394 }) 395 sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, srcAddr, dstAddr, header.UDPMinimumSize) 396 sum = checksum.Checksum(hdr, sum) 397 u.SetChecksum(^u.CalculateChecksum(sum)) 398 } 399 400 tests := []struct { 401 name string 402 setupFilter func(*testing.T, *stack.Stack) 403 genPacket func(*stack.Route) stack.PacketBufferList 404 proto tcpip.NetworkProtocolNumber 405 remoteAddr tcpip.Address 406 expectSent uint64 407 expectOutputDropped uint64 408 }{ 409 { 410 name: "IPv4 Accept", 411 setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, 412 genPacket: func(r *stack.Route) stack.PacketBufferList { 413 var pkts stack.PacketBufferList 414 415 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 416 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 417 }) 418 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 419 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) 420 pkts.PushBack(pkt) 421 422 return pkts 423 }, 424 proto: header.IPv4ProtocolNumber, 425 remoteAddr: dstAddrV4, 426 expectSent: 1, 427 expectOutputDropped: 0, 428 }, 429 { 430 name: "IPv4 Drop Other Port", 431 setupFilter: func(t *testing.T, s *stack.Stack) { 432 t.Helper() 433 434 table := stack.Table{ 435 Rules: []stack.Rule{ 436 { 437 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, 438 }, 439 { 440 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, 441 }, 442 { 443 Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, 444 Target: &stack.DropTarget{NetworkProtocol: header.IPv4ProtocolNumber}, 445 }, 446 { 447 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}, 448 }, 449 { 450 Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}, 451 }, 452 }, 453 BuiltinChains: [stack.NumHooks]int{ 454 stack.Prerouting: stack.HookUnset, 455 stack.Input: 0, 456 stack.Forward: 1, 457 stack.Output: 2, 458 stack.Postrouting: stack.HookUnset, 459 }, 460 Underflows: [stack.NumHooks]int{ 461 stack.Prerouting: stack.HookUnset, 462 stack.Input: 0, 463 stack.Forward: 1, 464 stack.Output: 2, 465 stack.Postrouting: stack.HookUnset, 466 }, 467 } 468 469 s.IPTables().ForceReplaceTable(stack.FilterID, table, false /* ipv4 */) 470 }, 471 genPacket: func(r *stack.Route) stack.PacketBufferList { 472 var pkts stack.PacketBufferList 473 474 for i := 0; i < acceptPackets; i++ { 475 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 476 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 477 }) 478 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 479 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) 480 pkts.PushBack(pkt) 481 } 482 for i := 0; i < dropPackets; i++ { 483 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 484 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 485 }) 486 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 487 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) 488 pkts.PushBack(pkt) 489 } 490 491 return pkts 492 }, 493 proto: header.IPv4ProtocolNumber, 494 remoteAddr: dstAddrV4, 495 expectSent: acceptPackets, 496 expectOutputDropped: dropPackets, 497 }, 498 { 499 name: "IPv6 Accept", 500 setupFilter: func(*testing.T, *stack.Stack) { /* no filter */ }, 501 genPacket: func(r *stack.Route) stack.PacketBufferList { 502 var pkts stack.PacketBufferList 503 504 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 505 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 506 }) 507 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 508 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) 509 pkts.PushBack(pkt) 510 511 return pkts 512 }, 513 proto: header.IPv6ProtocolNumber, 514 remoteAddr: dstAddrV6, 515 expectSent: 1, 516 expectOutputDropped: 0, 517 }, 518 { 519 name: "IPv6 Drop Other Port", 520 setupFilter: func(t *testing.T, s *stack.Stack) { 521 t.Helper() 522 523 table := stack.Table{ 524 Rules: []stack.Rule{ 525 { 526 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, 527 }, 528 { 529 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, 530 }, 531 { 532 Matchers: []stack.Matcher{&udpSourcePortMatcher{port: dropLocalPort}}, 533 Target: &stack.DropTarget{NetworkProtocol: header.IPv6ProtocolNumber}, 534 }, 535 { 536 Target: &stack.AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}, 537 }, 538 { 539 Target: &stack.ErrorTarget{NetworkProtocol: header.IPv6ProtocolNumber}, 540 }, 541 }, 542 BuiltinChains: [stack.NumHooks]int{ 543 stack.Prerouting: stack.HookUnset, 544 stack.Input: 0, 545 stack.Forward: 1, 546 stack.Output: 2, 547 stack.Postrouting: stack.HookUnset, 548 }, 549 Underflows: [stack.NumHooks]int{ 550 stack.Prerouting: stack.HookUnset, 551 stack.Input: 0, 552 stack.Forward: 1, 553 stack.Output: 2, 554 stack.Postrouting: stack.HookUnset, 555 }, 556 } 557 558 s.IPTables().ForceReplaceTable(stack.FilterID, table, true /* ipv6 */) 559 }, 560 genPacket: func(r *stack.Route) stack.PacketBufferList { 561 var pkts stack.PacketBufferList 562 563 for i := 0; i < acceptPackets; i++ { 564 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 565 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 566 }) 567 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 568 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), utils.LocalPort, utils.RemotePort) 569 pkts.PushBack(pkt) 570 } 571 for i := 0; i < dropPackets; i++ { 572 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 573 ReserveHeaderBytes: int(r.MaxHeaderLength() + header.UDPMinimumSize), 574 }) 575 hdr := pkt.TransportHeader().Push(header.UDPMinimumSize) 576 udpHdr(hdr, r.LocalAddress(), r.RemoteAddress(), dropLocalPort, utils.RemotePort) 577 pkts.PushBack(pkt) 578 } 579 580 return pkts 581 }, 582 proto: header.IPv6ProtocolNumber, 583 remoteAddr: dstAddrV6, 584 expectSent: acceptPackets, 585 expectOutputDropped: dropPackets, 586 }, 587 } 588 589 for _, test := range tests { 590 t.Run(test.name, func(t *testing.T) { 591 s := stack.New(stack.Options{ 592 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 593 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 594 }) 595 defer s.Destroy() 596 e := channelEndpoint{ 597 Endpoint: channel.New(4, header.IPv6MinimumMTU, linkAddr), 598 t: t, 599 } 600 if err := s.CreateNIC(nicID, &e); err != nil { 601 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 602 } 603 protocolAddrV6 := tcpip.ProtocolAddress{ 604 Protocol: header.IPv6ProtocolNumber, 605 AddressWithPrefix: srcAddrV6.WithPrefix(), 606 } 607 if err := s.AddProtocolAddress(nicID, protocolAddrV6, stack.AddressProperties{}); err != nil { 608 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV6, err) 609 } 610 protocolAddrV4 := tcpip.ProtocolAddress{ 611 Protocol: header.IPv4ProtocolNumber, 612 AddressWithPrefix: srcAddrV4.WithPrefix(), 613 } 614 if err := s.AddProtocolAddress(nicID, protocolAddrV4, stack.AddressProperties{}); err != nil { 615 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddrV4, err) 616 } 617 618 s.SetRouteTable([]tcpip.Route{ 619 { 620 Destination: header.IPv4EmptySubnet, 621 NIC: nicID, 622 }, 623 { 624 Destination: header.IPv6EmptySubnet, 625 NIC: nicID, 626 }, 627 }) 628 629 test.setupFilter(t, s) 630 631 r, err := s.FindRoute(nicID, tcpip.Address{}, test.remoteAddr, test.proto, false) 632 if err != nil { 633 t.Fatalf("FindRoute(%d, '', %s, %d, false): %s", nicID, test.remoteAddr, test.proto, err) 634 } 635 defer r.Release() 636 637 pkts := test.genPacket(r) 638 for _, pkt := range pkts.AsSlice() { 639 if err := r.WritePacket(stack.NetworkHeaderParams{ 640 Protocol: header.UDPProtocolNumber, 641 TTL: 64, 642 }, pkt); err != nil { 643 t.Fatalf("WritePacket(...): %s", err) 644 } 645 pkt.DecRef() 646 } 647 648 if got := s.Stats().IP.PacketsSent.Value(); got != test.expectSent { 649 t.Errorf("got PacketSent = %d, want = %d", got, test.expectSent) 650 } 651 if got := s.Stats().IP.IPTablesOutputDropped.Value(); got != test.expectOutputDropped { 652 t.Errorf("got IPTablesOutputDropped = %d, want = %d", got, test.expectOutputDropped) 653 } 654 }) 655 } 656 } 657 658 const ttl = 64 659 660 var ( 661 ipv4GlobalMulticastAddr = testutil.MustParse4("224.0.1.10") 662 ipv6GlobalMulticastAddr = testutil.MustParse6("ff0e::a") 663 ) 664 665 func rxICMPv4EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { 666 utils.RxICMPv4EchoReply(e, src, dst, ttl) 667 } 668 669 func rxICMPv6EchoReply(e *channel.Endpoint, src, dst tcpip.Address) { 670 utils.RxICMPv6EchoReply(e, src, dst, ttl) 671 } 672 673 func forwardedICMPv4EchoReplyChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) { 674 checker.IPv4(t, v, 675 checker.SrcAddr(src), 676 checker.DstAddr(dst), 677 checker.TTL(ttl-1), 678 checker.ICMPv4( 679 checker.ICMPv4Type(header.ICMPv4EchoReply))) 680 } 681 682 func forwardedICMPv6EchoReplyChecker(t *testing.T, v *buffer.View, src, dst tcpip.Address) { 683 checker.IPv6(t, v, 684 checker.SrcAddr(src), 685 checker.DstAddr(dst), 686 checker.TTL(ttl-1), 687 checker.ICMPv6( 688 checker.ICMPv6Type(header.ICMPv6EchoReply))) 689 } 690 691 func boolToInt(v bool) uint64 { 692 if v { 693 return 1 694 } 695 return 0 696 } 697 698 func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { 699 return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { 700 t.Helper() 701 702 ipv6 := netProto == ipv6.ProtocolNumber 703 704 ipt := s.IPTables() 705 filter := ipt.GetTable(stack.FilterID, ipv6) 706 ruleIdx := filter.BuiltinChains[hook] 707 filter.Rules[ruleIdx].Filter = f 708 filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} 709 // Make sure the packet is not dropped by the next rule. 710 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} 711 ipt.ForceReplaceTable(stack.FilterID, filter, ipv6) 712 } 713 } 714 715 func TestForwardingHook(t *testing.T) { 716 const ( 717 nicID1 = 1 718 nicID2 = 2 719 720 nic1Name = "nic1" 721 nic2Name = "nic2" 722 723 otherNICName = "otherNIC" 724 ) 725 726 tests := []struct { 727 name string 728 netProto tcpip.NetworkProtocolNumber 729 local bool 730 srcAddr, dstAddr tcpip.Address 731 rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) 732 checker func(*testing.T, *buffer.View) 733 }{ 734 { 735 name: "IPv4 remote", 736 netProto: ipv4.ProtocolNumber, 737 local: false, 738 srcAddr: utils.RemoteIPv4Addr, 739 dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, 740 rx: rxICMPv4EchoReply, 741 checker: func(t *testing.T, v *buffer.View) { 742 forwardedICMPv4EchoReplyChecker(t, v, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) 743 }, 744 }, 745 { 746 name: "IPv4 local", 747 netProto: ipv4.ProtocolNumber, 748 local: true, 749 srcAddr: utils.RemoteIPv4Addr, 750 dstAddr: utils.Ipv4Addr.Address, 751 rx: rxICMPv4EchoReply, 752 }, 753 { 754 name: "IPv6 remote", 755 netProto: ipv6.ProtocolNumber, 756 local: false, 757 srcAddr: utils.RemoteIPv6Addr, 758 dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, 759 rx: rxICMPv6EchoReply, 760 checker: func(t *testing.T, v *buffer.View) { 761 forwardedICMPv6EchoReplyChecker(t, v, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) 762 }, 763 }, 764 { 765 name: "IPv6 local", 766 netProto: ipv6.ProtocolNumber, 767 local: true, 768 srcAddr: utils.RemoteIPv6Addr, 769 dstAddr: utils.Ipv6Addr.Address, 770 rx: rxICMPv6EchoReply, 771 }, 772 } 773 774 subTests := []struct { 775 name string 776 setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) 777 expectForward bool 778 }{ 779 { 780 name: "Accept", 781 setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, 782 expectForward: true, 783 }, 784 785 { 786 name: "Drop", 787 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), 788 expectForward: false, 789 }, 790 { 791 name: "Drop with input NIC filtering", 792 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), 793 expectForward: false, 794 }, 795 { 796 name: "Drop with output NIC filtering", 797 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), 798 expectForward: false, 799 }, 800 { 801 name: "Drop with input and output NIC filtering", 802 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), 803 expectForward: false, 804 }, 805 806 { 807 name: "Drop with other input NIC filtering", 808 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), 809 expectForward: true, 810 }, 811 { 812 name: "Drop with other output NIC filtering", 813 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), 814 expectForward: true, 815 }, 816 { 817 name: "Drop with other input and output NIC filtering", 818 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), 819 expectForward: true, 820 }, 821 { 822 name: "Drop with input and other output NIC filtering", 823 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), 824 expectForward: true, 825 }, 826 { 827 name: "Drop with other input and other output NIC filtering", 828 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), 829 expectForward: true, 830 }, 831 832 { 833 name: "Drop with inverted input NIC filtering", 834 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), 835 expectForward: true, 836 }, 837 { 838 name: "Drop with inverted output NIC filtering", 839 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), 840 expectForward: true, 841 }, 842 } 843 844 for _, test := range tests { 845 t.Run(test.name, func(t *testing.T) { 846 for _, subTest := range subTests { 847 t.Run(subTest.name, func(t *testing.T) { 848 s := stack.New(stack.Options{ 849 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 850 }) 851 defer s.Destroy() 852 853 subTest.setupFilter(t, s, test.netProto) 854 855 e1 := channel.New(1, header.IPv6MinimumMTU, "") 856 if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { 857 t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) 858 } 859 860 e2 := channel.New(1, header.IPv6MinimumMTU, "") 861 if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { 862 t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) 863 } 864 865 protocolAddrV4 := tcpip.ProtocolAddress{ 866 Protocol: ipv4.ProtocolNumber, 867 AddressWithPrefix: utils.Ipv4Addr.Address.WithPrefix(), 868 } 869 if err := s.AddProtocolAddress(nicID2, protocolAddrV4, stack.AddressProperties{}); err != nil { 870 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV4, err) 871 } 872 protocolAddrV6 := tcpip.ProtocolAddress{ 873 Protocol: ipv6.ProtocolNumber, 874 AddressWithPrefix: utils.Ipv6Addr.Address.WithPrefix(), 875 } 876 if err := s.AddProtocolAddress(nicID2, protocolAddrV6, stack.AddressProperties{}); err != nil { 877 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID2, protocolAddrV6, err) 878 } 879 880 if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { 881 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) 882 } 883 if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { 884 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) 885 } 886 887 s.SetRouteTable([]tcpip.Route{ 888 { 889 Destination: header.IPv4EmptySubnet, 890 NIC: nicID2, 891 }, 892 { 893 Destination: header.IPv6EmptySubnet, 894 NIC: nicID2, 895 }, 896 }) 897 898 test.rx(e1, test.srcAddr, test.dstAddr) 899 900 expectTransmitPacket := subTest.expectForward && !test.local 901 902 ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) 903 if err != nil { 904 t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) 905 } 906 ep1Stats := ep1.Stats() 907 ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) 908 if !ok { 909 t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) 910 } 911 ip1Stats := ipEP1Stats.IPStats() 912 913 if got := ip1Stats.PacketsReceived.Value(); got != 1 { 914 t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) 915 } 916 if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { 917 t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) 918 } 919 if got, want := ip1Stats.IPTablesForwardDropped.Value(), boolToInt(!subTest.expectForward); got != want { 920 t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, want) 921 } 922 if got := ip1Stats.PacketsSent.Value(); got != 0 { 923 t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = 0", got) 924 } 925 926 ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) 927 if err != nil { 928 t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) 929 } 930 ep2Stats := ep2.Stats() 931 ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) 932 if !ok { 933 t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) 934 } 935 ip2Stats := ipEP2Stats.IPStats() 936 if got := ip2Stats.PacketsReceived.Value(); got != 0 { 937 t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) 938 } 939 if got, want := ip2Stats.ValidPacketsReceived.Value(), boolToInt(subTest.expectForward && test.local); got != want { 940 t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, want) 941 } 942 if got, want := ip2Stats.PacketsSent.Value(), boolToInt(expectTransmitPacket); got != want { 943 t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = %d", got, want) 944 } 945 946 p := e2.Read() 947 if (p != nil) != expectTransmitPacket { 948 t.Fatalf("got e2.Read() = %#v, want = (_ == nil) = %t", p, expectTransmitPacket) 949 } 950 if expectTransmitPacket { 951 payload := stack.PayloadSince(p.NetworkHeader()) 952 defer payload.Release() 953 test.checker(t, payload) 954 p.DecRef() 955 } 956 }) 957 } 958 }) 959 } 960 } 961 962 func TestFilteringEchoPacketsWithLocalForwarding(t *testing.T) { 963 const ( 964 nicID1 = 1 965 nicID2 = 2 966 967 nic1Name = "nic1" 968 nic2Name = "nic2" 969 970 otherNICName = "otherNIC" 971 ) 972 973 tests := []struct { 974 name string 975 netProto tcpip.NetworkProtocolNumber 976 rx func(*channel.Endpoint) 977 checker func(*testing.T, *buffer.View) 978 }{ 979 { 980 name: "IPv4", 981 netProto: ipv4.ProtocolNumber, 982 rx: func(e *channel.Endpoint) { 983 utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) 984 }, 985 checker: func(t *testing.T, v *buffer.View) { 986 checker.IPv4(t, v, 987 checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), 988 checker.DstAddr(utils.RemoteIPv4Addr), 989 checker.ICMPv4( 990 checker.ICMPv4Type(header.ICMPv4EchoReply))) 991 }, 992 }, 993 { 994 name: "IPv6", 995 netProto: ipv6.ProtocolNumber, 996 rx: func(e *channel.Endpoint) { 997 utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) 998 }, 999 checker: func(t *testing.T, v *buffer.View) { 1000 checker.IPv6(t, v, 1001 checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), 1002 checker.DstAddr(utils.RemoteIPv6Addr), 1003 checker.ICMPv6( 1004 checker.ICMPv6Type(header.ICMPv6EchoReply))) 1005 }, 1006 }, 1007 } 1008 1009 type droppedEcho int 1010 const ( 1011 _ droppedEcho = iota 1012 noneDropped 1013 echoRequestDroppedAtInput 1014 echoRequestDroppedAtForward 1015 echoReplyDropped 1016 ) 1017 1018 subTests := []struct { 1019 name string 1020 setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) 1021 expectResult droppedEcho 1022 }{ 1023 { 1024 name: "Accept", 1025 setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, 1026 expectResult: noneDropped, 1027 }, 1028 1029 { 1030 name: "Input Drop", 1031 setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), 1032 expectResult: echoRequestDroppedAtInput, 1033 }, 1034 { 1035 name: "Input Drop with input NIC filtering on arrival NIC", 1036 setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), 1037 expectResult: echoRequestDroppedAtInput, 1038 }, 1039 { 1040 name: "Input Drop with input NIC filtering on delivered NIC", 1041 setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), 1042 expectResult: noneDropped, 1043 }, 1044 1045 { 1046 name: "Input Drop with input NIC filtering on other NIC", 1047 setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), 1048 expectResult: noneDropped, 1049 }, 1050 1051 { 1052 name: "Forward Drop", 1053 setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), 1054 expectResult: echoRequestDroppedAtForward, 1055 }, 1056 1057 { 1058 name: "Output Drop", 1059 setupFilter: setupDropFilter(stack.Output, stack.IPHeaderFilter{}), 1060 expectResult: echoReplyDropped, 1061 }, 1062 } 1063 1064 for _, test := range tests { 1065 t.Run(test.name, func(t *testing.T) { 1066 for _, subTest := range subTests { 1067 t.Run(subTest.name, func(t *testing.T) { 1068 s := stack.New(stack.Options{ 1069 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 1070 }) 1071 defer s.Destroy() 1072 1073 subTest.setupFilter(t, s, test.netProto) 1074 1075 e1 := channel.New(1, header.IPv6MinimumMTU, "") 1076 if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { 1077 t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) 1078 } 1079 if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1, stack.AddressProperties{}); err != nil { 1080 t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv4Addr1, err) 1081 } 1082 if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1, stack.AddressProperties{}); err != nil { 1083 t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID1, utils.Ipv6Addr1, err) 1084 } 1085 1086 e2 := channel.New(1, header.IPv6MinimumMTU, "") 1087 if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { 1088 t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) 1089 } 1090 if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2, stack.AddressProperties{}); err != nil { 1091 t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv4Addr2, err) 1092 } 1093 if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2, stack.AddressProperties{}); err != nil { 1094 t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID2, utils.Ipv6Addr2, err) 1095 } 1096 1097 if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { 1098 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) 1099 } 1100 if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { 1101 t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) 1102 } 1103 1104 s.SetRouteTable([]tcpip.Route{ 1105 { 1106 Destination: header.IPv4EmptySubnet, 1107 NIC: nicID1, 1108 }, 1109 { 1110 Destination: header.IPv6EmptySubnet, 1111 NIC: nicID1, 1112 }, 1113 }) 1114 1115 test.rx(e1) 1116 1117 ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) 1118 if err != nil { 1119 t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) 1120 } 1121 ep1Stats := ep1.Stats() 1122 ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) 1123 if !ok { 1124 t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) 1125 } 1126 ip1Stats := ipEP1Stats.IPStats() 1127 1128 if got := ip1Stats.PacketsReceived.Value(); got != 1 { 1129 t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) 1130 } 1131 if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { 1132 t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) 1133 } 1134 1135 expectedIP1StatIPTablesForawrdDropped := uint64(0) 1136 expectedIP1StatIPTablesOutputDropped := uint64(0) 1137 expectedIP1StatPacketsSent := uint64(0) 1138 expectedIP2StatValidPacketsReceived := uint64(1) 1139 expectedIP2StatIPTablesInputDropped := uint64(0) 1140 switch subTest.expectResult { 1141 case noneDropped: 1142 expectedIP1StatPacketsSent = 1 1143 case echoRequestDroppedAtInput: 1144 expectedIP2StatIPTablesInputDropped = 1 1145 case echoRequestDroppedAtForward: 1146 expectedIP1StatIPTablesForawrdDropped = 1 1147 expectedIP2StatValidPacketsReceived = 0 1148 case echoReplyDropped: 1149 expectedIP1StatIPTablesOutputDropped = 1 1150 default: 1151 t.Fatalf("unhandled expectResult = %d", subTest.expectResult) 1152 } 1153 1154 if got := ip1Stats.IPTablesForwardDropped.Value(); got != expectedIP1StatIPTablesForawrdDropped { 1155 t.Errorf("got ip1Stats.IPTablesForwardDropped.Value() = %d, want = %d", got, expectedIP1StatIPTablesForawrdDropped) 1156 } 1157 if got := ip1Stats.IPTablesOutputDropped.Value(); got != expectedIP1StatIPTablesOutputDropped { 1158 t.Errorf("got ip1Stats.IPTablesOutputDropped.Value() = %d, want = %d", got, expectedIP1StatIPTablesOutputDropped) 1159 } 1160 if got := ip1Stats.PacketsSent.Value(); got != expectedIP1StatPacketsSent { 1161 t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, expectedIP1StatPacketsSent) 1162 } 1163 1164 ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) 1165 if err != nil { 1166 t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) 1167 } 1168 ep2Stats := ep2.Stats() 1169 ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) 1170 if !ok { 1171 t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) 1172 } 1173 ip2Stats := ipEP2Stats.IPStats() 1174 if got := ip2Stats.PacketsReceived.Value(); got != 0 { 1175 t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) 1176 } 1177 if got := ip2Stats.ValidPacketsReceived.Value(); got != expectedIP2StatValidPacketsReceived { 1178 t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = %d", got, expectedIP2StatValidPacketsReceived) 1179 } 1180 if got := ip2Stats.IPTablesInputDropped.Value(); got != expectedIP2StatIPTablesInputDropped { 1181 t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, expectedIP2StatIPTablesInputDropped) 1182 } 1183 if got := ip2Stats.PacketsSent.Value(); got != 0 { 1184 t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) 1185 } 1186 1187 expectPacket := subTest.expectResult == noneDropped 1188 p := e1.Read() 1189 if (p != nil) != expectPacket { 1190 t.Errorf("got e1.Read() = %#v, want = (_ == nil) = %t", p, expectPacket) 1191 } 1192 if p != nil { 1193 payload := stack.PayloadSince(p.NetworkHeader()) 1194 defer payload.Release() 1195 test.checker(t, payload) 1196 p.DecRef() 1197 } 1198 if p := e2.Read(); p != nil { 1199 t.Errorf("got e1.Read() = %#v, want = nil)", p) 1200 p.DecRef() 1201 } 1202 }) 1203 } 1204 }) 1205 } 1206 } 1207 1208 func setupNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, hook stack.Hook, filter stack.IPHeaderFilter, target stack.Target) { 1209 t.Helper() 1210 1211 ipv6 := netProto == ipv6.ProtocolNumber 1212 ipt := s.IPTables() 1213 table := ipt.GetTable(stack.NATID, ipv6) 1214 ruleIdx := table.BuiltinChains[hook] 1215 table.Rules[ruleIdx].Filter = filter 1216 table.Rules[ruleIdx].Target = target 1217 // Make sure the packet is not dropped by the next rule. 1218 table.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 1219 ipt.ForceReplaceTable(stack.NATID, table, ipv6) 1220 } 1221 1222 func setupDNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { 1223 t.Helper() 1224 1225 setupNAT( 1226 t, 1227 s, 1228 netProto, 1229 stack.Prerouting, 1230 stack.IPHeaderFilter{ 1231 Protocol: transProto, 1232 CheckProtocol: true, 1233 InputInterface: utils.RouterNIC2Name, 1234 }, 1235 target) 1236 } 1237 1238 func setupSNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, target stack.Target) { 1239 t.Helper() 1240 1241 setupNAT( 1242 t, 1243 s, 1244 netProto, 1245 stack.Postrouting, 1246 stack.IPHeaderFilter{ 1247 Protocol: transProto, 1248 CheckProtocol: true, 1249 OutputInterface: utils.RouterNIC1Name, 1250 }, 1251 target) 1252 } 1253 1254 func setupTwiceNAT(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, dnatAddr tcpip.Address, dnatTarget, snatTarget stack.Target) { 1255 t.Helper() 1256 1257 ipv6 := netProto == ipv6.ProtocolNumber 1258 ipt := s.IPTables() 1259 1260 table := stack.Table{ 1261 Rules: []stack.Rule{ 1262 // Prerouting 1263 { 1264 Filter: stack.IPHeaderFilter{ 1265 Protocol: transProto, 1266 CheckProtocol: true, 1267 InputInterface: utils.RouterNIC2Name, 1268 }, 1269 Target: dnatTarget, 1270 }, 1271 { 1272 Target: &stack.AcceptTarget{}, 1273 }, 1274 1275 // Input 1276 { 1277 Target: &stack.AcceptTarget{}, 1278 }, 1279 1280 // Forward 1281 { 1282 Target: &stack.AcceptTarget{}, 1283 }, 1284 1285 // Output 1286 { 1287 Target: &stack.AcceptTarget{}, 1288 }, 1289 1290 // Postrouting 1291 { 1292 Filter: stack.IPHeaderFilter{ 1293 Protocol: transProto, 1294 CheckProtocol: true, 1295 OutputInterface: utils.RouterNIC1Name, 1296 }, 1297 Target: snatTarget, 1298 }, 1299 { 1300 Target: &stack.AcceptTarget{}, 1301 }, 1302 }, 1303 BuiltinChains: [stack.NumHooks]int{ 1304 stack.Prerouting: 0, 1305 stack.Input: 2, 1306 stack.Forward: 3, 1307 stack.Output: 4, 1308 stack.Postrouting: 5, 1309 }, 1310 } 1311 1312 ipt.ForceReplaceTable(stack.NATID, table, ipv6) 1313 } 1314 1315 type natType struct { 1316 name string 1317 setupNAT func(_ *testing.T, _ *stack.Stack, _ tcpip.NetworkProtocolNumber, _ tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16) 1318 } 1319 1320 var ( 1321 snatTypes = []natType{ 1322 { 1323 name: "SNAT", 1324 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, _ tcpip.Address, _ uint16) { 1325 t.Helper() 1326 1327 setupSNAT(t, s, netProto, transProto, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr, ChangeAddress: true, ChangePort: true}) 1328 }, 1329 }, 1330 { 1331 name: "Masquerade", 1332 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address, _ uint16) { 1333 t.Helper() 1334 1335 setupSNAT(t, s, netProto, transProto, &stack.MasqueradeTarget{NetworkProtocol: netProto}) 1336 }, 1337 }, 1338 } 1339 1340 dnatTarget = natType{ 1341 name: "DNAT", 1342 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, dnatAddr tcpip.Address, dnatPort uint16) { 1343 t.Helper() 1344 1345 setupDNAT(t, s, netProto, transProto, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true}) 1346 }, 1347 } 1348 1349 dnatTypes = []natType{ 1350 { 1351 name: "Redirect", 1352 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, _, _ tcpip.Address, dnatPort uint16) { 1353 t.Helper() 1354 1355 setupDNAT(t, s, netProto, transProto, &stack.RedirectTarget{NetworkProtocol: netProto, Port: dnatPort}) 1356 }, 1357 }, 1358 dnatTarget, 1359 } 1360 1361 twiceNATTypes = []natType{ 1362 { 1363 name: "DNAT-Masquerade", 1364 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16) { 1365 t.Helper() 1366 1367 setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true}, &stack.MasqueradeTarget{NetworkProtocol: netProto}) 1368 }, 1369 }, 1370 { 1371 name: "DNAT-SNAT", 1372 setupNAT: func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, snatAddr, dnatAddr tcpip.Address, dnatPort uint16) { 1373 t.Helper() 1374 1375 setupTwiceNAT(t, s, netProto, transProto, dnatAddr, &stack.DNATTarget{NetworkProtocol: netProto, Addr: dnatAddr, Port: dnatPort, ChangeAddress: true, ChangePort: true}, &stack.SNATTarget{NetworkProtocol: netProto, Addr: snatAddr, ChangeAddress: true, ChangePort: true}) 1376 }, 1377 }, 1378 } 1379 ) 1380 1381 func TestNATEcho(t *testing.T) { 1382 const ident = 1 1383 1384 v4EchoPkt := func(srcAddr, dstAddr tcpip.Address, reply bool) []byte { 1385 icmpType := header.ICMPv4Echo 1386 if reply { 1387 icmpType = header.ICMPv4EchoReply 1388 } 1389 1390 return icmpv4Packet(srcAddr, dstAddr, icmpType, ident) 1391 } 1392 1393 checkV4EchoPkt := func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool) { 1394 t.Helper() 1395 1396 icmpType := header.ICMPv4Echo 1397 if reply { 1398 icmpType = header.ICMPv4EchoReply 1399 } 1400 1401 checker.IPv4(t, v, 1402 checker.SrcAddr(srcAddr), 1403 checker.DstAddr(dstAddr), 1404 checker.ICMPv4( 1405 checker.ICMPv4Type(icmpType), 1406 checker.ICMPv4Checksum(), 1407 ), 1408 ) 1409 } 1410 1411 v6EchoPkt := func(srcAddr, dstAddr tcpip.Address, reply bool) []byte { 1412 icmpType := header.ICMPv6EchoRequest 1413 if reply { 1414 icmpType = header.ICMPv6EchoReply 1415 } 1416 1417 return icmpv6Packet(srcAddr, dstAddr, icmpType, ident) 1418 } 1419 1420 checkV6EchoPkt := func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool) { 1421 t.Helper() 1422 1423 icmpType := header.ICMPv6EchoRequest 1424 if reply { 1425 icmpType = header.ICMPv6EchoReply 1426 } 1427 1428 checker.IPv6(t, v, 1429 checker.SrcAddr(srcAddr), 1430 checker.DstAddr(dstAddr), 1431 checker.ICMPv6( 1432 checker.ICMPv6Type(icmpType), 1433 ), 1434 ) 1435 } 1436 1437 type natTypeTest struct { 1438 name string 1439 natTypes []natType 1440 requestSrc, requestDst tcpip.Address 1441 expectedRequestSrc, expectedRequestDst tcpip.Address 1442 } 1443 1444 tests := []struct { 1445 name string 1446 netProto tcpip.NetworkProtocolNumber 1447 transProto tcpip.TransportProtocolNumber 1448 echoPkt func(srcAddr, dstAddr tcpip.Address, reply bool) []byte 1449 checkEchoPkt func(t *testing.T, v *buffer.View, srcAddr, dstAddr tcpip.Address, reply bool) 1450 1451 natTypes []natTypeTest 1452 }{ 1453 { 1454 name: "IPv4", 1455 netProto: header.IPv4ProtocolNumber, 1456 transProto: header.ICMPv4ProtocolNumber, 1457 echoPkt: v4EchoPkt, 1458 checkEchoPkt: checkV4EchoPkt, 1459 1460 natTypes: []natTypeTest{ 1461 { 1462 name: "SNAT", 1463 natTypes: snatTypes, 1464 requestSrc: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1465 requestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1466 expectedRequestSrc: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 1467 expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1468 }, 1469 { 1470 name: "DNAT", 1471 natTypes: []natType{dnatTarget}, 1472 requestSrc: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1473 requestDst: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, 1474 expectedRequestSrc: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1475 expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1476 }, 1477 { 1478 name: "Twice-NAT", 1479 natTypes: twiceNATTypes, 1480 requestSrc: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1481 requestDst: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, 1482 expectedRequestSrc: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 1483 expectedRequestDst: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1484 }, 1485 }, 1486 }, 1487 { 1488 name: "IPv6", 1489 netProto: header.IPv6ProtocolNumber, 1490 transProto: header.ICMPv6ProtocolNumber, 1491 echoPkt: v6EchoPkt, 1492 checkEchoPkt: checkV6EchoPkt, 1493 1494 natTypes: []natTypeTest{ 1495 { 1496 name: "SNAT", 1497 natTypes: snatTypes, 1498 requestSrc: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1499 requestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1500 expectedRequestSrc: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 1501 expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1502 }, 1503 { 1504 name: "DNAT", 1505 natTypes: []natType{dnatTarget}, 1506 requestSrc: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1507 requestDst: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, 1508 expectedRequestSrc: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1509 expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1510 }, 1511 { 1512 name: "Twice-NAT", 1513 natTypes: twiceNATTypes, 1514 requestSrc: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1515 requestDst: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, 1516 expectedRequestSrc: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 1517 expectedRequestDst: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1518 }, 1519 }, 1520 }, 1521 } 1522 1523 for _, test := range tests { 1524 t.Run(test.name, func(t *testing.T) { 1525 for _, natTypeTest := range test.natTypes { 1526 t.Run(natTypeTest.name, func(t *testing.T) { 1527 for _, natType := range natTypeTest.natTypes { 1528 t.Run(natType.name, func(t *testing.T) { 1529 s := stack.New(stack.Options{ 1530 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 1531 TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, 1532 }) 1533 defer s.Destroy() 1534 1535 ep1 := channel.New(1, header.IPv6MinimumMTU, "") 1536 ep2 := channel.New(1, header.IPv6MinimumMTU, "") 1537 utils.SetupRouterStack(t, s, ep1, ep2) 1538 1539 natType.setupNAT(t, s, test.netProto, test.transProto, natTypeTest.expectedRequestSrc, natTypeTest.expectedRequestDst, 0 /* dnatPort */) 1540 1541 // Send and check the Echo Request. 1542 { 1543 ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ 1544 Payload: buffer.MakeWithData(test.echoPkt(natTypeTest.requestSrc, natTypeTest.requestDst, false /* reply */)), 1545 })) 1546 pkt := ep1.Read() 1547 if pkt == nil { 1548 t.Fatal("expected to read a packet on ep1") 1549 } 1550 payload := stack.PayloadSince(pkt.NetworkHeader()) 1551 defer payload.Release() 1552 test.checkEchoPkt(t, payload, natTypeTest.expectedRequestSrc, natTypeTest.expectedRequestDst, false /* reply */) 1553 pkt.DecRef() 1554 } 1555 1556 if t.Failed() { 1557 t.FailNow() 1558 } 1559 1560 // Send and check the Echo Reply. 1561 { 1562 ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ 1563 Payload: buffer.MakeWithData(test.echoPkt(natTypeTest.expectedRequestDst, natTypeTest.expectedRequestSrc, true /* reply */)), 1564 })) 1565 pkt := ep2.Read() 1566 if pkt == nil { 1567 t.Fatal("expected to read a packet on ep2") 1568 } 1569 payload := stack.PayloadSince(pkt.NetworkHeader()) 1570 defer payload.Release() 1571 test.checkEchoPkt(t, payload, natTypeTest.requestDst, natTypeTest.requestSrc, true /* reply */) 1572 pkt.DecRef() 1573 } 1574 }) 1575 } 1576 }) 1577 } 1578 }) 1579 } 1580 } 1581 1582 func TestNAT(t *testing.T) { 1583 const listenPort uint16 = 8080 1584 1585 type endpointAndAddresses struct { 1586 serverEP tcpip.Endpoint 1587 serverAddr tcpip.FullAddress 1588 serverReadableCH chan struct{} 1589 serverConnectAddr tcpip.Address 1590 1591 clientEP tcpip.Endpoint 1592 clientAddr tcpip.Address 1593 clientReadableCH chan struct{} 1594 clientConnectAddr tcpip.FullAddress 1595 } 1596 1597 newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { 1598 t.Helper() 1599 var wq waiter.Queue 1600 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 1601 wq.EventRegister(&we) 1602 t.Cleanup(func() { 1603 wq.EventUnregister(&we) 1604 }) 1605 1606 ep, err := s.NewEndpoint(transProto, netProto, &wq) 1607 if err != nil { 1608 t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) 1609 } 1610 t.Cleanup(ep.Close) 1611 1612 return ep, ch 1613 } 1614 1615 tests := []struct { 1616 name string 1617 netProto tcpip.NetworkProtocolNumber 1618 // Setups up the stacks in such a way that: 1619 // 1620 // - Host2 is the client for all tests. 1621 // - When performing SNAT only: 1622 // + Host1 is the server. 1623 // + NAT will transform client-originating packets' source addresses to 1624 // the router's NIC1's address before reaching Host1. 1625 // - When performing DNAT only: 1626 // + Router is the server. 1627 // + Client will send packets directed to Host1. 1628 // + NAT will transform client-originating packets' destination addresses 1629 // to the router's NIC2's address. 1630 // - When performing Twice-NAT: 1631 // + Host1 is the server. 1632 // + Client will send packets directed to router's NIC2. 1633 // + NAT will transform client originating packets' destination addresses 1634 // to Host1's address. 1635 // + NAT will transform client-originating packets' source addresses to 1636 // the router's NIC1's address before reaching Host1. 1637 epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses 1638 natTypes []natType 1639 }{ 1640 { 1641 name: "IPv4 SNAT", 1642 netProto: ipv4.ProtocolNumber, 1643 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1644 t.Helper() 1645 1646 listenerStack := host1Stack 1647 serverAddr := tcpip.FullAddress{ 1648 Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1649 Port: listenPort, 1650 } 1651 serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address 1652 clientConnectPort := serverAddr.Port 1653 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) 1654 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) 1655 return endpointAndAddresses{ 1656 serverEP: ep1, 1657 serverAddr: serverAddr, 1658 serverReadableCH: ep1WECH, 1659 serverConnectAddr: serverConnectAddr, 1660 1661 clientEP: ep2, 1662 clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1663 clientReadableCH: ep2WECH, 1664 clientConnectAddr: tcpip.FullAddress{ 1665 Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1666 Port: clientConnectPort, 1667 }, 1668 } 1669 }, 1670 natTypes: snatTypes, 1671 }, 1672 { 1673 name: "IPv4 DNAT", 1674 netProto: ipv4.ProtocolNumber, 1675 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1676 t.Helper() 1677 1678 // If we are performing DNAT, then the packet will be redirected 1679 // to the router. 1680 listenerStack := routerStack 1681 serverAddr := tcpip.FullAddress{ 1682 Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, 1683 Port: listenPort, 1684 } 1685 serverConnectAddr := utils.Host2IPv4Addr.AddressWithPrefix.Address 1686 // DNAT will update the destination port to what the server is 1687 // bound to. 1688 clientConnectPort := serverAddr.Port + 1 1689 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) 1690 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) 1691 return endpointAndAddresses{ 1692 serverEP: ep1, 1693 serverAddr: serverAddr, 1694 serverReadableCH: ep1WECH, 1695 serverConnectAddr: serverConnectAddr, 1696 1697 clientEP: ep2, 1698 clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1699 clientReadableCH: ep2WECH, 1700 clientConnectAddr: tcpip.FullAddress{ 1701 Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1702 Port: clientConnectPort, 1703 }, 1704 } 1705 }, 1706 natTypes: dnatTypes, 1707 }, 1708 { 1709 name: "IPv4 Twice-NAT", 1710 netProto: ipv4.ProtocolNumber, 1711 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1712 t.Helper() 1713 1714 listenerStack := host1Stack 1715 serverAddr := tcpip.FullAddress{ 1716 Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 1717 Port: listenPort, 1718 } 1719 serverConnectAddr := utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address 1720 clientConnectPort := serverAddr.Port 1721 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv4.ProtocolNumber) 1722 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) 1723 return endpointAndAddresses{ 1724 serverEP: ep1, 1725 serverAddr: serverAddr, 1726 serverReadableCH: ep1WECH, 1727 serverConnectAddr: serverConnectAddr, 1728 1729 clientEP: ep2, 1730 clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, 1731 clientReadableCH: ep2WECH, 1732 clientConnectAddr: tcpip.FullAddress{ 1733 Addr: utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, 1734 Port: clientConnectPort, 1735 }, 1736 } 1737 }, 1738 natTypes: twiceNATTypes, 1739 }, 1740 { 1741 name: "IPv6 SNAT", 1742 netProto: ipv6.ProtocolNumber, 1743 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1744 t.Helper() 1745 1746 listenerStack := host1Stack 1747 serverAddr := tcpip.FullAddress{ 1748 Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1749 Port: listenPort, 1750 } 1751 serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address 1752 clientConnectPort := serverAddr.Port 1753 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) 1754 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) 1755 return endpointAndAddresses{ 1756 serverEP: ep1, 1757 serverAddr: serverAddr, 1758 serverReadableCH: ep1WECH, 1759 serverConnectAddr: serverConnectAddr, 1760 1761 clientEP: ep2, 1762 clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1763 clientReadableCH: ep2WECH, 1764 clientConnectAddr: tcpip.FullAddress{ 1765 Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1766 Port: clientConnectPort, 1767 }, 1768 } 1769 }, 1770 natTypes: snatTypes, 1771 }, 1772 { 1773 name: "IPv6 DNAT", 1774 netProto: ipv6.ProtocolNumber, 1775 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1776 t.Helper() 1777 1778 // If we are performing DNAT, then the packet will be redirected 1779 // to the router. 1780 listenerStack := routerStack 1781 serverAddr := tcpip.FullAddress{ 1782 Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, 1783 Port: listenPort, 1784 } 1785 serverConnectAddr := utils.Host2IPv6Addr.AddressWithPrefix.Address 1786 // DNAT will update the destination port to what the server is 1787 // bound to. 1788 clientConnectPort := serverAddr.Port + 1 1789 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) 1790 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) 1791 return endpointAndAddresses{ 1792 serverEP: ep1, 1793 serverAddr: serverAddr, 1794 serverReadableCH: ep1WECH, 1795 serverConnectAddr: serverConnectAddr, 1796 1797 clientEP: ep2, 1798 clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1799 clientReadableCH: ep2WECH, 1800 clientConnectAddr: tcpip.FullAddress{ 1801 Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1802 Port: clientConnectPort, 1803 }, 1804 } 1805 }, 1806 natTypes: dnatTypes, 1807 }, 1808 { 1809 name: "IPv6 Twice-NAT", 1810 netProto: ipv6.ProtocolNumber, 1811 epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { 1812 t.Helper() 1813 1814 listenerStack := host1Stack 1815 serverAddr := tcpip.FullAddress{ 1816 Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 1817 Port: listenPort, 1818 } 1819 serverConnectAddr := utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address 1820 clientConnectPort := serverAddr.Port 1821 ep1, ep1WECH := newEP(t, listenerStack, proto, ipv6.ProtocolNumber) 1822 ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) 1823 return endpointAndAddresses{ 1824 serverEP: ep1, 1825 serverAddr: serverAddr, 1826 serverReadableCH: ep1WECH, 1827 serverConnectAddr: serverConnectAddr, 1828 1829 clientEP: ep2, 1830 clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, 1831 clientReadableCH: ep2WECH, 1832 clientConnectAddr: tcpip.FullAddress{ 1833 Addr: utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, 1834 Port: clientConnectPort, 1835 }, 1836 } 1837 }, 1838 natTypes: twiceNATTypes, 1839 }, 1840 } 1841 1842 subTests := []struct { 1843 name string 1844 proto tcpip.TransportProtocolNumber 1845 expectedConnectErr tcpip.Error 1846 setupServer func(t *testing.T, ep tcpip.Endpoint) 1847 setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) 1848 needRemoteAddr bool 1849 }{ 1850 { 1851 name: "UDP", 1852 proto: udp.ProtocolNumber, 1853 expectedConnectErr: nil, 1854 setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { 1855 t.Helper() 1856 1857 if err := ep.Connect(clientAddr); err != nil { 1858 t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) 1859 } 1860 return nil, nil 1861 }, 1862 needRemoteAddr: true, 1863 }, 1864 { 1865 name: "TCP", 1866 proto: tcp.ProtocolNumber, 1867 expectedConnectErr: &tcpip.ErrConnectStarted{}, 1868 setupServer: func(t *testing.T, ep tcpip.Endpoint) { 1869 t.Helper() 1870 1871 if err := ep.Listen(1); err != nil { 1872 t.Fatalf("ep.Listen(1): %s", err) 1873 } 1874 }, 1875 setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { 1876 t.Helper() 1877 1878 var addr tcpip.FullAddress 1879 for { 1880 newEP, wq, err := ep.Accept(&addr) 1881 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 1882 <-ch 1883 continue 1884 } 1885 if err != nil { 1886 t.Fatalf("ep.Accept(_): %s", err) 1887 } 1888 if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( 1889 "NIC", 1890 )); diff != "" { 1891 t.Errorf("accepted address mismatch (-want +got):\n%s", diff) 1892 } 1893 1894 we, newCH := waiter.NewChannelEntry(waiter.ReadableEvents) 1895 wq.EventRegister(&we) 1896 return newEP, newCH 1897 } 1898 }, 1899 needRemoteAddr: false, 1900 }, 1901 } 1902 1903 for _, test := range tests { 1904 t.Run(test.name, func(t *testing.T) { 1905 for _, subTest := range subTests { 1906 t.Run(subTest.name, func(t *testing.T) { 1907 for _, natType := range test.natTypes { 1908 t.Run(natType.name, func(t *testing.T) { 1909 stackOpts := stack.Options{ 1910 NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, 1911 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 1912 } 1913 1914 host1Stack := stack.New(stackOpts) 1915 defer host1Stack.Destroy() 1916 routerStack := stack.New(stackOpts) 1917 defer routerStack.Destroy() 1918 host2Stack := stack.New(stackOpts) 1919 defer host2Stack.Destroy() 1920 utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) 1921 1922 epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) 1923 natType.setupNAT(t, routerStack, test.netProto, subTest.proto, epsAndAddrs.serverConnectAddr, epsAndAddrs.serverAddr.Addr, listenPort) 1924 1925 if err := epsAndAddrs.serverEP.Bind(epsAndAddrs.serverAddr); err != nil { 1926 t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", epsAndAddrs.serverAddr, err) 1927 } 1928 clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} 1929 if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { 1930 t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) 1931 } 1932 1933 if subTest.setupServer != nil { 1934 subTest.setupServer(t, epsAndAddrs.serverEP) 1935 } 1936 { 1937 err := epsAndAddrs.clientEP.Connect(epsAndAddrs.clientConnectAddr) 1938 if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { 1939 t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", epsAndAddrs.clientConnectAddr, diff) 1940 } 1941 } 1942 serverConnectAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverConnectAddr} 1943 if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { 1944 t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) 1945 } else { 1946 serverConnectAddr.Port = addr.Port 1947 } 1948 1949 serverEP := epsAndAddrs.serverEP 1950 serverCH := epsAndAddrs.serverReadableCH 1951 if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, serverConnectAddr); ep != nil { 1952 defer ep.Close() 1953 serverEP = ep 1954 serverCH = ch 1955 } 1956 1957 write := func(ep tcpip.Endpoint, data []byte) { 1958 t.Helper() 1959 1960 var r bytes.Reader 1961 r.Reset(data) 1962 var wOpts tcpip.WriteOptions 1963 n, err := ep.Write(&r, wOpts) 1964 if err != nil { 1965 t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) 1966 } 1967 if want := int64(len(data)); n != want { 1968 t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) 1969 } 1970 } 1971 1972 read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { 1973 t.Helper() 1974 1975 var buf bytes.Buffer 1976 var res tcpip.ReadResult 1977 for { 1978 var err tcpip.Error 1979 opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} 1980 res, err = ep.Read(&buf, opts) 1981 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 1982 <-ch 1983 continue 1984 } 1985 if err != nil { 1986 t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) 1987 } 1988 break 1989 } 1990 1991 readResult := tcpip.ReadResult{ 1992 Count: len(data), 1993 Total: len(data), 1994 } 1995 if subTest.needRemoteAddr { 1996 readResult.RemoteAddr = expectedFrom 1997 } 1998 if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( 1999 "ControlMessages", 2000 "RemoteAddr.NIC", 2001 )); diff != "" { 2002 t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) 2003 } 2004 if diff := cmp.Diff(buf.Bytes(), data); diff != "" { 2005 t.Errorf("received data mismatch (-want +got):\n%s", diff) 2006 } 2007 2008 if t.Failed() { 2009 t.FailNow() 2010 } 2011 } 2012 2013 { 2014 data := []byte{1, 2, 3, 4} 2015 write(epsAndAddrs.clientEP, data) 2016 read(serverCH, serverEP, data, serverConnectAddr) 2017 } 2018 2019 { 2020 data := []byte{5, 6, 7, 8, 9, 10, 11, 12} 2021 write(serverEP, data) 2022 read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, epsAndAddrs.clientConnectAddr) 2023 } 2024 }) 2025 } 2026 }) 2027 } 2028 }) 2029 } 2030 } 2031 2032 func encodeIPv4Header(v []byte, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) { 2033 ip := header.IPv4(v) 2034 ip.Encode(&header.IPv4Fields{ 2035 TotalLength: uint16(totalLen), 2036 Protocol: uint8(transProto), 2037 TTL: 64, 2038 SrcAddr: srcAddr, 2039 DstAddr: dstAddr, 2040 }) 2041 ip.SetChecksum(^ip.CalculateChecksum()) 2042 } 2043 2044 func encodeIPv6Header(v []byte, payloadLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) { 2045 ip := header.IPv6(v) 2046 ip.Encode(&header.IPv6Fields{ 2047 PayloadLength: uint16(payloadLen), 2048 TransportProtocol: transProto, 2049 HopLimit: 64, 2050 SrcAddr: srcAddr, 2051 DstAddr: dstAddr, 2052 }) 2053 } 2054 2055 func udpv4Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte { 2056 udpSize := header.UDPMinimumSize + dataSize 2057 hdr := prependable.New(header.IPv4MinimumSize + udpSize) 2058 udp := header.UDP(hdr.Prepend(udpSize)) 2059 udp.SetSourcePort(srcPort) 2060 udp.SetDestinationPort(dstPort) 2061 udp.SetLength(uint16(udpSize)) 2062 udp.SetChecksum(0) 2063 udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( 2064 header.UDPProtocolNumber, 2065 srcAddr, 2066 dstAddr, 2067 uint16(len(udp)), 2068 ))) 2069 encodeIPv4Header( 2070 hdr.Prepend(header.IPv4MinimumSize), 2071 hdr.UsedLength(), 2072 header.UDPProtocolNumber, 2073 srcAddr, 2074 dstAddr, 2075 ) 2076 return hdr.View() 2077 } 2078 2079 func tcpv4Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte { 2080 tcpSize := header.TCPMinimumSize + dataSize 2081 hdr := prependable.New(header.IPv4MinimumSize + tcpSize) 2082 tcp := header.TCP(hdr.Prepend(tcpSize)) 2083 tcp.SetSourcePort(srcPort) 2084 tcp.SetDestinationPort(dstPort) 2085 tcp.SetDataOffset(header.TCPMinimumSize) 2086 tcp.SetChecksum(0) 2087 tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum( 2088 header.TCPProtocolNumber, 2089 srcAddr, 2090 dstAddr, 2091 uint16(len(tcp)), 2092 ))) 2093 encodeIPv4Header( 2094 hdr.Prepend(header.IPv4MinimumSize), 2095 hdr.UsedLength(), 2096 header.TCPProtocolNumber, 2097 srcAddr, 2098 dstAddr, 2099 ) 2100 return hdr.View() 2101 } 2102 2103 func icmpv4Packet(srcAddr, dstAddr tcpip.Address, icmpType header.ICMPv4Type, ident uint16) []byte { 2104 hdr := prependable.New(header.IPv4MinimumSize + header.ICMPv4MinimumSize) 2105 icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) 2106 icmp.SetType(icmpType) 2107 icmp.SetIdent(ident) 2108 icmp.SetChecksum(0) 2109 icmp.SetChecksum(^checksum.Checksum(icmp, 0)) 2110 encodeIPv4Header( 2111 hdr.Prepend(header.IPv4MinimumSize), 2112 hdr.UsedLength(), 2113 header.ICMPv4ProtocolNumber, 2114 srcAddr, 2115 dstAddr, 2116 ) 2117 return hdr.View() 2118 } 2119 2120 func udpv6Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte { 2121 udpSize := header.UDPMinimumSize + dataSize 2122 hdr := prependable.New(header.IPv6MinimumSize + udpSize) 2123 udp := header.UDP(hdr.Prepend(udpSize)) 2124 udp.SetSourcePort(srcPort) 2125 udp.SetDestinationPort(dstPort) 2126 udp.SetLength(uint16(udpSize)) 2127 udp.SetChecksum(0) 2128 udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum( 2129 header.UDPProtocolNumber, 2130 srcAddr, 2131 dstAddr, 2132 uint16(len(udp)), 2133 ))) 2134 encodeIPv6Header( 2135 hdr.Prepend(header.IPv6MinimumSize), 2136 len(udp), 2137 header.UDPProtocolNumber, 2138 srcAddr, 2139 dstAddr, 2140 ) 2141 return hdr.View() 2142 } 2143 2144 func tcpv6Packet(srcAddr, dstAddr tcpip.Address, srcPort, dstPort uint16, dataSize int) []byte { 2145 tcpSize := header.TCPMinimumSize + dataSize 2146 hdr := prependable.New(header.IPv6MinimumSize + tcpSize) 2147 tcp := header.TCP(hdr.Prepend(tcpSize)) 2148 tcp.SetSourcePort(srcPort) 2149 tcp.SetDestinationPort(dstPort) 2150 tcp.SetDataOffset(header.TCPMinimumSize) 2151 tcp.SetChecksum(0) 2152 tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum( 2153 header.TCPProtocolNumber, 2154 srcAddr, 2155 dstAddr, 2156 uint16(len(tcp)), 2157 ))) 2158 encodeIPv6Header( 2159 hdr.Prepend(header.IPv6MinimumSize), 2160 len(tcp), 2161 header.TCPProtocolNumber, 2162 srcAddr, 2163 dstAddr, 2164 ) 2165 return hdr.View() 2166 } 2167 2168 func icmpv6Packet(srcAddr, dstAddr tcpip.Address, icmpType header.ICMPv6Type, ident uint16) []byte { 2169 hdr := prependable.New(header.IPv6MinimumSize + header.ICMPv6MinimumSize) 2170 icmp := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) 2171 icmp.SetType(icmpType) 2172 icmp.SetIdent(ident) 2173 icmp.SetChecksum(0) 2174 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 2175 Header: icmp, 2176 Src: srcAddr, 2177 Dst: dstAddr, 2178 })) 2179 encodeIPv6Header( 2180 hdr.Prepend(header.IPv6MinimumSize), 2181 len(icmp), 2182 header.ICMPv6ProtocolNumber, 2183 srcAddr, 2184 dstAddr, 2185 ) 2186 return hdr.View() 2187 } 2188 2189 func TestNATICMPError(t *testing.T) { 2190 const ( 2191 srcPort = 1234 2192 dstPort = 5432 2193 dataSize = 4 2194 ) 2195 2196 type icmpTypeTest struct { 2197 name string 2198 val uint8 2199 expectResponse bool 2200 } 2201 2202 type transportTypeTest struct { 2203 name string 2204 proto tcpip.TransportProtocolNumber 2205 buf []byte 2206 checkNATed func(*testing.T, *buffer.View) 2207 } 2208 2209 tests := []struct { 2210 name string 2211 netProto tcpip.NetworkProtocolNumber 2212 host1Addr tcpip.Address 2213 icmpError func(*testing.T, []byte, uint8) []byte 2214 decrementTTL func([]byte) 2215 checkNATedError func(*testing.T, *buffer.View, []byte, uint8) 2216 2217 transportTypes []transportTypeTest 2218 icmpTypes []icmpTypeTest 2219 }{ 2220 { 2221 name: "IPv4", 2222 netProto: ipv4.ProtocolNumber, 2223 host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 2224 icmpError: func(t *testing.T, original []byte, icmpType uint8) []byte { 2225 hdr := prependable.New(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)) 2226 if n := copy(hdr.Prepend(len(original)), original); n != len(original) { 2227 t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) 2228 } 2229 icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) 2230 icmp.SetType(header.ICMPv4Type(icmpType)) 2231 icmp.SetChecksum(0) 2232 icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0)) 2233 encodeIPv4Header( 2234 hdr.Prepend(header.IPv4MinimumSize), 2235 hdr.UsedLength(), 2236 header.ICMPv4ProtocolNumber, 2237 utils.Host1IPv4Addr.AddressWithPrefix.Address, 2238 utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 2239 ) 2240 return hdr.View() 2241 }, 2242 decrementTTL: func(v []byte) { 2243 ip := header.IPv4(v) 2244 ip.SetTTL(ip.TTL() - 1) 2245 ip.SetChecksum(0) 2246 ip.SetChecksum(^ip.CalculateChecksum()) 2247 }, 2248 checkNATedError: func(t *testing.T, v *buffer.View, original []byte, icmpType uint8) { 2249 checker.IPv4(t, v, 2250 checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address), 2251 checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address), 2252 checker.ICMPv4( 2253 checker.ICMPv4Type(header.ICMPv4Type(icmpType)), 2254 checker.ICMPv4Checksum(), 2255 checker.ICMPv4Payload(original), 2256 ), 2257 ) 2258 }, 2259 transportTypes: []transportTypeTest{ 2260 { 2261 name: "UDP", 2262 proto: header.UDPProtocolNumber, 2263 buf: func() []byte { 2264 return udpv4Packet(utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize) 2265 }(), 2266 checkNATed: func(t *testing.T, v *buffer.View) { 2267 checker.IPv4(t, v, 2268 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 2269 checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), 2270 checker.UDP( 2271 checker.SrcPort(srcPort), 2272 checker.DstPort(dstPort), 2273 ), 2274 ) 2275 }, 2276 }, 2277 { 2278 name: "TCP", 2279 proto: header.TCPProtocolNumber, 2280 buf: func() []byte { 2281 return tcpv4Packet(utils.Host2IPv4Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize) 2282 }(), 2283 checkNATed: func(t *testing.T, v *buffer.View) { 2284 checker.IPv4(t, v, 2285 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 2286 checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), 2287 checker.TCP( 2288 checker.SrcPort(srcPort), 2289 checker.DstPort(dstPort), 2290 ), 2291 ) 2292 }, 2293 }, 2294 }, 2295 icmpTypes: []icmpTypeTest{ 2296 { 2297 name: "Destination Unreachable", 2298 val: uint8(header.ICMPv4DstUnreachable), 2299 expectResponse: true, 2300 }, 2301 { 2302 name: "Time Exceeded", 2303 val: uint8(header.ICMPv4TimeExceeded), 2304 expectResponse: true, 2305 }, 2306 { 2307 name: "Parameter Problem", 2308 val: uint8(header.ICMPv4ParamProblem), 2309 expectResponse: true, 2310 }, 2311 { 2312 name: "Echo Request", 2313 val: uint8(header.ICMPv4Echo), 2314 expectResponse: false, 2315 }, 2316 { 2317 name: "Echo Reply", 2318 val: uint8(header.ICMPv4EchoReply), 2319 expectResponse: false, 2320 }, 2321 }, 2322 }, 2323 { 2324 name: "IPv6", 2325 netProto: ipv6.ProtocolNumber, 2326 host1Addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 2327 icmpError: func(t *testing.T, original []byte, icmpType uint8) []byte { 2328 payloadLen := header.ICMPv6MinimumSize + len(original) 2329 hdr := prependable.New(header.IPv6MinimumSize + payloadLen) 2330 icmp := header.ICMPv6(hdr.Prepend(payloadLen)) 2331 icmp.SetType(header.ICMPv6Type(icmpType)) 2332 if n := copy(icmp.Payload(), original); n != len(original) { 2333 t.Fatalf("got copy(...) = %d, want = %d", n, len(original)) 2334 } 2335 icmp.SetChecksum(0) 2336 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ 2337 Header: icmp, 2338 Src: utils.Host1IPv6Addr.AddressWithPrefix.Address, 2339 Dst: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 2340 })) 2341 encodeIPv6Header( 2342 hdr.Prepend(header.IPv6MinimumSize), 2343 payloadLen, 2344 header.ICMPv6ProtocolNumber, 2345 utils.Host1IPv6Addr.AddressWithPrefix.Address, 2346 utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 2347 ) 2348 return hdr.View() 2349 }, 2350 decrementTTL: func(v []byte) { 2351 ip := header.IPv6(v) 2352 ip.SetHopLimit(ip.HopLimit() - 1) 2353 }, 2354 checkNATedError: func(t *testing.T, v *buffer.View, original []byte, icmpType uint8) { 2355 checker.IPv6(t, v, 2356 checker.SrcAddr(utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address), 2357 checker.DstAddr(utils.Host2IPv6Addr.AddressWithPrefix.Address), 2358 checker.ICMPv6( 2359 checker.ICMPv6Type(header.ICMPv6Type(icmpType)), 2360 checker.ICMPv6Payload(original), 2361 ), 2362 ) 2363 }, 2364 transportTypes: []transportTypeTest{ 2365 { 2366 name: "UDP", 2367 proto: header.UDPProtocolNumber, 2368 buf: func() []byte { 2369 return udpv6Packet(utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize) 2370 }(), 2371 checkNATed: func(t *testing.T, v *buffer.View) { 2372 checker.IPv6(t, v, 2373 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 2374 checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), 2375 checker.UDP( 2376 checker.SrcPort(srcPort), 2377 checker.DstPort(dstPort), 2378 ), 2379 ) 2380 }, 2381 }, 2382 { 2383 name: "TCP", 2384 proto: header.TCPProtocolNumber, 2385 buf: func() []byte { 2386 return tcpv6Packet(utils.Host2IPv6Addr.AddressWithPrefix.Address, utils.RouterNIC2IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, dataSize) 2387 }(), 2388 checkNATed: func(t *testing.T, v *buffer.View) { 2389 checker.IPv6(t, v, 2390 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 2391 checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), 2392 checker.TCP( 2393 checker.SrcPort(srcPort), 2394 checker.DstPort(dstPort), 2395 ), 2396 ) 2397 }, 2398 }, 2399 }, 2400 icmpTypes: []icmpTypeTest{ 2401 { 2402 name: "Destination Unreachable", 2403 val: uint8(header.ICMPv6DstUnreachable), 2404 expectResponse: true, 2405 }, 2406 { 2407 name: "Packet Too Big", 2408 val: uint8(header.ICMPv6PacketTooBig), 2409 expectResponse: true, 2410 }, 2411 { 2412 name: "Time Exceeded", 2413 val: uint8(header.ICMPv6TimeExceeded), 2414 expectResponse: true, 2415 }, 2416 { 2417 name: "Parameter Problem", 2418 val: uint8(header.ICMPv6ParamProblem), 2419 expectResponse: true, 2420 }, 2421 { 2422 name: "Echo Request", 2423 val: uint8(header.ICMPv6EchoRequest), 2424 expectResponse: false, 2425 }, 2426 { 2427 name: "Echo Reply", 2428 val: uint8(header.ICMPv6EchoReply), 2429 expectResponse: false, 2430 }, 2431 }, 2432 }, 2433 } 2434 2435 trimTests := []struct { 2436 name string 2437 trimLen int 2438 expectNATedICMP bool 2439 }{ 2440 { 2441 name: "Trim nothing", 2442 trimLen: 0, 2443 expectNATedICMP: true, 2444 }, 2445 { 2446 name: "Trim data", 2447 trimLen: dataSize, 2448 expectNATedICMP: true, 2449 }, 2450 { 2451 name: "Trim data and transport header", 2452 trimLen: dataSize + 1, 2453 expectNATedICMP: false, 2454 }, 2455 } 2456 2457 for _, test := range tests { 2458 t.Run(test.name, func(t *testing.T) { 2459 for _, transportType := range test.transportTypes { 2460 t.Run(transportType.name, func(t *testing.T) { 2461 for _, icmpType := range test.icmpTypes { 2462 t.Run(icmpType.name, func(t *testing.T) { 2463 for _, trimTest := range trimTests { 2464 t.Run(trimTest.name, func(t *testing.T) { 2465 s := stack.New(stack.Options{ 2466 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2467 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 2468 }) 2469 defer s.Destroy() 2470 2471 ep1 := channel.New(1, header.IPv6MinimumMTU, "") 2472 ep2 := channel.New(1, header.IPv6MinimumMTU, "") 2473 utils.SetupRouterStack(t, s, ep1, ep2) 2474 2475 ipv6 := test.netProto == ipv6.ProtocolNumber 2476 ipt := s.IPTables() 2477 2478 table := stack.Table{ 2479 Rules: []stack.Rule{ 2480 // Prerouting 2481 { 2482 Filter: stack.IPHeaderFilter{ 2483 Protocol: transportType.proto, 2484 CheckProtocol: true, 2485 InputInterface: utils.RouterNIC2Name, 2486 }, 2487 Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort, ChangeAddress: true, ChangePort: true}, 2488 }, 2489 { 2490 Target: &stack.AcceptTarget{}, 2491 }, 2492 2493 // Input 2494 { 2495 Target: &stack.AcceptTarget{}, 2496 }, 2497 2498 // Forward 2499 { 2500 Target: &stack.AcceptTarget{}, 2501 }, 2502 2503 // Output 2504 { 2505 Target: &stack.AcceptTarget{}, 2506 }, 2507 2508 // Postrouting 2509 { 2510 Filter: stack.IPHeaderFilter{ 2511 Protocol: transportType.proto, 2512 CheckProtocol: true, 2513 OutputInterface: utils.RouterNIC1Name, 2514 }, 2515 Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto}, 2516 }, 2517 { 2518 Target: &stack.AcceptTarget{}, 2519 }, 2520 }, 2521 BuiltinChains: [stack.NumHooks]int{ 2522 stack.Prerouting: 0, 2523 stack.Input: 2, 2524 stack.Forward: 3, 2525 stack.Output: 4, 2526 stack.Postrouting: 5, 2527 }, 2528 } 2529 2530 ipt.ForceReplaceTable(stack.NATID, table, ipv6) 2531 2532 buf := transportType.buf 2533 2534 ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2535 Payload: buffer.MakeWithData(append([]byte{}, buf...)), 2536 })) 2537 2538 { 2539 pkt := ep1.Read() 2540 if pkt == nil { 2541 t.Fatal("expected to read a packet on ep1") 2542 } 2543 pktView := stack.PayloadSince(pkt.NetworkHeader()) 2544 defer pktView.Release() 2545 pkt.DecRef() 2546 transportType.checkNATed(t, pktView) 2547 if t.Failed() { 2548 t.FailNow() 2549 } 2550 2551 pktSlice := pktView.AsSlice()[:pktView.Size()-trimTest.trimLen] 2552 buf = buf[:len(buf)-trimTest.trimLen] 2553 2554 ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2555 Payload: buffer.MakeWithData(test.icmpError(t, pktSlice, icmpType.val)), 2556 })) 2557 } 2558 2559 pkt := ep2.Read() 2560 expectResponse := icmpType.expectResponse && trimTest.expectNATedICMP 2561 if (pkt != nil) != expectResponse { 2562 t.Fatalf("got ep2.Read() = %#v, want = (_ == nil) = %t", pkt, expectResponse) 2563 } 2564 if !expectResponse { 2565 return 2566 } 2567 test.decrementTTL(buf) 2568 payload := stack.PayloadSince(pkt.NetworkHeader()) 2569 defer payload.Release() 2570 test.checkNATedError(t, payload, buf, icmpType.val) 2571 pkt.DecRef() 2572 }) 2573 } 2574 }) 2575 } 2576 }) 2577 } 2578 }) 2579 } 2580 } 2581 2582 func TestSNATHandlePortOrIdentConflicts(t *testing.T) { 2583 const dstPort = 5432 2584 2585 type portOrIdentRange struct { 2586 first uint16 2587 last uint16 2588 } 2589 2590 type srcPortOrIdentRangeTest struct { 2591 name string 2592 originalRange portOrIdentRange 2593 targetRange portOrIdentRange 2594 } 2595 2596 srcPortRanges := []srcPortOrIdentRangeTest{ 2597 { 2598 name: "Less than 512", 2599 originalRange: portOrIdentRange{first: 1, last: 511}, 2600 targetRange: portOrIdentRange{first: 1, last: 511}, 2601 }, 2602 { 2603 name: "Greater than or equal to 512 but less than 1024", 2604 originalRange: portOrIdentRange{first: 512, last: 1023}, 2605 targetRange: portOrIdentRange{first: 1, last: 1023}, 2606 }, 2607 { 2608 name: "Greater than or equal to 1024", 2609 originalRange: portOrIdentRange{first: 1024, last: math.MaxUint16}, 2610 targetRange: portOrIdentRange{first: 1024, last: math.MaxUint16}, 2611 }, 2612 } 2613 2614 // Unlike TCP/UDP, the Ident may be mapped to any 16-bit value. 2615 identRanges := []srcPortOrIdentRangeTest{ 2616 { 2617 name: "Less than 512", 2618 originalRange: portOrIdentRange{first: 0, last: 511}, 2619 targetRange: portOrIdentRange{first: 0, last: math.MaxUint16}, 2620 }, 2621 { 2622 name: "Greater than or equal to 512 but less than 1024", 2623 originalRange: portOrIdentRange{first: 512, last: 1023}, 2624 targetRange: portOrIdentRange{first: 0, last: math.MaxUint16}, 2625 }, 2626 { 2627 name: "Greater than or equal to 1024", 2628 originalRange: portOrIdentRange{first: 1024, last: math.MaxUint16}, 2629 targetRange: portOrIdentRange{first: 0, last: math.MaxUint16}, 2630 }, 2631 } 2632 2633 type transportTypeTest struct { 2634 name string 2635 proto tcpip.TransportProtocolNumber 2636 buf func(tcpip.Address, uint16) []byte 2637 checkNATed func(*testing.T, *buffer.View, uint16, bool, portOrIdentRange) 2638 srcPortOrIdentRanges []srcPortOrIdentRangeTest 2639 } 2640 2641 compareSrcPortOrIdent := func(t *testing.T, gotPort uint16, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) { 2642 t.Helper() 2643 2644 if firstPacket { 2645 if gotPort != originalSrcPort { 2646 t.Errorf("got port/ident = %d, want = %d", gotPort, originalSrcPort) 2647 } 2648 return 2649 } 2650 2651 if gotPort < expectedRange.first || gotPort > expectedRange.last { 2652 t.Errorf("got port/ident = %d, want in range [%d, %d]", gotPort, expectedRange.first, expectedRange.last) 2653 } 2654 } 2655 2656 tests := []struct { 2657 name string 2658 netProto tcpip.NetworkProtocolNumber 2659 routerNIC1Addr tcpip.Address 2660 srcAddrs []tcpip.Address 2661 transportTypes []transportTypeTest 2662 }{ 2663 { 2664 name: "IPv4", 2665 netProto: ipv4.ProtocolNumber, 2666 routerNIC1Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 2667 srcAddrs: []tcpip.Address{ 2668 utils.Ipv4Addr1.AddressWithPrefix.Address, 2669 utils.Ipv4Addr2.AddressWithPrefix.Address, 2670 utils.Ipv4Addr3.AddressWithPrefix.Address, 2671 }, 2672 transportTypes: []transportTypeTest{ 2673 { 2674 name: "UDP", 2675 proto: header.UDPProtocolNumber, 2676 buf: func(srcAddr tcpip.Address, srcPort uint16) []byte { 2677 return udpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */) 2678 }, 2679 checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) { 2680 checker.IPv4(t, v, 2681 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 2682 checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), 2683 checker.UDP( 2684 checker.DstPort(dstPort), 2685 ), 2686 ) 2687 2688 if !t.Failed() { 2689 compareSrcPortOrIdent(t, header.UDP(header.IPv4(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange) 2690 } 2691 }, 2692 srcPortOrIdentRanges: srcPortRanges, 2693 }, 2694 { 2695 name: "TCP", 2696 proto: header.TCPProtocolNumber, 2697 buf: func(srcAddr tcpip.Address, srcPort uint16) []byte { 2698 return tcpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */) 2699 }, 2700 checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) { 2701 checker.IPv4(t, v, 2702 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 2703 checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), 2704 checker.TCP( 2705 checker.DstPort(dstPort), 2706 ), 2707 ) 2708 2709 if !t.Failed() { 2710 compareSrcPortOrIdent(t, header.TCP(header.IPv4(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange) 2711 } 2712 }, 2713 srcPortOrIdentRanges: srcPortRanges, 2714 }, 2715 { 2716 name: "ICMP Echo", 2717 proto: header.ICMPv4ProtocolNumber, 2718 buf: func(srcAddr tcpip.Address, ident uint16) []byte { 2719 return icmpv4Packet(srcAddr, utils.Host1IPv4Addr.AddressWithPrefix.Address, header.ICMPv4Echo, ident) 2720 }, 2721 checkNATed: func(t *testing.T, v *buffer.View, originalIdent uint16, firstPacket bool, expectedRange portOrIdentRange) { 2722 checker.IPv4(t, v, 2723 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 2724 checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address), 2725 checker.ICMPv4( 2726 checker.ICMPv4Type(header.ICMPv4Echo), 2727 checker.ICMPv4Checksum(), 2728 ), 2729 ) 2730 2731 if !t.Failed() { 2732 compareSrcPortOrIdent(t, header.ICMPv4(header.IPv4(v.AsSlice()).Payload()).Ident(), originalIdent, firstPacket, expectedRange) 2733 } 2734 }, 2735 srcPortOrIdentRanges: identRanges, 2736 }, 2737 }, 2738 }, 2739 { 2740 name: "IPv6", 2741 netProto: ipv6.ProtocolNumber, 2742 routerNIC1Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 2743 srcAddrs: []tcpip.Address{ 2744 utils.Ipv6Addr1.AddressWithPrefix.Address, 2745 utils.Ipv6Addr2.AddressWithPrefix.Address, 2746 utils.Ipv6Addr2.AddressWithPrefix.Address, 2747 }, 2748 transportTypes: []transportTypeTest{ 2749 { 2750 name: "UDP", 2751 proto: header.UDPProtocolNumber, 2752 buf: func(srcAddr tcpip.Address, srcPort uint16) []byte { 2753 return udpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */) 2754 }, 2755 checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) { 2756 checker.IPv6(t, v, 2757 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 2758 checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), 2759 checker.UDP( 2760 checker.DstPort(dstPort), 2761 ), 2762 ) 2763 2764 if !t.Failed() { 2765 compareSrcPortOrIdent(t, header.UDP(header.IPv6(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange) 2766 } 2767 }, 2768 srcPortOrIdentRanges: srcPortRanges, 2769 }, 2770 { 2771 name: "TCP", 2772 proto: header.TCPProtocolNumber, 2773 buf: func(srcAddr tcpip.Address, srcPort uint16) []byte { 2774 return tcpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, srcPort, dstPort, 0 /* dataSize */) 2775 }, 2776 checkNATed: func(t *testing.T, v *buffer.View, originalSrcPort uint16, firstPacket bool, expectedRange portOrIdentRange) { 2777 checker.IPv6(t, v, 2778 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 2779 checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), 2780 checker.TCP( 2781 checker.DstPort(dstPort), 2782 ), 2783 ) 2784 2785 if !t.Failed() { 2786 compareSrcPortOrIdent(t, header.TCP(header.IPv6(v.AsSlice()).Payload()).SourcePort(), originalSrcPort, firstPacket, expectedRange) 2787 } 2788 }, 2789 srcPortOrIdentRanges: srcPortRanges, 2790 }, 2791 { 2792 name: "ICMP Echo", 2793 proto: header.ICMPv6ProtocolNumber, 2794 buf: func(srcAddr tcpip.Address, ident uint16) []byte { 2795 return icmpv6Packet(srcAddr, utils.Host1IPv6Addr.AddressWithPrefix.Address, header.ICMPv6EchoRequest, ident) 2796 }, 2797 checkNATed: func(t *testing.T, v *buffer.View, originalIdent uint16, firstPacket bool, expectedRange portOrIdentRange) { 2798 checker.IPv6(t, v, 2799 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 2800 checker.DstAddr(utils.Host1IPv6Addr.AddressWithPrefix.Address), 2801 checker.ICMPv6( 2802 checker.ICMPv6Type(header.ICMPv6EchoRequest), 2803 ), 2804 ) 2805 2806 if !t.Failed() { 2807 compareSrcPortOrIdent(t, header.ICMPv6(header.IPv6(v.AsSlice()).Payload()).Ident(), originalIdent, firstPacket, expectedRange) 2808 } 2809 }, 2810 srcPortOrIdentRanges: identRanges, 2811 }, 2812 }, 2813 }, 2814 } 2815 2816 natTypes := []struct { 2817 name string 2818 target func(tcpip.NetworkProtocolNumber, tcpip.Address) stack.Target 2819 }{ 2820 { 2821 name: "Masquerade", 2822 target: func(netProto tcpip.NetworkProtocolNumber, _ tcpip.Address) stack.Target { 2823 return &stack.MasqueradeTarget{NetworkProtocol: netProto} 2824 }, 2825 }, 2826 { 2827 name: "SNAT", 2828 target: func(netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) stack.Target { 2829 return &stack.SNATTarget{NetworkProtocol: netProto, Addr: addr, ChangeAddress: true, ChangePort: true} 2830 }, 2831 }, 2832 } 2833 2834 for _, test := range tests { 2835 t.Run(test.name, func(t *testing.T) { 2836 for _, transportType := range test.transportTypes { 2837 t.Run(transportType.name, func(t *testing.T) { 2838 for _, natType := range natTypes { 2839 t.Run(natType.name, func(t *testing.T) { 2840 for _, srcPortOrIdentRange := range transportType.srcPortOrIdentRanges { 2841 t.Run(srcPortOrIdentRange.name, func(t *testing.T) { 2842 for _, srcPortOrIdent := range [2]uint16{srcPortOrIdentRange.originalRange.first, srcPortOrIdentRange.originalRange.last} { 2843 t.Run(fmt.Sprintf("OriginalSrcPortOrIdent=%d", srcPortOrIdent), func(t *testing.T) { 2844 s := stack.New(stack.Options{ 2845 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2846 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 2847 }) 2848 defer s.Destroy() 2849 2850 ep1 := channel.New(1, header.IPv6MinimumMTU, "") 2851 ep2 := channel.New(1, header.IPv6MinimumMTU, "") 2852 utils.SetupRouterStack(t, s, ep1, ep2) 2853 2854 ipv6 := test.netProto == ipv6.ProtocolNumber 2855 ipt := s.IPTables() 2856 2857 table := stack.Table{ 2858 Rules: []stack.Rule{ 2859 // Prerouting 2860 { 2861 Target: &stack.AcceptTarget{}, 2862 }, 2863 2864 // Input 2865 { 2866 Target: &stack.AcceptTarget{}, 2867 }, 2868 2869 // Forward 2870 { 2871 Target: &stack.AcceptTarget{}, 2872 }, 2873 2874 // Output 2875 { 2876 Target: &stack.AcceptTarget{}, 2877 }, 2878 2879 // Postrouting 2880 { 2881 Filter: stack.IPHeaderFilter{ 2882 Protocol: transportType.proto, 2883 CheckProtocol: true, 2884 OutputInterface: utils.RouterNIC1Name, 2885 }, 2886 Target: natType.target(test.netProto, test.routerNIC1Addr), 2887 }, 2888 { 2889 Target: &stack.AcceptTarget{}, 2890 }, 2891 }, 2892 BuiltinChains: [stack.NumHooks]int{ 2893 stack.Prerouting: 0, 2894 stack.Input: 1, 2895 stack.Forward: 2, 2896 stack.Output: 3, 2897 stack.Postrouting: 4, 2898 }, 2899 } 2900 2901 ipt.ForceReplaceTable(stack.NATID, table, ipv6) 2902 2903 for i, srcAddr := range test.srcAddrs { 2904 t.Run(fmt.Sprintf("Packet#%d", i), func(t *testing.T) { 2905 ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{ 2906 Payload: buffer.MakeWithData(transportType.buf(srcAddr, srcPortOrIdent)), 2907 })) 2908 2909 pkt := ep1.Read() 2910 if pkt == nil { 2911 t.Fatal("expected to read a packet on ep1") 2912 } 2913 pktView := stack.PayloadSince(pkt.NetworkHeader()) 2914 defer pktView.Release() 2915 pkt.DecRef() 2916 transportType.checkNATed(t, pktView, srcPortOrIdent, i == 0, srcPortOrIdentRange.targetRange) 2917 }) 2918 } 2919 }) 2920 } 2921 }) 2922 } 2923 }) 2924 } 2925 }) 2926 } 2927 }) 2928 } 2929 } 2930 2931 func TestSNATLocallyGeneratedTrafficPorts(t *testing.T) { 2932 s := stack.New(stack.Options{ 2933 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 2934 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 2935 }) 2936 defer s.Destroy() 2937 2938 ep1 := channel.New(1, header.IPv4MinimumMTU, "") 2939 ep2 := channel.New(1, header.IPv4MinimumMTU, "") 2940 utils.SetupRouterStack(t, s, ep1, ep2) 2941 2942 // Configure Masquerade NAT on the router stack. 2943 ipt := s.IPTables() 2944 table := stack.Table{ 2945 Rules: []stack.Rule{ 2946 // Prerouting 2947 { 2948 Target: &stack.AcceptTarget{}, 2949 }, 2950 2951 // Input 2952 { 2953 Target: &stack.AcceptTarget{}, 2954 }, 2955 2956 // Forward 2957 { 2958 Target: &stack.AcceptTarget{}, 2959 }, 2960 2961 // Output 2962 { 2963 Target: &stack.AcceptTarget{}, 2964 }, 2965 2966 // Postrouting 2967 { 2968 Filter: stack.IPHeaderFilter{ 2969 Protocol: udp.ProtocolNumber, 2970 CheckProtocol: true, 2971 OutputInterface: utils.RouterNIC2Name, 2972 }, 2973 Target: &stack.MasqueradeTarget{NetworkProtocol: ipv4.ProtocolNumber}, 2974 }, 2975 { 2976 Target: &stack.AcceptTarget{}, 2977 }, 2978 }, 2979 BuiltinChains: [stack.NumHooks]int{ 2980 stack.Prerouting: 0, 2981 stack.Input: 1, 2982 stack.Forward: 2, 2983 stack.Output: 3, 2984 stack.Postrouting: 4, 2985 }, 2986 } 2987 ipt.ForceReplaceTable(stack.NATID, table, false /* ipv6 */) 2988 2989 routerNIC2Addr := utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address 2990 ep1Addr := utils.Host1IPv4Addr.AddressWithPrefix.Address 2991 var ep1Port uint16 = 1234 2992 ep2Addr := utils.Host2IPv4Addr.AddressWithPrefix.Address 2993 var ep2Port uint16 = 2345 2994 2995 // Inject an incoming packet on NIC1 destined to an address that will be 2996 // routed out of NIC2. Expect that we can read the packet on ep2 coming from 2997 // the stack's address assigned on NIC2, because it should have performed 2998 // Masquerade NAT on the forwarded traffic. 2999 ep1.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 3000 Payload: buffer.MakeWithData(udpv4Packet(ep1Addr, ep2Addr, ep1Port, ep2Port, 0 /* dataSize */)), 3001 })) 3002 pkt := ep2.Read() 3003 if pkt == nil { 3004 t.Fatal("expected to read a packet on ep2") 3005 } 3006 pktView := stack.PayloadSince(pkt.NetworkHeader()) 3007 defer pktView.Release() 3008 pkt.DecRef() 3009 checker.IPv4(t, pktView, 3010 checker.SrcAddr(routerNIC2Addr), 3011 checker.DstAddr(ep2Addr), 3012 checker.UDP( 3013 checker.SrcPort(ep1Port), 3014 checker.DstPort(ep2Port), 3015 ), 3016 ) 3017 3018 // Now bind a UDP socket on the stack itself to the same port used by the 3019 // previous packet, and send a packet to the same address. 3020 var wq waiter.Queue 3021 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 3022 wq.EventRegister(&we) 3023 defer wq.EventUnregister(&we) 3024 3025 ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) 3026 if err != nil { 3027 t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, ipv4.ProtocolNumber, err) 3028 } 3029 defer ep.Close() 3030 3031 srcAddr := tcpip.FullAddress{Addr: routerNIC2Addr, Port: ep1Port} 3032 if err := ep.Bind(srcAddr); err != nil { 3033 t.Fatalf("ep.Bind(%#v): %s", srcAddr, err) 3034 } 3035 dstAddr := tcpip.FullAddress{Addr: ep2Addr, Port: ep2Port} 3036 if err := ep.Connect(dstAddr); err != nil { 3037 t.Fatalf("ep.Connect(%#v): %s", dstAddr, err) 3038 } 3039 3040 data := []byte{1, 2, 3, 4} 3041 var r bytes.Reader 3042 r.Reset(data) 3043 var wOpts tcpip.WriteOptions 3044 n, err := ep.Write(&r, wOpts) 3045 if err != nil { 3046 t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) 3047 } 3048 if want := int64(len(data)); n != want { 3049 t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) 3050 } 3051 3052 // The router should perform source port remapping for the locally generated 3053 // traffic so that it does not conflict with the existing conntrack entry, so 3054 // ep2 should observe the traffic as coming from the router's address, but 3055 // *not* from the same port as the traffic from ep1 before. 3056 pkt = ep2.Read() 3057 if pkt == nil { 3058 t.Fatal("expected to read a packet on ep2") 3059 } 3060 pktView = stack.PayloadSince(pkt.NetworkHeader()) 3061 defer pktView.Release() 3062 pkt.DecRef() 3063 checker.IPv4(t, pktView, 3064 checker.SrcAddr(routerNIC2Addr), 3065 checker.DstAddr(ep2Addr), 3066 checker.UDP( 3067 checker.DstPort(ep2Port), 3068 checker.Payload(data), 3069 ), 3070 ) 3071 gotPort := header.UDP(header.IPv4(pktView.AsSlice()).Payload()).SourcePort() 3072 if gotPort == ep1Port { 3073 t.Errorf("got src port == ep1Port (%d), should be remapped to avoid conflict", gotPort) 3074 } 3075 3076 // We should also be able to reply on either connection, by injecting inbound 3077 // traffic on ep2 destined to the router. 3078 // 3079 // Traffic destined to the port originally used in the traffic injected on ep1 3080 // should go to ep1. 3081 ep2.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 3082 Payload: buffer.MakeWithData(udpv4Packet(ep2Addr, routerNIC2Addr, ep2Port, ep1Port, 0 /* dataSize */)), 3083 })) 3084 pkt = ep1.Read() 3085 if pkt == nil { 3086 t.Fatal("expected to read a packet on ep2") 3087 } 3088 pktView = stack.PayloadSince(pkt.NetworkHeader()) 3089 defer pktView.Release() 3090 pkt.DecRef() 3091 checker.IPv4(t, pktView, 3092 checker.SrcAddr(ep2Addr), 3093 checker.DstAddr(ep1Addr), 3094 checker.UDP( 3095 checker.SrcPort(ep2Port), 3096 checker.DstPort(ep1Port), 3097 ), 3098 ) 3099 3100 // And traffic destined to the remapped source port chosen by conntrack for 3101 // the socket bound on the stack should go to the socket. 3102 reply := udpv4Packet(ep2Addr, routerNIC2Addr, ep2Port, gotPort, 0 /* dataSize */) 3103 reply = append(reply, data...) 3104 ep2.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ 3105 Payload: buffer.MakeWithData(reply), 3106 })) 3107 var buf bytes.Buffer 3108 var res tcpip.ReadResult 3109 for { 3110 var err tcpip.Error 3111 res, err = ep.Read(&buf, tcpip.ReadOptions{}) 3112 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 3113 <-ch 3114 continue 3115 } 3116 if err != nil { 3117 t.Fatalf("ep.Read(_, {}): %s", err) 3118 } 3119 break 3120 } 3121 if diff := cmp.Diff( 3122 tcpip.ReadResult{ 3123 Count: 0, 3124 Total: 0, 3125 }, 3126 res, 3127 checker.IgnoreCmpPath("ControlMessages"), 3128 ); diff != "" { 3129 t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) 3130 } 3131 } 3132 3133 func TestLocallyRoutedPackets(t *testing.T) { 3134 const nicID = 1 3135 3136 tests := []struct { 3137 name string 3138 netProto tcpip.NetworkProtocolNumber 3139 addr tcpip.Address 3140 }{ 3141 { 3142 name: "IPv4", 3143 netProto: ipv4.ProtocolNumber, 3144 addr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 3145 }, 3146 { 3147 name: "IPv6", 3148 netProto: ipv6.ProtocolNumber, 3149 addr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 3150 }, 3151 } 3152 3153 for _, test := range tests { 3154 t.Run(test.name, func(t *testing.T) { 3155 s := stack.New(stack.Options{ 3156 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 3157 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, 3158 }) 3159 defer s.Destroy() 3160 3161 if err := s.CreateNIC(nicID, loopback.New()); err != nil { 3162 t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) 3163 } 3164 protocolAddr := tcpip.ProtocolAddress{ 3165 Protocol: test.netProto, 3166 AddressWithPrefix: test.addr.WithPrefix(), 3167 } 3168 if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { 3169 t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err) 3170 } 3171 3172 s.SetRouteTable([]tcpip.Route{ 3173 { 3174 Destination: protocolAddr.AddressWithPrefix.Subnet(), 3175 NIC: nicID, 3176 }, 3177 }) 3178 3179 // Set IPTables so we create entries in the conntrack table. 3180 { 3181 ipv6 := test.netProto == ipv6.ProtocolNumber 3182 ipt := s.IPTables() 3183 filter := ipt.GetTable(stack.FilterID, ipv6) 3184 ipt.ForceReplaceTable(stack.FilterID, filter, ipv6) 3185 } 3186 3187 var wq waiter.Queue 3188 we, ch := waiter.NewChannelEntry(waiter.ReadableEvents) 3189 wq.EventRegister(&we) 3190 defer wq.EventUnregister(&we) 3191 3192 ep, err := s.NewEndpoint(udp.ProtocolNumber, test.netProto, &wq) 3193 if err != nil { 3194 t.Fatalf("s.NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.netProto, err) 3195 } 3196 defer ep.Close() 3197 3198 fullAddr := tcpip.FullAddress{Addr: test.addr, Port: 1234} 3199 if err := ep.Bind(fullAddr); err != nil { 3200 t.Fatalf("ep.Bind(%#v): %s", fullAddr, err) 3201 } 3202 if err := ep.Connect(fullAddr); err != nil { 3203 t.Fatalf("ep.Connect(%#v): %s", fullAddr, err) 3204 } 3205 3206 data := []byte{1, 2, 3, 4} 3207 3208 var r bytes.Reader 3209 r.Reset(data) 3210 var wOpts tcpip.WriteOptions 3211 n, err := ep.Write(&r, wOpts) 3212 if err != nil { 3213 t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) 3214 } 3215 if want := int64(len(data)); n != want { 3216 t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) 3217 } 3218 3219 var buf bytes.Buffer 3220 var res tcpip.ReadResult 3221 for { 3222 var err tcpip.Error 3223 res, err = ep.Read(&buf, tcpip.ReadOptions{}) 3224 if _, ok := err.(*tcpip.ErrWouldBlock); ok { 3225 <-ch 3226 continue 3227 } 3228 if err != nil { 3229 t.Fatalf("ep.Read(_, {}): %s", err) 3230 } 3231 break 3232 } 3233 if diff := cmp.Diff( 3234 tcpip.ReadResult{ 3235 Count: len(data), 3236 Total: len(data), 3237 }, 3238 res, 3239 checker.IgnoreCmpPath("ControlMessages"), 3240 ); diff != "" { 3241 t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) 3242 } 3243 if diff := cmp.Diff(buf.Bytes(), data); diff != "" { 3244 t.Errorf("received data mismatch (-want +got):\n%s", diff) 3245 } 3246 }) 3247 } 3248 } 3249 3250 type icmpv4Matcher struct { 3251 icmpType header.ICMPv4Type 3252 } 3253 3254 func (m *icmpv4Matcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches bool, hotdrop bool) { 3255 if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber { 3256 return false, false 3257 } 3258 3259 if pkt.TransportProtocolNumber != header.ICMPv4ProtocolNumber { 3260 return false, false 3261 } 3262 3263 return header.ICMPv4(pkt.TransportHeader().Slice()).Type() == m.icmpType, false 3264 } 3265 3266 type icmpv6Matcher struct { 3267 icmpType header.ICMPv6Type 3268 } 3269 3270 func (m *icmpv6Matcher) Match(_ stack.Hook, pkt *stack.PacketBuffer, _, _ string) (matches bool, hotdrop bool) { 3271 if pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { 3272 return false, false 3273 } 3274 3275 if pkt.TransportProtocolNumber != header.ICMPv6ProtocolNumber { 3276 return false, false 3277 } 3278 3279 return header.ICMPv6(pkt.TransportHeader().Slice()).Type() == m.icmpType, false 3280 } 3281 3282 func TestRejectWith(t *testing.T) { 3283 type natHook struct { 3284 hook stack.Hook 3285 dstAddr tcpip.Address 3286 matcher stack.Matcher 3287 3288 errorICMPDstAddr tcpip.Address 3289 errorICMPPayload []byte 3290 } 3291 3292 type rejectWithVal struct { 3293 name string 3294 val int 3295 errorICMPCode uint8 3296 } 3297 3298 rxICMPv4EchoRequest := func(dst tcpip.Address) []byte { 3299 return utils.ICMPv4Echo(utils.Host1IPv4Addr.AddressWithPrefix.Address, dst, ttl, header.ICMPv4Echo) 3300 } 3301 3302 rxICMPv6EchoRequest := func(dst tcpip.Address) []byte { 3303 return utils.ICMPv6Echo(utils.Host1IPv6Addr.AddressWithPrefix.Address, dst, ttl, header.ICMPv6EchoRequest) 3304 } 3305 3306 tests := []struct { 3307 name string 3308 netProto tcpip.NetworkProtocolNumber 3309 rxICMPEchoRequest func(tcpip.Address) []byte 3310 icmpChecker func(*testing.T, *buffer.View, tcpip.Address, uint8, uint8, []byte) 3311 3312 natHooks []natHook 3313 3314 rejectTarget func(*testing.T, stack.NetworkProtocol, int) stack.Target 3315 rejectWithVals []rejectWithVal 3316 errorICMPType uint8 3317 }{ 3318 { 3319 name: "IPv4", 3320 netProto: header.IPv4ProtocolNumber, 3321 rxICMPEchoRequest: rxICMPv4EchoRequest, 3322 3323 icmpChecker: func(t *testing.T, v *buffer.View, dstAddr tcpip.Address, icmpType, icmpCode uint8, origPayload []byte) { 3324 t.Helper() 3325 3326 checker.IPv4(t, v, 3327 checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 3328 checker.DstAddr(dstAddr), 3329 checker.ICMPv4( 3330 checker.ICMPv4Checksum(), 3331 checker.ICMPv4Type(header.ICMPv4Type(icmpType)), 3332 checker.ICMPv4Code(header.ICMPv4Code(icmpCode)), 3333 checker.ICMPv4Payload(origPayload), 3334 ), 3335 ) 3336 }, 3337 natHooks: []natHook{ 3338 { 3339 hook: stack.Input, 3340 dstAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 3341 matcher: &icmpv4Matcher{icmpType: header.ICMPv4Echo}, 3342 errorICMPDstAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 3343 errorICMPPayload: rxICMPv4EchoRequest(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address), 3344 }, 3345 { 3346 hook: stack.Forward, 3347 dstAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, 3348 matcher: &icmpv4Matcher{icmpType: header.ICMPv4Echo}, 3349 errorICMPDstAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, 3350 errorICMPPayload: rxICMPv4EchoRequest(utils.Host2IPv4Addr.AddressWithPrefix.Address), 3351 }, 3352 { 3353 hook: stack.Output, 3354 dstAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 3355 matcher: &icmpv4Matcher{icmpType: header.ICMPv4EchoReply}, 3356 errorICMPDstAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, 3357 errorICMPPayload: utils.ICMPv4Echo(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, utils.Host1IPv4Addr.AddressWithPrefix.Address, ttl, header.ICMPv4EchoReply), 3358 }, 3359 }, 3360 rejectTarget: func(t *testing.T, netProto stack.NetworkProtocol, rejectWith int) stack.Target { 3361 handler, ok := netProto.(stack.RejectIPv4WithHandler) 3362 if !ok { 3363 t.Fatalf("expected %T to implement %T", netProto, handler) 3364 } 3365 3366 return &stack.RejectIPv4Target{ 3367 Handler: handler, 3368 RejectWith: stack.RejectIPv4WithICMPType(rejectWith), 3369 } 3370 }, 3371 rejectWithVals: []rejectWithVal{ 3372 { 3373 name: "ICMP Network Unreachable", 3374 val: int(stack.RejectIPv4WithICMPNetUnreachable), 3375 errorICMPCode: uint8(header.ICMPv4NetUnreachable), 3376 }, 3377 { 3378 name: "ICMP Host Unreachable", 3379 val: int(stack.RejectIPv4WithICMPHostUnreachable), 3380 errorICMPCode: uint8(header.ICMPv4HostUnreachable), 3381 }, 3382 { 3383 name: "ICMP Port Unreachable", 3384 val: int(stack.RejectIPv4WithICMPPortUnreachable), 3385 errorICMPCode: uint8(header.ICMPv4PortUnreachable), 3386 }, 3387 { 3388 name: "ICMP Network Prohibited", 3389 val: int(stack.RejectIPv4WithICMPNetProhibited), 3390 errorICMPCode: uint8(header.ICMPv4NetProhibited), 3391 }, 3392 { 3393 name: "ICMP Host Prohibited", 3394 val: int(stack.RejectIPv4WithICMPHostProhibited), 3395 errorICMPCode: uint8(header.ICMPv4HostProhibited), 3396 }, 3397 { 3398 name: "ICMP Administratively Prohibited", 3399 val: int(stack.RejectIPv4WithICMPAdminProhibited), 3400 errorICMPCode: uint8(header.ICMPv4AdminProhibited), 3401 }, 3402 }, 3403 errorICMPType: uint8(header.ICMPv4DstUnreachable), 3404 }, 3405 { 3406 name: "IPv6", 3407 netProto: header.IPv6ProtocolNumber, 3408 rxICMPEchoRequest: rxICMPv6EchoRequest, 3409 3410 icmpChecker: func(t *testing.T, v *buffer.View, dstAddr tcpip.Address, icmpType, icmpCode uint8, origPayload []byte) { 3411 t.Helper() 3412 3413 checker.IPv6(t, v, 3414 checker.SrcAddr(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 3415 checker.DstAddr(dstAddr), 3416 checker.ICMPv6( 3417 checker.ICMPv6Type(header.ICMPv6Type(icmpType)), 3418 checker.ICMPv6Code(header.ICMPv6Code(icmpCode)), 3419 checker.ICMPv6Payload(origPayload), 3420 ), 3421 ) 3422 }, 3423 natHooks: []natHook{ 3424 { 3425 hook: stack.Input, 3426 dstAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 3427 matcher: &icmpv6Matcher{icmpType: header.ICMPv6EchoRequest}, 3428 errorICMPDstAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 3429 errorICMPPayload: rxICMPv6EchoRequest(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address), 3430 }, 3431 { 3432 hook: stack.Forward, 3433 dstAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, 3434 matcher: &icmpv6Matcher{icmpType: header.ICMPv6EchoRequest}, 3435 errorICMPDstAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, 3436 errorICMPPayload: rxICMPv6EchoRequest(utils.Host2IPv6Addr.AddressWithPrefix.Address), 3437 }, 3438 { 3439 hook: stack.Output, 3440 dstAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 3441 matcher: &icmpv6Matcher{icmpType: header.ICMPv6EchoReply}, 3442 errorICMPDstAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, 3443 errorICMPPayload: utils.ICMPv6Echo(utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, utils.Host1IPv6Addr.AddressWithPrefix.Address, ttl, header.ICMPv6EchoReply), 3444 }, 3445 }, 3446 rejectTarget: func(t *testing.T, netProto stack.NetworkProtocol, rejectWith int) stack.Target { 3447 handler, ok := netProto.(stack.RejectIPv6WithHandler) 3448 if !ok { 3449 t.Fatalf("expected %T to implement %T", netProto, handler) 3450 } 3451 3452 return &stack.RejectIPv6Target{ 3453 Handler: handler, 3454 RejectWith: stack.RejectIPv6WithICMPType(rejectWith), 3455 } 3456 }, 3457 rejectWithVals: []rejectWithVal{ 3458 { 3459 name: "ICMP No Route", 3460 val: int(stack.RejectIPv6WithICMPNoRoute), 3461 errorICMPCode: uint8(header.ICMPv6NetworkUnreachable), 3462 }, 3463 { 3464 name: "ICMP Address Unreachable", 3465 val: int(stack.RejectIPv6WithICMPAddrUnreachable), 3466 errorICMPCode: uint8(header.ICMPv6AddressUnreachable), 3467 }, 3468 { 3469 name: "ICMP Port Unreachable", 3470 val: int(stack.RejectIPv6WithICMPPortUnreachable), 3471 errorICMPCode: uint8(header.ICMPv6PortUnreachable), 3472 }, 3473 { 3474 name: "ICMP Administratively Prohibited", 3475 val: int(stack.RejectIPv6WithICMPAdminProhibited), 3476 errorICMPCode: uint8(header.ICMPv6Prohibited), 3477 }, 3478 }, 3479 errorICMPType: uint8(header.ICMPv6DstUnreachable), 3480 }, 3481 } 3482 3483 for _, test := range tests { 3484 t.Run(test.name, func(t *testing.T) { 3485 for _, natHook := range test.natHooks { 3486 t.Run(natHook.hook.String(), func(t *testing.T) { 3487 for _, rejectWith := range test.rejectWithVals { 3488 t.Run(rejectWith.name, func(t *testing.T) { 3489 s := stack.New(stack.Options{ 3490 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 3491 TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, 3492 }) 3493 defer s.Destroy() 3494 3495 ep1 := channel.New(1, header.IPv6MinimumMTU, "") 3496 ep2 := channel.New(1, header.IPv6MinimumMTU, "") 3497 utils.SetupRouterStack(t, s, ep1, ep2) 3498 3499 { 3500 ipv6 := test.netProto == ipv6.ProtocolNumber 3501 ipt := s.IPTables() 3502 filter := ipt.GetTable(stack.FilterID, ipv6) 3503 ruleIdx := filter.BuiltinChains[natHook.hook] 3504 filter.Rules[ruleIdx].Matchers = []stack.Matcher{natHook.matcher} 3505 filter.Rules[ruleIdx].Target = test.rejectTarget(t, s.NetworkProtocolInstance(test.netProto), rejectWith.val) 3506 // Make sure the packet is not dropped by the next rule. 3507 filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} 3508 ipt.ForceReplaceTable(stack.FilterID, filter, ipv6) 3509 } 3510 3511 func() { 3512 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 3513 Payload: buffer.MakeWithData(test.rxICMPEchoRequest(natHook.dstAddr)), 3514 }) 3515 defer pkt.DecRef() 3516 ep1.InjectInbound(test.netProto, pkt) 3517 }() 3518 3519 { 3520 pkt := ep1.Read() 3521 if pkt == nil { 3522 t.Fatal("expected to read a packet on ep1") 3523 } 3524 payload := stack.PayloadSince(pkt.NetworkHeader()) 3525 defer payload.Release() 3526 test.icmpChecker( 3527 t, 3528 payload, 3529 natHook.errorICMPDstAddr, 3530 test.errorICMPType, 3531 rejectWith.errorICMPCode, 3532 natHook.errorICMPPayload, 3533 ) 3534 pkt.DecRef() 3535 } 3536 }) 3537 } 3538 }) 3539 } 3540 }) 3541 } 3542 } 3543 3544 // TestInvalidTransportHeader tests that bad transport headers (with a bad 3545 // length/offset field) don't panic. 3546 func TestInvalidTransportHeader(t *testing.T) { 3547 tests := []struct { 3548 name string 3549 setupStack func(*testing.T) (*stack.Stack, *channel.Endpoint) 3550 genPacket func(int8) *stack.PacketBuffer 3551 offset int8 3552 }{ 3553 { 3554 name: "TCP4 offset small", 3555 setupStack: genStackV4, 3556 genPacket: genTCP4, 3557 offset: -1, 3558 }, 3559 { 3560 name: "TCP4 offset large", 3561 setupStack: genStackV4, 3562 genPacket: genTCP4, 3563 offset: 1, 3564 }, 3565 { 3566 name: "UDP4 offset small", 3567 setupStack: genStackV4, 3568 genPacket: genUDP4, 3569 offset: -1, 3570 }, 3571 { 3572 name: "UDP4 offset large", 3573 setupStack: genStackV4, 3574 genPacket: genUDP4, 3575 offset: 1, 3576 }, 3577 { 3578 name: "TCP6 offset small", 3579 setupStack: genStackV6, 3580 genPacket: genTCP6, 3581 offset: -1, 3582 }, 3583 { 3584 name: "TCP6 offset large", 3585 setupStack: genStackV6, 3586 genPacket: genTCP6, 3587 offset: 1, 3588 }, 3589 { 3590 name: "UDP6 offset small", 3591 setupStack: genStackV6, 3592 genPacket: genUDP6, 3593 offset: -1, 3594 }, 3595 { 3596 name: "UDP6 offset large", 3597 setupStack: genStackV6, 3598 genPacket: genUDP6, 3599 offset: 1, 3600 }, 3601 } 3602 3603 for _, test := range tests { 3604 t.Run(test.name, func(t *testing.T) { 3605 s, e := test.setupStack(t) 3606 3607 // Enable iptables and conntrack. 3608 ipt := s.IPTables() 3609 filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) 3610 ipt.ForceReplaceTable(stack.FilterID, filter, false /* ipv6 */) 3611 3612 // This can panic if conntrack isn't checking lengths. 3613 e.InjectInbound(header.IPv4ProtocolNumber, test.genPacket(test.offset)) 3614 }) 3615 } 3616 } 3617 3618 func genTCP4(offset int8) *stack.PacketBuffer { 3619 pktSize := header.IPv4MinimumSize + header.TCPMinimumSize 3620 hdr := prependable.New(pktSize) 3621 3622 tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) 3623 tcp.Encode(&header.TCPFields{ 3624 SeqNum: 0, 3625 AckNum: 0, 3626 DataOffset: header.TCPMinimumSize + uint8(offset)*4, // DataOffset must be a multiple of 4. 3627 Flags: header.TCPFlagSyn, 3628 Checksum: 0, 3629 }) 3630 3631 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 3632 ip.Encode(&header.IPv4Fields{ 3633 TOS: 0, 3634 TotalLength: uint16(pktSize), 3635 ID: 1, 3636 Flags: 0, 3637 FragmentOffset: 0, 3638 TTL: 48, 3639 Protocol: uint8(header.TCPProtocolNumber), 3640 SrcAddr: srcAddrV4, 3641 DstAddr: dstAddrV4, 3642 }) 3643 ip.SetChecksum(0) 3644 ip.SetChecksum(^ip.CalculateChecksum()) 3645 3646 buf := buffer.MakeWithData(append([]byte{}, hdr.View()...)) 3647 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 3648 } 3649 3650 func genTCP6(offset int8) *stack.PacketBuffer { 3651 pktSize := header.IPv6MinimumSize + header.TCPMinimumSize 3652 hdr := prependable.New(pktSize) 3653 3654 tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) 3655 tcp.Encode(&header.TCPFields{ 3656 SeqNum: 0, 3657 AckNum: 0, 3658 DataOffset: header.TCPMinimumSize + uint8(offset)*4, // DataOffset must be a multiple of 4. 3659 Flags: header.TCPFlagSyn, 3660 Checksum: 0, 3661 }) 3662 3663 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 3664 ip.Encode(&header.IPv6Fields{ 3665 PayloadLength: header.TCPMinimumSize, 3666 TransportProtocol: header.TCPProtocolNumber, 3667 HopLimit: 255, 3668 SrcAddr: srcAddrV6, 3669 DstAddr: dstAddrV6, 3670 }) 3671 3672 buf := buffer.MakeWithData(append([]byte{}, hdr.View()...)) 3673 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 3674 } 3675 3676 func genUDP4(offset int8) *stack.PacketBuffer { 3677 pktSize := header.IPv4MinimumSize + header.UDPMinimumSize 3678 hdr := prependable.New(pktSize) 3679 3680 udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) 3681 udp.Encode(&header.UDPFields{ 3682 SrcPort: 343, 3683 DstPort: 2401, 3684 Length: header.UDPMinimumSize + uint16(offset), 3685 Checksum: 0, 3686 }) 3687 3688 ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) 3689 ip.Encode(&header.IPv4Fields{ 3690 TOS: 0, 3691 TotalLength: uint16(pktSize), 3692 ID: 1, 3693 Flags: 0, 3694 FragmentOffset: 0, 3695 TTL: 48, 3696 Protocol: uint8(header.UDPProtocolNumber), 3697 SrcAddr: srcAddrV4, 3698 DstAddr: dstAddrV4, 3699 }) 3700 ip.SetChecksum(0) 3701 ip.SetChecksum(^ip.CalculateChecksum()) 3702 3703 buf := buffer.MakeWithData(append([]byte{}, hdr.View()...)) 3704 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 3705 } 3706 3707 func genUDP6(offset int8) *stack.PacketBuffer { 3708 pktSize := header.IPv6MinimumSize + header.UDPMinimumSize 3709 hdr := prependable.New(pktSize) 3710 3711 udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) 3712 udp.Encode(&header.UDPFields{ 3713 SrcPort: 343, 3714 DstPort: 2401, 3715 Length: header.UDPMinimumSize + uint16(offset), 3716 Checksum: 0, 3717 }) 3718 3719 ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) 3720 ip.Encode(&header.IPv6Fields{ 3721 PayloadLength: header.UDPMinimumSize, 3722 TransportProtocol: header.UDPProtocolNumber, 3723 HopLimit: 255, 3724 SrcAddr: srcAddrV6, 3725 DstAddr: dstAddrV6, 3726 }) 3727 3728 buf := buffer.MakeWithData(append([]byte{}, hdr.View()...)) 3729 return stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buf}) 3730 }