github.com/dylandreimerink/gobpfld@v0.6.1-0.20220205171531-e79c330ad608/cmd/examples/xsk_multi_sock/main.go (about)

     1  package main
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"net"
     7  	"os"
     8  	"os/signal"
     9  	"sync"
    10  
    11  	"github.com/dylandreimerink/gobpfld"
    12  	"github.com/dylandreimerink/gobpfld/bpftypes"
    13  	"github.com/dylandreimerink/gobpfld/ebpf"
    14  	"github.com/vishvananda/netlink"
    15  	"golang.org/x/sys/unix"
    16  
    17  	_ "net/http/pprof"
    18  )
    19  
    20  var (
    21  	ifname     = flag.String("ifname", "", "name of the network interface to bind to")
    22  	ip         = flag.String("ip", "", "the ipv4 ip we will use as ping target")
    23  	concurrent = flag.Bool("concurrent", false, "enable concurrent reading and processing of packets")
    24  )
    25  
    26  func main() {
    27  	flag.Parse()
    28  
    29  	if ifname == nil || *ifname == "" {
    30  		fmt.Fprint(os.Stderr, "flag 'ifname' is required\n")
    31  		os.Exit(1)
    32  	}
    33  
    34  	if ip == nil || *ip == "" {
    35  		fmt.Fprint(os.Stderr, "flag 'ip' is required\n")
    36  		os.Exit(1)
    37  	}
    38  
    39  	linkip := net.ParseIP(*ip).To4()
    40  	if linkip.Equal(net.IPv4zero) {
    41  		fmt.Fprint(os.Stderr, "flag 'ip' contains an invalid IPv4\n")
    42  		os.Exit(1)
    43  	}
    44  
    45  	linkName := *ifname
    46  	link, err := netlink.LinkByName(linkName)
    47  	if err != nil {
    48  		fmt.Fprintf(os.Stderr, "get link by name: %s\n", err.Error())
    49  		os.Exit(1)
    50  	}
    51  	linkMAC := link.Attrs().HardwareAddr
    52  
    53  	queues, err := gobpfld.GetNetDevQueueCount(linkName)
    54  	if err != nil {
    55  		fmt.Fprintf(os.Stderr, "get link queue count: %s\n", err.Error())
    56  		os.Exit(1)
    57  	}
    58  
    59  	sockets := make([]*gobpfld.XSKSocket, queues)
    60  
    61  	for i := 0; i < queues; i++ {
    62  		xsksock, err := gobpfld.NewXSKSocket(gobpfld.XSKSettings{
    63  			NetDevIfIndex: link.Attrs().Index,
    64  			QueueID:       i,
    65  		})
    66  		if err != nil {
    67  			fmt.Fprintf(os.Stderr, "new socket: %s\n", err.Error())
    68  			os.Exit(1)
    69  		}
    70  
    71  		// Set the read timeout to 100ms so we can stop the program even if there is nothing to read.
    72  		// Set the write timeout to infinity since waiting for writes almost never happens(for this example),
    73  		// and retry logic is harder to implement.
    74  		// The default behavour is to never block, this allows for busy polling which has lower latency but
    75  		//  higher CPU usage.
    76  		xsksock.SetReadTimeout(100)
    77  		xsksock.SetWriteTimeout(-1)
    78  
    79  		sockets[i] = xsksock
    80  	}
    81  
    82  	// Generate a program which will bypass all traffic to userspace
    83  	program := &gobpfld.ProgramXDP{
    84  		AbstractBPFProgram: gobpfld.AbstractBPFProgram{
    85  			Name:        gobpfld.MustNewObjName("xsk_bypass"),
    86  			ProgramType: bpftypes.BPF_PROG_TYPE_XDP,
    87  			License:     "GPL",
    88  			Maps: map[string]gobpfld.BPFMap{
    89  				"xskmap": &gobpfld.XSKMap{
    90  					AbstractMap: gobpfld.AbstractMap{
    91  						Name: gobpfld.MustNewObjName("xskmap"),
    92  						Definition: gobpfld.BPFMapDef{
    93  							Type:       bpftypes.BPF_MAP_TYPE_XSKMAP,
    94  							KeySize:    4, // SizeOf(uint32)
    95  							ValueSize:  4, // SizeOf(uint32)
    96  							MaxEntries: uint32(queues),
    97  						},
    98  					},
    99  				},
   100  			},
   101  			MapFDLocations: map[string][]uint64{
   102  				"xskmap": {
   103  					// LoadConstant64bit is the 2nd instruction in this program. So the first byte of the
   104  					// 2nd instruction is the width of a instruction * 1 to skip the first 1 instruction
   105  					uint64(ebpf.BPFInstSize) * 1,
   106  				},
   107  			},
   108  			// Instructions for this program:
   109  			// int xsk_bypass(struct xdp_md *ctx)
   110  			// {
   111  			// 	return bpf_redirect_map(&xsks_map, ctx->rx_queue_index, XDP_PASS);
   112  			// }
   113  			//
   114  			// NOTE this program only works in linux kernel >= 5.3
   115  			// https://elixir.bootlin.com/linux/v5.12.2/source/tools/lib/bpf/xsk.c#L416
   116  			Instructions: ebpf.MustEncode([]ebpf.Instruction{
   117  				// load ((xdp_md) ctx)->rx_queue_index into R2 (used as 2nd parameter)
   118  				/* r2 = *(u32 *)(r1 + 16) */
   119  				&ebpf.LoadMemory{
   120  					Dest:   ebpf.BPF_REG_2,
   121  					Src:    ebpf.BPF_REG_1,
   122  					Size:   ebpf.BPF_W,
   123  					Offset: 16,
   124  				},
   125  				// Set R1(first parameter) to the address of the xskmap.
   126  				// Which will be set during loading
   127  				/* r1 = xskmap[] */
   128  				&ebpf.LoadConstant64bit{
   129  					Dest: ebpf.BPF_REG_1,
   130  				},
   131  				// Move XDP_PASS into R3 (third argument)
   132  				/* r3 = XDP_PASS */
   133  				&ebpf.Mov64{
   134  					Dest:  ebpf.BPF_REG_3,
   135  					Value: ebpf.XDP_PASS,
   136  				},
   137  				/* call bpf_redirect_map */
   138  				&ebpf.CallHelper{
   139  					Function: 51,
   140  				},
   141  				&ebpf.Exit{}, // exit
   142  			}),
   143  		},
   144  	}
   145  
   146  	xskmap := program.Maps["xskmap"].(*gobpfld.XSKMap)
   147  	log, err := program.Load(gobpfld.ProgXDPLoadOpts{
   148  		VerifierLogLevel: bpftypes.BPFLogLevelBasic,
   149  	})
   150  
   151  	program.DecodeToReader(os.Stdout)
   152  	fmt.Fprintln(os.Stderr, log)
   153  	if err != nil {
   154  		fmt.Fprintf(os.Stderr, "error while loading program: %s\n", err.Error())
   155  		os.Exit(1)
   156  	}
   157  
   158  	// Add all sockets to the xskmap, index by the queue number.
   159  	for i := uint32(0); i < uint32(queues); i++ {
   160  		err = xskmap.Set(i, sockets[i])
   161  		if err != nil {
   162  			fmt.Fprintf(os.Stderr, "error while setting xsksock in map: %s\n", err.Error())
   163  			os.Exit(1)
   164  		}
   165  	}
   166  
   167  	sigChan := make(chan os.Signal, 1)
   168  	signal.Notify(sigChan, os.Interrupt, unix.SIGTERM, unix.SIGINT)
   169  
   170  	err = program.Attach(gobpfld.ProgXDPAttachOpts{
   171  		InterfaceName: linkName,
   172  		Replace:       true,
   173  	})
   174  	if err != nil {
   175  		fmt.Fprintf(os.Stderr, "error while attaching program to loopback device: %s\n", err.Error())
   176  		os.Exit(1)
   177  	}
   178  
   179  	if *concurrent {
   180  		wg := &sync.WaitGroup{}
   181  		done := make(chan struct{})
   182  
   183  		// In concurrent mode we just create a routine for every queue we have. This offers beter performance over
   184  		// the MultiWriter because on a multi core processor multiple frames can be handled at the same time.
   185  		// It also requires more setup and manual TX balancing.
   186  		//
   187  		// The 'ethtool' utility can be used to configure NIC's to stear traffic to specific RX queues based on
   188  		// rules. Even tho we have an XDP program which uses a map to select a XSK, the kernel will not allow
   189  		// the XDP program to pick a socket which is not bound to that specific queue. The NIC/driver is leading.
   190  		for i := 0; i < queues; i++ {
   191  			fmt.Println("scheduled listener for queue ", i)
   192  
   193  			wg.Add(1)
   194  			go func(queue int, wg *sync.WaitGroup) {
   195  				defer wg.Done()
   196  
   197  				for {
   198  					select {
   199  					case <-sigChan:
   200  						close(done)
   201  						return
   202  					case <-done:
   203  						return
   204  					default:
   205  						sock := sockets[queue]
   206  						lease, err := sock.ReadLease()
   207  						if err != nil {
   208  							fmt.Fprintf(os.Stderr, "read lease: %s", err.Error())
   209  						}
   210  
   211  						if lease == nil {
   212  							continue
   213  						}
   214  
   215  						// Seperator to distinguish between frames
   216  						fmt.Println("------")
   217  						fmt.Printf("received frame on queue: %d\n", queue)
   218  
   219  						err = HandleFrame(lease, linkMAC, linkip)
   220  						if err != nil {
   221  							fmt.Fprintf(os.Stderr, "echo reply: %s", err.Error())
   222  						}
   223  					}
   224  				}
   225  			}(i, wg)
   226  		}
   227  
   228  		// Wait for all routines to stop before exiting the program
   229  		wg.Wait()
   230  	} else {
   231  		// If we will not be using multiple goroutines, we need to bundle the socket for every queue into one
   232  		// using a multi socket. A multi socket has the same functions available but balances between all
   233  		// sockets.
   234  
   235  		multiSock, err := gobpfld.NewXSKMultiSocket(sockets...)
   236  		if err != nil {
   237  			fmt.Fprintf(os.Stderr, "new xsk multi socket: %s\n", err.Error())
   238  		} else {
   239  			// We have to do this for the multi sock since it's timeout overrules that of any
   240  			// underlaying socket.
   241  			// Read the comments at xsksock.Set{Read|Write}Timeout for an explination of the values.
   242  			multiSock.SetReadTimeout(100)
   243  			multiSock.SetWriteTimeout(-1)
   244  
   245  			done := false
   246  			for !done {
   247  				select {
   248  				case <-sigChan:
   249  					done = true
   250  
   251  				default:
   252  					// Seperator to distinguish between frames
   253  					fmt.Println("------")
   254  					lease, err := multiSock.ReadLease()
   255  					if err != nil {
   256  						fmt.Fprintf(os.Stderr, "read lease: %s\n", err.Error())
   257  						continue
   258  					}
   259  
   260  					if lease == nil {
   261  						continue
   262  					}
   263  
   264  					err = HandleFrame(lease, linkMAC, linkip)
   265  					if err != nil {
   266  						fmt.Fprintf(os.Stderr, "echo reply: %s", err.Error())
   267  					}
   268  				}
   269  			}
   270  		}
   271  	}
   272  
   273  	fmt.Println("Detaching XPD program and stopping")
   274  
   275  	err = program.XDPLinkDetach(gobpfld.BPFProgramXDPLinkDetachSettings{
   276  		All: true,
   277  	})
   278  	if err != nil {
   279  		fmt.Fprintf(os.Stderr, "error while detaching program: %s\n", err.Error())
   280  		os.Exit(1)
   281  	}
   282  
   283  	os.Exit(0)
   284  }
   285  
   286  func HandleFrame(lease *gobpfld.XSKLease, linkMac net.HardwareAddr, linkIP net.IP) error {
   287  	var err error
   288  
   289  	// Swap the MAC addresses
   290  	fmt.Printf("Src MAC: %X, Dst MAC: %X, EthType: %X\n", lease.Data[0:6], lease.Data[6:12], lease.Data[12:14])
   291  	swapMac := make([]byte, 6)
   292  	copy(swapMac, lease.Data[0:6])
   293  	copy(lease.Data[0:6], lease.Data[6:12])
   294  	copy(lease.Data[6:12], swapMac)
   295  
   296  	// If EtherType == 0x0806(ARP)
   297  	if lease.Data[12] == 0x08 && lease.Data[13] == 0x06 {
   298  		return HandleARP(lease, linkMac, linkIP)
   299  	}
   300  
   301  	// If EtherType == 0x0800(IPv4)
   302  	if lease.Data[12] == 0x08 && lease.Data[13] == 0x00 {
   303  		return EchoReply(lease, linkIP)
   304  	}
   305  
   306  	err = lease.Release()
   307  	if err != nil {
   308  		return fmt.Errorf("release lease: %w", err)
   309  	}
   310  	return nil
   311  }
   312  
   313  // Since we have no network stack we also need to respond to ARP messages
   314  func HandleARP(lease *gobpfld.XSKLease, linkMac net.HardwareAddr, linkIP net.IP) error {
   315  	// 14-20 left unchanged
   316  
   317  	// Change opcode from request to reply
   318  	lease.Data[21] = 2
   319  
   320  	// Copy the MAC of our link into the ethernet SRC
   321  	copy(lease.Data[6:12], linkMac)
   322  	// Copy the MAC of our link into the ARP target (will be swapped to source)
   323  	copy(lease.Data[32:38], linkMac)
   324  
   325  	// Ignore ARP requests whic hare not for us, otherwise we are ARP spoofing
   326  	if !net.IP(lease.Data[38:42]).Equal(linkIP) {
   327  		err := lease.Release()
   328  		if err != nil {
   329  			return fmt.Errorf("release lease: %w", err)
   330  		}
   331  		return nil
   332  	}
   333  
   334  	fmt.Println("respond to arp")
   335  
   336  	// Swap sender and target fields
   337  	swap := make([]byte, 10)
   338  	copy(swap, lease.Data[22:32])
   339  	copy(lease.Data[22:32], lease.Data[32:42])
   340  	copy(lease.Data[32:42], swap)
   341  
   342  	err := lease.Write()
   343  	if err != nil {
   344  		return fmt.Errorf("release lease: %w", err)
   345  	}
   346  	return nil
   347  }
   348  
   349  func EchoReply(lease *gobpfld.XSKLease, linkIP net.IP) error {
   350  	var err error
   351  
   352  	// Ignore IP traffic that is not for us
   353  	if !net.IP(lease.Data[30:34]).Equal(linkIP) {
   354  		err = lease.Release()
   355  		if err != nil {
   356  			return fmt.Errorf("release lease: %w", err)
   357  		}
   358  		return nil
   359  	}
   360  
   361  	fmt.Printf("IPv4 Src: %X, Dst: %X, Proto: %X\n", lease.Data[26:30], lease.Data[30:34], lease.Data[23])
   362  	swapIP := make([]byte, 4)
   363  	copy(swapIP, lease.Data[26:30])
   364  	copy(lease.Data[26:30], lease.Data[30:34])
   365  	copy(lease.Data[30:34], swapIP)
   366  
   367  	fmt.Printf("IPv4 Checksum in: %X\n", lease.Data[24:26])
   368  
   369  	// Zero the checksum
   370  	lease.Data[24] = 0x00
   371  	lease.Data[25] = 0x00
   372  
   373  	// Calculate header checksum
   374  	// https://en.wikipedia.org/wiki/IPv4_header_checksum
   375  	var sum uint32
   376  	for i := 14; i < 34; i += 2 {
   377  		sum += uint32(lease.Data[i]) << 8
   378  		sum += uint32(lease.Data[i+1])
   379  	}
   380  	for {
   381  		// Break when sum is less or equals to 0xFFFF
   382  		if sum <= 65535 {
   383  			break
   384  		}
   385  		// Add carry to the sum
   386  		sum = (sum >> 16) + uint32(uint16(sum))
   387  	}
   388  	checkSum := ^uint16(sum)
   389  	lease.Data[24] = byte(checkSum >> 8)
   390  	lease.Data[25] = byte(checkSum & 0xFF)
   391  
   392  	fmt.Printf("IPv4 Checksum out: %X\n", lease.Data[24:26])
   393  
   394  	// If Protocol != 0x01(ICMPv4)
   395  	if lease.Data[23] != 0x01 {
   396  		err = lease.Release()
   397  		if err != nil {
   398  			return fmt.Errorf("release lease: %w", err)
   399  		}
   400  		return nil
   401  	}
   402  
   403  	fmt.Printf("ICMPv4 Type: %X, Code: %X\n", lease.Data[34], lease.Data[35])
   404  	// Set type to 0 to get 0,0 = Echo Reply
   405  	lease.Data[34] = 0
   406  
   407  	fmt.Printf("ICMPv4 Checksum in: %X\n", lease.Data[36:38])
   408  
   409  	// clear icmp checksum
   410  	lease.Data[36] = 0x00
   411  	lease.Data[37] = 0x00
   412  
   413  	// Calculate ICMP checksum
   414  	sum = 0
   415  	for i := 34; i < len(lease.Data); i += 2 {
   416  		sum += uint32(lease.Data[i]) << 8
   417  		sum += uint32(lease.Data[i+1])
   418  	}
   419  	for {
   420  		// Break when sum is less or equals to 0xFFFF
   421  		if sum <= 65535 {
   422  			break
   423  		}
   424  		// Add carry to the sum
   425  		sum = (sum >> 16) + uint32(uint16(sum))
   426  	}
   427  	checkSum = ^uint16(sum)
   428  	lease.Data[36] = byte(checkSum >> 8)
   429  	lease.Data[37] = byte(checkSum & 0xFF)
   430  
   431  	fmt.Printf("ICMPv4 Checksum out: %X\n", lease.Data[36:38])
   432  
   433  	// Now that we have converted the request into a reply in the same memory buffer
   434  	// we can just write this buffer back to the network interface
   435  	err = lease.Write()
   436  	if err != nil {
   437  		return fmt.Errorf("write lease: %w", err)
   438  	}
   439  	return nil
   440  }