golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/bpf/vm_bpf_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bpf_test
     6  
     7  import (
     8  	"net"
     9  	"runtime"
    10  	"testing"
    11  	"time"
    12  
    13  	"golang.org/x/net/bpf"
    14  	"golang.org/x/net/ipv4"
    15  	"golang.org/x/net/ipv6"
    16  	"golang.org/x/net/nettest"
    17  	"golang.org/x/sys/cpu"
    18  )
    19  
    20  // A virtualMachine is a BPF virtual machine which can process an
    21  // input packet against a BPF program and render a verdict.
    22  type virtualMachine interface {
    23  	Run(in []byte) (int, error)
    24  }
    25  
    26  // All BPF tests against both the Go VM and OS VM are assumed to
    27  // be used with a UDP socket. As a result, the entire contents
    28  // of a UDP datagram is sent through the BPF program, but only
    29  // the body after the UDP header will ever be returned in output.
    30  
    31  // testVM sets up a Go BPF VM, and if available, a native OS BPF VM
    32  // for integration testing.
    33  func testVM(t *testing.T, filter []bpf.Instruction) (virtualMachine, func(), error) {
    34  	goVM, err := bpf.NewVM(filter)
    35  	if err != nil {
    36  		// Some tests expect an error, so this error must be returned
    37  		// instead of fatally exiting the test
    38  		return nil, nil, err
    39  	}
    40  
    41  	mvm := &multiVirtualMachine{
    42  		goVM: goVM,
    43  
    44  		t: t,
    45  	}
    46  
    47  	// For linux with a little endian CPU, the Go VM and OS VM have exactly the
    48  	// same output for the same input program and packet. Compare both.
    49  	done := func() {}
    50  	if runtime.GOOS == "linux" && !cpu.IsBigEndian {
    51  		osVM, osVMDone := testOSVM(t, filter)
    52  		done = func() { osVMDone() }
    53  		mvm.osVM = osVM
    54  	}
    55  
    56  	return mvm, done, nil
    57  }
    58  
    59  // udpHeaderLen is the length of a UDP header.
    60  const udpHeaderLen = 8
    61  
    62  // A multiVirtualMachine is a virtualMachine which can call out to both the Go VM
    63  // and the native OS VM, if the OS VM is available.
    64  type multiVirtualMachine struct {
    65  	goVM virtualMachine
    66  	osVM virtualMachine
    67  
    68  	t *testing.T
    69  }
    70  
    71  func (mvm *multiVirtualMachine) Run(in []byte) (int, error) {
    72  	if len(in) < udpHeaderLen {
    73  		mvm.t.Fatalf("input must be at least length of UDP header (%d), got: %d",
    74  			udpHeaderLen, len(in))
    75  	}
    76  
    77  	// All tests have a UDP header as part of input, because the OS VM
    78  	// packets always will. For the Go VM, this output is trimmed before
    79  	// being sent back to tests.
    80  	goOut, goErr := mvm.goVM.Run(in)
    81  	if goOut >= udpHeaderLen {
    82  		goOut -= udpHeaderLen
    83  	}
    84  
    85  	// If Go output is larger than the size of the packet, packet filtering
    86  	// interop tests must trim the output bytes to the length of the packet.
    87  	// The BPF VM should not do this on its own, as other uses of it do
    88  	// not trim the output byte count.
    89  	trim := len(in) - udpHeaderLen
    90  	if goOut > trim {
    91  		goOut = trim
    92  	}
    93  
    94  	// When the OS VM is not available, process using the Go VM alone
    95  	if mvm.osVM == nil {
    96  		return goOut, goErr
    97  	}
    98  
    99  	// The OS VM will apply its own UDP header, so remove the pseudo header
   100  	// that the Go VM needs.
   101  	osOut, err := mvm.osVM.Run(in[udpHeaderLen:])
   102  	if err != nil {
   103  		mvm.t.Fatalf("error while running OS VM: %v", err)
   104  	}
   105  
   106  	// Verify both VMs return same number of bytes
   107  	var mismatch bool
   108  	if goOut != osOut {
   109  		mismatch = true
   110  		mvm.t.Logf("output byte count does not match:\n- go: %v\n- os: %v", goOut, osOut)
   111  	}
   112  
   113  	if mismatch {
   114  		mvm.t.Fatal("Go BPF and OS BPF packet outputs do not match")
   115  	}
   116  
   117  	return goOut, goErr
   118  }
   119  
   120  // An osVirtualMachine is a virtualMachine which uses the OS's BPF VM for
   121  // processing BPF programs.
   122  type osVirtualMachine struct {
   123  	l net.PacketConn
   124  	s net.Conn
   125  }
   126  
   127  // testOSVM creates a virtualMachine which uses the OS's BPF VM by injecting
   128  // packets into a UDP listener with a BPF program attached to it.
   129  func testOSVM(t *testing.T, filter []bpf.Instruction) (virtualMachine, func()) {
   130  	l, err := nettest.NewLocalPacketListener("udp")
   131  	if err != nil {
   132  		t.Fatalf("failed to open OS VM UDP listener: %v", err)
   133  	}
   134  
   135  	prog, err := bpf.Assemble(filter)
   136  	if err != nil {
   137  		t.Fatalf("failed to compile BPF program: %v", err)
   138  	}
   139  
   140  	ip := l.LocalAddr().(*net.UDPAddr).IP
   141  	if ip.To4() != nil && ip.To16() == nil {
   142  		err = ipv4.NewPacketConn(l).SetBPF(prog)
   143  	} else {
   144  		err = ipv6.NewPacketConn(l).SetBPF(prog)
   145  	}
   146  	if err != nil {
   147  		t.Fatalf("failed to attach BPF program to listener: %v", err)
   148  	}
   149  
   150  	s, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
   151  	if err != nil {
   152  		t.Fatalf("failed to dial connection to listener: %v", err)
   153  	}
   154  
   155  	done := func() {
   156  		_ = s.Close()
   157  		_ = l.Close()
   158  	}
   159  
   160  	return &osVirtualMachine{
   161  		l: l,
   162  		s: s,
   163  	}, done
   164  }
   165  
   166  // Run sends the input bytes into the OS's BPF VM and returns its verdict.
   167  func (vm *osVirtualMachine) Run(in []byte) (int, error) {
   168  	go func() {
   169  		_, _ = vm.s.Write(in)
   170  	}()
   171  
   172  	vm.l.SetDeadline(time.Now().Add(50 * time.Millisecond))
   173  
   174  	var b [512]byte
   175  	n, _, err := vm.l.ReadFrom(b[:])
   176  	if err != nil {
   177  		// A timeout indicates that BPF filtered out the packet, and thus,
   178  		// no input should be returned.
   179  		if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
   180  			return n, nil
   181  		}
   182  
   183  		return n, err
   184  	}
   185  
   186  	return n, nil
   187  }