github.com/cilium/ebpf@v0.10.0/cmd/bpf2go/main_test.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"os/exec"
     9  	"path/filepath"
    10  	"runtime"
    11  	"sort"
    12  	"strings"
    13  	"testing"
    14  
    15  	qt "github.com/frankban/quicktest"
    16  	"github.com/google/go-cmp/cmp"
    17  )
    18  
    19  func TestRun(t *testing.T) {
    20  	clangBin := clangBin(t)
    21  	dir := mustWriteTempFile(t, "test.c", minimalSocketFilter)
    22  
    23  	cwd, err := os.Getwd()
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  
    28  	modRoot := filepath.Clean(filepath.Join(cwd, "../.."))
    29  	if _, err := os.Stat(filepath.Join(modRoot, "go.mod")); os.IsNotExist(err) {
    30  		t.Fatal("No go.mod file in", modRoot)
    31  	}
    32  
    33  	tmpDir, err := os.MkdirTemp("", "bpf2go-module-*")
    34  	if err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	defer os.RemoveAll(tmpDir)
    38  
    39  	execInModule := func(name string, args ...string) {
    40  		t.Helper()
    41  
    42  		cmd := exec.Command(name, args...)
    43  		cmd.Dir = tmpDir
    44  		if out, err := cmd.CombinedOutput(); err != nil {
    45  			if out := string(out); out != "" {
    46  				t.Log(out)
    47  			}
    48  			t.Fatalf("Can't execute %s: %v", name, args)
    49  		}
    50  	}
    51  
    52  	execInModule("go", "mod", "init", "bpf2go-test")
    53  
    54  	execInModule("go", "mod", "edit",
    55  		// Require the module. The version doesn't matter due to the replace
    56  		// below.
    57  		fmt.Sprintf("-require=%s@v0.0.0", ebpfModule),
    58  		// Replace the module with the current version.
    59  		fmt.Sprintf("-replace=%s=%s", ebpfModule, modRoot),
    60  	)
    61  
    62  	err = run(io.Discard, "foo", tmpDir, []string{
    63  		"-cc", clangBin,
    64  		"bar",
    65  		filepath.Join(dir, "test.c"),
    66  	})
    67  
    68  	if err != nil {
    69  		t.Fatal("Can't run:", err)
    70  	}
    71  
    72  	for _, arch := range []string{
    73  		"amd64", // little-endian
    74  		"s390x", // big-endian
    75  	} {
    76  		t.Run(arch, func(t *testing.T) {
    77  			goBin := exec.Command("go", "build", "-mod=mod")
    78  			goBin.Dir = tmpDir
    79  			goBin.Env = append(os.Environ(),
    80  				"GOOS=linux",
    81  				"GOARCH="+arch,
    82  			)
    83  			out, err := goBin.CombinedOutput()
    84  			if err != nil {
    85  				if out := string(out); out != "" {
    86  					t.Log(out)
    87  				}
    88  				t.Error("Can't compile package:", err)
    89  			}
    90  		})
    91  	}
    92  }
    93  
    94  func TestHelp(t *testing.T) {
    95  	var stdout bytes.Buffer
    96  	err := run(&stdout, "", "", []string{"-help"})
    97  	if err != nil {
    98  		t.Fatal("Can't execute -help")
    99  	}
   100  
   101  	if stdout.Len() == 0 {
   102  		t.Error("-help doesn't write to stdout")
   103  	}
   104  }
   105  
   106  func TestDisableStripping(t *testing.T) {
   107  	dir := mustWriteTempFile(t, "test.c", minimalSocketFilter)
   108  
   109  	err := run(io.Discard, "foo", dir, []string{
   110  		"-cc", clangBin(t),
   111  		"-strip", "binary-that-certainly-doesnt-exist",
   112  		"-no-strip",
   113  		"bar",
   114  		filepath.Join(dir, "test.c"),
   115  	})
   116  
   117  	if err != nil {
   118  		t.Fatal("Can't run with stripping disabled:", err)
   119  	}
   120  }
   121  
   122  func TestCollectTargets(t *testing.T) {
   123  	clangArches := make(map[string][]string)
   124  	linuxArchesLE := make(map[string][]string)
   125  	linuxArchesBE := make(map[string][]string)
   126  	for arch, archTarget := range targetByGoArch {
   127  		clangArches[archTarget.clang] = append(clangArches[archTarget.clang], arch)
   128  		if archTarget.clang == "bpfel" {
   129  			linuxArchesLE[archTarget.linux] = append(linuxArchesLE[archTarget.linux], arch)
   130  			continue
   131  		}
   132  		linuxArchesBE[archTarget.linux] = append(linuxArchesBE[archTarget.linux], arch)
   133  	}
   134  	for i := range clangArches {
   135  		sort.Strings(clangArches[i])
   136  	}
   137  	for i := range linuxArchesLE {
   138  		sort.Strings(linuxArchesLE[i])
   139  	}
   140  	for i := range linuxArchesBE {
   141  		sort.Strings(linuxArchesBE[i])
   142  	}
   143  
   144  	nativeTarget := make(map[target][]string)
   145  	for arch, archTarget := range targetByGoArch {
   146  		if arch == runtime.GOARCH {
   147  			if archTarget.clang == "bpfel" {
   148  				nativeTarget[archTarget] = linuxArchesLE[archTarget.linux]
   149  			} else {
   150  				nativeTarget[archTarget] = linuxArchesBE[archTarget.linux]
   151  			}
   152  			break
   153  		}
   154  	}
   155  
   156  	tests := []struct {
   157  		targets []string
   158  		want    map[target][]string
   159  	}{
   160  		{
   161  			[]string{"bpf", "bpfel", "bpfeb"},
   162  			map[target][]string{
   163  				{"bpf", ""}:   nil,
   164  				{"bpfel", ""}: clangArches["bpfel"],
   165  				{"bpfeb", ""}: clangArches["bpfeb"],
   166  			},
   167  		},
   168  		{
   169  			[]string{"amd64", "386"},
   170  			map[target][]string{
   171  				{"bpfel", "x86"}: linuxArchesLE["x86"],
   172  			},
   173  		},
   174  		{
   175  			[]string{"amd64", "arm64be"},
   176  			map[target][]string{
   177  				{"bpfeb", "arm64"}: linuxArchesBE["arm64"],
   178  				{"bpfel", "x86"}:   linuxArchesLE["x86"],
   179  			},
   180  		},
   181  		{
   182  			[]string{"native"},
   183  			nativeTarget,
   184  		},
   185  	}
   186  
   187  	for _, test := range tests {
   188  		name := strings.Join(test.targets, ",")
   189  		t.Run(name, func(t *testing.T) {
   190  			have, err := collectTargets(test.targets)
   191  			if err != nil {
   192  				t.Fatal(err)
   193  			}
   194  
   195  			if diff := cmp.Diff(test.want, have); diff != "" {
   196  				t.Errorf("Result mismatch (-want +got):\n%s", diff)
   197  			}
   198  		})
   199  	}
   200  }
   201  
   202  func TestCollectTargetsErrors(t *testing.T) {
   203  	tests := []struct {
   204  		name   string
   205  		target string
   206  	}{
   207  		{"unknown", "frood"},
   208  		{"no linux target", "mips64p32le"},
   209  	}
   210  
   211  	for _, test := range tests {
   212  		t.Run(test.name, func(t *testing.T) {
   213  			_, err := collectTargets([]string{test.target})
   214  			if err == nil {
   215  				t.Fatal("Function did not return an error")
   216  			}
   217  			t.Log("Error message:", err)
   218  		})
   219  	}
   220  }
   221  
   222  func TestConvertGOARCH(t *testing.T) {
   223  	tmp := mustWriteTempFile(t, "test.c",
   224  		`
   225  #ifndef __TARGET_ARCH_x86
   226  #error __TARGET_ARCH_x86 is not defined
   227  #endif`,
   228  	)
   229  
   230  	b2g := bpf2go{
   231  		pkg:              "test",
   232  		stdout:           io.Discard,
   233  		ident:            "test",
   234  		cc:               clangBin(t),
   235  		disableStripping: true,
   236  		sourceFile:       tmp + "/test.c",
   237  		outputDir:        tmp,
   238  	}
   239  
   240  	if err := b2g.convert(targetByGoArch["amd64"], nil); err != nil {
   241  		t.Fatal("Can't target GOARCH:", err)
   242  	}
   243  }
   244  
   245  func TestCTypes(t *testing.T) {
   246  	var ct cTypes
   247  	valid := []string{
   248  		"abcdefghijklmnopqrstuvqxyABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_",
   249  		"y",
   250  	}
   251  	for _, value := range valid {
   252  		if err := ct.Set(value); err != nil {
   253  			t.Fatalf("Set returned an error for %q: %s", value, err)
   254  		}
   255  	}
   256  	qt.Assert(t, ct, qt.ContentEquals, cTypes(valid))
   257  
   258  	for _, value := range []string{
   259  		"",
   260  		" ",
   261  		" frood",
   262  		"foo\nbar",
   263  		".",
   264  		",",
   265  		"+",
   266  		"-",
   267  	} {
   268  		ct = nil
   269  		if err := ct.Set(value); err == nil {
   270  			t.Fatalf("Set did not return an error for %q", value)
   271  		}
   272  	}
   273  
   274  	ct = nil
   275  	qt.Assert(t, ct.Set("foo"), qt.IsNil)
   276  	qt.Assert(t, ct.Set("foo"), qt.IsNotNil)
   277  }
   278  
   279  func clangBin(t *testing.T) string {
   280  	t.Helper()
   281  
   282  	if testing.Short() {
   283  		t.Skip("Not compiling with -short")
   284  	}
   285  
   286  	// Use a recent clang version for local development, but allow CI to run
   287  	// against oldest supported clang.
   288  	clang := "clang-14"
   289  	if minVersion := os.Getenv("CI_MIN_CLANG_VERSION"); minVersion != "" {
   290  		clang = fmt.Sprintf("clang-%s", minVersion)
   291  	}
   292  
   293  	t.Log("Testing against", clang)
   294  	return clang
   295  }