github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/sys_conn_oob_test.go (about)

     1  //go:build darwin || linux || freebsd
     2  
     3  package quic
     4  
     5  import (
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"golang.org/x/net/ipv4"
    11  	"golang.org/x/sys/unix"
    12  
    13  	"github.com/apernet/quic-go/internal/protocol"
    14  	"github.com/apernet/quic-go/internal/utils"
    15  
    16  	. "github.com/onsi/ginkgo/v2"
    17  	. "github.com/onsi/gomega"
    18  	"go.uber.org/mock/gomock"
    19  )
    20  
    21  type oobRecordingConn struct {
    22  	*net.UDPConn
    23  	oobs [][]byte
    24  }
    25  
    26  func (c *oobRecordingConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) {
    27  	c.oobs = append(c.oobs, oob)
    28  	return c.UDPConn.WriteMsgUDP(b, oob, addr)
    29  }
    30  
    31  var _ = Describe("OOB Conn Test", func() {
    32  	runServer := func(network, address string) (*net.UDPConn, <-chan receivedPacket) {
    33  		addr, err := net.ResolveUDPAddr(network, address)
    34  		Expect(err).ToNot(HaveOccurred())
    35  		udpConn, err := net.ListenUDP(network, addr)
    36  		Expect(err).ToNot(HaveOccurred())
    37  		oobConn, err := newConn(udpConn, true)
    38  		Expect(err).ToNot(HaveOccurred())
    39  		Expect(oobConn.capabilities().DF).To(BeTrue())
    40  
    41  		packetChan := make(chan receivedPacket)
    42  		go func() {
    43  			defer GinkgoRecover()
    44  			for {
    45  				p, err := oobConn.ReadPacket()
    46  				if err != nil {
    47  					return
    48  				}
    49  				packetChan <- p
    50  			}
    51  		}()
    52  
    53  		return udpConn, packetChan
    54  	}
    55  
    56  	Context("reading ECN-marked packets", func() {
    57  		sendPacketWithECN := func(network string, addr *net.UDPAddr, setECN func(uintptr)) net.Addr {
    58  			conn, err := net.DialUDP(network, nil, addr)
    59  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
    60  			rawConn, err := conn.SyscallConn()
    61  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
    62  			ExpectWithOffset(1, rawConn.Control(func(fd uintptr) {
    63  				setECN(fd)
    64  			})).To(Succeed())
    65  			_, err = conn.Write([]byte("foobar"))
    66  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
    67  			return conn.LocalAddr()
    68  		}
    69  
    70  		It("reads ECN flags on IPv4", func() {
    71  			conn, packetChan := runServer("udp4", "localhost:0")
    72  			defer conn.Close()
    73  
    74  			sentFrom := sendPacketWithECN(
    75  				"udp4",
    76  				conn.LocalAddr().(*net.UDPAddr),
    77  				func(fd uintptr) {
    78  					Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 2)).To(Succeed())
    79  				},
    80  			)
    81  
    82  			var p receivedPacket
    83  			Eventually(packetChan).Should(Receive(&p))
    84  			Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
    85  			Expect(p.data).To(Equal([]byte("foobar")))
    86  			Expect(p.remoteAddr).To(Equal(sentFrom))
    87  			Expect(p.ecn).To(Equal(protocol.ECT0))
    88  		})
    89  
    90  		It("reads ECN flags on IPv6", func() {
    91  			conn, packetChan := runServer("udp6", "[::]:0")
    92  			defer conn.Close()
    93  
    94  			sentFrom := sendPacketWithECN(
    95  				"udp6",
    96  				conn.LocalAddr().(*net.UDPAddr),
    97  				func(fd uintptr) {
    98  					Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 3)).To(Succeed())
    99  				},
   100  			)
   101  
   102  			var p receivedPacket
   103  			Eventually(packetChan).Should(Receive(&p))
   104  			Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
   105  			Expect(p.data).To(Equal([]byte("foobar")))
   106  			Expect(p.remoteAddr).To(Equal(sentFrom))
   107  			Expect(p.ecn).To(Equal(protocol.ECNCE))
   108  		})
   109  
   110  		It("reads ECN flags on a connection that supports both IPv4 and IPv6", func() {
   111  			conn, packetChan := runServer("udp", "0.0.0.0:0")
   112  			defer conn.Close()
   113  			port := conn.LocalAddr().(*net.UDPAddr).Port
   114  
   115  			// IPv4
   116  			sendPacketWithECN(
   117  				"udp4",
   118  				&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port},
   119  				func(fd uintptr) {
   120  					Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_TOS, 3)).To(Succeed())
   121  				},
   122  			)
   123  
   124  			var p receivedPacket
   125  			Eventually(packetChan).Should(Receive(&p))
   126  			Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
   127  			Expect(p.ecn).To(Equal(protocol.ECNCE))
   128  
   129  			// IPv6
   130  			sendPacketWithECN(
   131  				"udp6",
   132  				&net.UDPAddr{IP: net.IPv6loopback, Port: port},
   133  				func(fd uintptr) {
   134  					Expect(unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_TCLASS, 1)).To(Succeed())
   135  				},
   136  			)
   137  
   138  			Eventually(packetChan).Should(Receive(&p))
   139  			Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
   140  			Expect(p.ecn).To(Equal(protocol.ECT1))
   141  		})
   142  
   143  		It("sends packets with ECN on IPv4", func() {
   144  			conn, packetChan := runServer("udp4", "localhost:0")
   145  			defer conn.Close()
   146  
   147  			c, err := net.ListenUDP("udp4", nil)
   148  			Expect(err).ToNot(HaveOccurred())
   149  			defer c.Close()
   150  
   151  			for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} {
   152  				_, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv4ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr))
   153  				Expect(err).ToNot(HaveOccurred())
   154  				var p receivedPacket
   155  				Eventually(packetChan).Should(Receive(&p))
   156  				Expect(p.data).To(Equal([]byte("foobar")))
   157  				Expect(p.ecn).To(Equal(val))
   158  			}
   159  		})
   160  
   161  		It("sends packets with ECN on IPv6", func() {
   162  			conn, packetChan := runServer("udp6", "[::1]:0")
   163  			defer conn.Close()
   164  
   165  			c, err := net.ListenUDP("udp6", nil)
   166  			Expect(err).ToNot(HaveOccurred())
   167  			defer c.Close()
   168  
   169  			for _, val := range []protocol.ECN{protocol.ECNNon, protocol.ECT1, protocol.ECT0, protocol.ECNCE} {
   170  				_, _, err = c.WriteMsgUDP([]byte("foobar"), appendIPv6ECNMsg([]byte{}, val), conn.LocalAddr().(*net.UDPAddr))
   171  				Expect(err).ToNot(HaveOccurred())
   172  				var p receivedPacket
   173  				Eventually(packetChan).Should(Receive(&p))
   174  				Expect(p.data).To(Equal([]byte("foobar")))
   175  				Expect(p.ecn).To(Equal(val))
   176  			}
   177  		})
   178  	})
   179  
   180  	Context("Packet Info conn", func() {
   181  		sendPacket := func(network string, addr *net.UDPAddr) net.Addr {
   182  			conn, err := net.DialUDP(network, nil, addr)
   183  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   184  			_, err = conn.Write([]byte("foobar"))
   185  			ExpectWithOffset(1, err).ToNot(HaveOccurred())
   186  			return conn.LocalAddr()
   187  		}
   188  
   189  		It("reads packet info on IPv4", func() {
   190  			conn, packetChan := runServer("udp4", ":0")
   191  			defer conn.Close()
   192  
   193  			addr := conn.LocalAddr().(*net.UDPAddr)
   194  			ip := net.ParseIP("127.0.0.1").To4()
   195  			addr.IP = ip
   196  			sentFrom := sendPacket("udp4", addr)
   197  
   198  			var p receivedPacket
   199  			Eventually(packetChan).Should(Receive(&p))
   200  			Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
   201  			Expect(p.data).To(Equal([]byte("foobar")))
   202  			Expect(p.remoteAddr).To(Equal(sentFrom))
   203  			Expect(p.info.addr.IsValid()).To(BeTrue())
   204  			Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip))
   205  		})
   206  
   207  		It("reads packet info on IPv6", func() {
   208  			conn, packetChan := runServer("udp6", ":0")
   209  			defer conn.Close()
   210  
   211  			addr := conn.LocalAddr().(*net.UDPAddr)
   212  			ip := net.ParseIP("::1")
   213  			addr.IP = ip
   214  			sentFrom := sendPacket("udp6", addr)
   215  
   216  			var p receivedPacket
   217  			Eventually(packetChan).Should(Receive(&p))
   218  			Expect(p.rcvTime).To(BeTemporally("~", time.Now(), scaleDuration(20*time.Millisecond)))
   219  			Expect(p.data).To(Equal([]byte("foobar")))
   220  			Expect(p.remoteAddr).To(Equal(sentFrom))
   221  			Expect(p.info).To(Not(BeNil()))
   222  			Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip))
   223  		})
   224  
   225  		It("reads packet info on a connection that supports both IPv4 and IPv6", func() {
   226  			conn, packetChan := runServer("udp", ":0")
   227  			defer conn.Close()
   228  			port := conn.LocalAddr().(*net.UDPAddr).Port
   229  
   230  			// IPv4
   231  			ip4 := net.ParseIP("127.0.0.1")
   232  			sendPacket("udp4", &net.UDPAddr{IP: ip4, Port: port})
   233  
   234  			var p receivedPacket
   235  			Eventually(packetChan).Should(Receive(&p))
   236  			Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeTrue())
   237  			Expect(p.info).To(Not(BeNil()))
   238  			Expect(p.info.addr.Is4()).To(BeTrue())
   239  			ip := p.info.addr.As4()
   240  			Expect(net.IP(ip[:])).To(Equal(ip4.To4()))
   241  
   242  			// IPv6
   243  			ip6 := net.ParseIP("::1")
   244  			sendPacket("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: port})
   245  
   246  			Eventually(packetChan).Should(Receive(&p))
   247  			Expect(utils.IsIPv4(p.remoteAddr.(*net.UDPAddr).IP)).To(BeFalse())
   248  			Expect(p.info).To(Not(BeNil()))
   249  			Expect(net.IP(p.info.addr.AsSlice())).To(Equal(ip6))
   250  		})
   251  	})
   252  
   253  	Context("Batch Reading", func() {
   254  		var batchConn *MockBatchConn
   255  
   256  		BeforeEach(func() {
   257  			batchConn = NewMockBatchConn(mockCtrl)
   258  		})
   259  
   260  		It("reads multiple messages in one batch", func() {
   261  			const numMsgRead = batchSize/2 + 1
   262  			var counter int
   263  			batchConn.EXPECT().ReadBatch(gomock.Any(), gomock.Any()).DoAndReturn(func(ms []ipv4.Message, flags int) (int, error) {
   264  				Expect(ms).To(HaveLen(batchSize))
   265  				for i := 0; i < numMsgRead; i++ {
   266  					Expect(ms[i].Buffers).To(HaveLen(1))
   267  					Expect(ms[i].Buffers[0]).To(HaveLen(protocol.MaxPacketBufferSize))
   268  					data := []byte(fmt.Sprintf("message %d", counter))
   269  					counter++
   270  					ms[i].Buffers[0] = data
   271  					ms[i].N = len(data)
   272  				}
   273  				return numMsgRead, nil
   274  			}).Times(2)
   275  
   276  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   277  			Expect(err).ToNot(HaveOccurred())
   278  			udpConn, err := net.ListenUDP("udp", addr)
   279  			Expect(err).ToNot(HaveOccurred())
   280  			oobConn, err := newConn(udpConn, true)
   281  			Expect(err).ToNot(HaveOccurred())
   282  			oobConn.batchConn = batchConn
   283  
   284  			for i := 0; i < batchSize+1; i++ {
   285  				p, err := oobConn.ReadPacket()
   286  				Expect(err).ToNot(HaveOccurred())
   287  				Expect(string(p.data)).To(Equal(fmt.Sprintf("message %d", i)))
   288  			}
   289  		})
   290  	})
   291  
   292  	Context("sending ECN-marked packets", func() {
   293  		It("sets the ECN control message", func() {
   294  			addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   295  			Expect(err).ToNot(HaveOccurred())
   296  			udpConn, err := net.ListenUDP("udp", addr)
   297  			Expect(err).ToNot(HaveOccurred())
   298  			c := &oobRecordingConn{UDPConn: udpConn}
   299  			oobConn, err := newConn(c, true)
   300  			Expect(err).ToNot(HaveOccurred())
   301  
   302  			oob := make([]byte, 0, 123)
   303  			oobConn.WritePacket([]byte("foobar"), addr, oob, 0, protocol.ECNCE)
   304  			Expect(c.oobs).To(HaveLen(1))
   305  			oobMsg := c.oobs[0]
   306  			Expect(oobMsg).ToNot(BeEmpty())
   307  			Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
   308  			expected := appendIPv4ECNMsg([]byte{}, protocol.ECNCE)
   309  			Expect(oobMsg).To(Equal(expected))
   310  		})
   311  	})
   312  
   313  	if platformSupportsGSO {
   314  		Context("GSO", func() {
   315  			It("appends the GSO control message", func() {
   316  				addr, err := net.ResolveUDPAddr("udp", "localhost:0")
   317  				Expect(err).ToNot(HaveOccurred())
   318  				udpConn, err := net.ListenUDP("udp", addr)
   319  				Expect(err).ToNot(HaveOccurred())
   320  				c := &oobRecordingConn{UDPConn: udpConn}
   321  				oobConn, err := newConn(c, true)
   322  				Expect(err).ToNot(HaveOccurred())
   323  				Expect(oobConn.capabilities().GSO).To(BeTrue())
   324  
   325  				oob := make([]byte, 0, 123)
   326  				oobConn.WritePacket([]byte("foobar"), addr, oob, 3, protocol.ECNCE)
   327  				Expect(c.oobs).To(HaveLen(1))
   328  				oobMsg := c.oobs[0]
   329  				Expect(oobMsg).ToNot(BeEmpty())
   330  				Expect(oobMsg).To(HaveCap(cap(oob))) // check that it appended to oob
   331  				expected := appendUDPSegmentSizeMsg([]byte{}, 3)
   332  				// Check that the first control message is the OOB control message.
   333  				Expect(oobMsg[:len(expected)]).To(Equal(expected))
   334  			})
   335  		})
   336  	}
   337  })