github.com/cilium/ebpf@v0.15.1-0.20240517100537-8079b37aa138/btf/types_test.go (about)

     1  package btf
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/go-quicktest/qt"
    12  	"github.com/google/go-cmp/cmp"
    13  )
    14  
    15  func TestSizeof(t *testing.T) {
    16  	testcases := []struct {
    17  		size int
    18  		typ  Type
    19  	}{
    20  		{0, (*Void)(nil)},
    21  		{1, &Int{Size: 1}},
    22  		{8, &Enum{Size: 8}},
    23  		{0, &Array{Type: &Pointer{Target: (*Void)(nil)}, Nelems: 0}},
    24  		{12, &Array{Type: &Enum{Size: 4}, Nelems: 3}},
    25  	}
    26  
    27  	for _, tc := range testcases {
    28  		name := fmt.Sprint(tc.typ)
    29  		t.Run(name, func(t *testing.T) {
    30  			have, err := Sizeof(tc.typ)
    31  			if err != nil {
    32  				t.Fatal("Can't calculate size:", err)
    33  			}
    34  			if have != tc.size {
    35  				t.Errorf("Expected size %d, got %d", tc.size, have)
    36  			}
    37  		})
    38  	}
    39  }
    40  
    41  func TestCopy(t *testing.T) {
    42  	_ = Copy((*Void)(nil))
    43  
    44  	in := &Int{Size: 4}
    45  	out := Copy(in)
    46  
    47  	in.Size = 8
    48  	if size := out.(*Int).Size; size != 4 {
    49  		t.Error("Copy doesn't make a copy, expected size 4, got", size)
    50  	}
    51  
    52  	t.Run("cyclical", func(t *testing.T) {
    53  		_ = Copy(newCyclicalType(2))
    54  	})
    55  
    56  	t.Run("identity", func(t *testing.T) {
    57  		u16 := &Int{Size: 2}
    58  
    59  		out := Copy(&Struct{
    60  			Members: []Member{
    61  				{Name: "a", Type: u16},
    62  				{Name: "b", Type: u16},
    63  			},
    64  		})
    65  
    66  		outStruct := out.(*Struct)
    67  		qt.Assert(t, qt.Equals(outStruct.Members[0].Type, outStruct.Members[1].Type))
    68  	})
    69  }
    70  
    71  func TestAs(t *testing.T) {
    72  	i := &Int{}
    73  	ptr := &Pointer{i}
    74  	td := &Typedef{Type: ptr}
    75  	cst := &Const{td}
    76  	vol := &Volatile{cst}
    77  
    78  	// It's possible to retrieve qualifiers and Typedefs.
    79  	haveVol, ok := As[*Volatile](vol)
    80  	qt.Assert(t, qt.IsTrue(ok))
    81  	qt.Assert(t, qt.Equals(haveVol, vol))
    82  
    83  	haveTd, ok := As[*Typedef](vol)
    84  	qt.Assert(t, qt.IsTrue(ok))
    85  	qt.Assert(t, qt.Equals(haveTd, td))
    86  
    87  	haveCst, ok := As[*Const](vol)
    88  	qt.Assert(t, qt.IsTrue(ok))
    89  	qt.Assert(t, qt.Equals(haveCst, cst))
    90  
    91  	// Make sure we don't skip Pointer.
    92  	haveI, ok := As[*Int](vol)
    93  	qt.Assert(t, qt.IsFalse(ok))
    94  	qt.Assert(t, qt.IsNil(haveI))
    95  
    96  	// Make sure we can always retrieve Pointer.
    97  	for _, typ := range []Type{
    98  		td, cst, vol, ptr,
    99  	} {
   100  		have, ok := As[*Pointer](typ)
   101  		qt.Assert(t, qt.IsTrue(ok))
   102  		qt.Assert(t, qt.Equals(have, ptr))
   103  	}
   104  }
   105  
   106  func BenchmarkCopy(b *testing.B) {
   107  	typ := newCyclicalType(10)
   108  
   109  	b.ReportAllocs()
   110  	b.ResetTimer()
   111  
   112  	for i := 0; i < b.N; i++ {
   113  		Copy(typ)
   114  	}
   115  }
   116  
   117  // The following are valid Types.
   118  //
   119  // There currently is no better way to document which
   120  // types implement an interface.
   121  func ExampleType_validTypes() {
   122  	var _ Type = &Void{}
   123  	var _ Type = &Int{}
   124  	var _ Type = &Pointer{}
   125  	var _ Type = &Array{}
   126  	var _ Type = &Struct{}
   127  	var _ Type = &Union{}
   128  	var _ Type = &Enum{}
   129  	var _ Type = &Fwd{}
   130  	var _ Type = &Typedef{}
   131  	var _ Type = &Volatile{}
   132  	var _ Type = &Const{}
   133  	var _ Type = &Restrict{}
   134  	var _ Type = &Func{}
   135  	var _ Type = &FuncProto{}
   136  	var _ Type = &Var{}
   137  	var _ Type = &Datasec{}
   138  	var _ Type = &Float{}
   139  }
   140  
   141  func TestType(t *testing.T) {
   142  	types := []func() Type{
   143  		func() Type { return &Void{} },
   144  		func() Type { return &Int{Size: 2} },
   145  		func() Type { return &Pointer{Target: &Void{}} },
   146  		func() Type { return &Array{Type: &Int{}} },
   147  		func() Type {
   148  			return &Struct{
   149  				Members: []Member{{Type: &Void{}}},
   150  			}
   151  		},
   152  		func() Type {
   153  			return &Union{
   154  				Members: []Member{{Type: &Void{}}},
   155  			}
   156  		},
   157  		func() Type { return &Enum{} },
   158  		func() Type { return &Fwd{Name: "thunk"} },
   159  		func() Type { return &Typedef{Type: &Void{}} },
   160  		func() Type { return &Volatile{Type: &Void{}} },
   161  		func() Type { return &Const{Type: &Void{}} },
   162  		func() Type { return &Restrict{Type: &Void{}} },
   163  		func() Type { return &Func{Name: "foo", Type: &Void{}} },
   164  		func() Type {
   165  			return &FuncProto{
   166  				Params: []FuncParam{{Name: "bar", Type: &Void{}}},
   167  				Return: &Void{},
   168  			}
   169  		},
   170  		func() Type { return &Var{Type: &Void{}} },
   171  		func() Type {
   172  			return &Datasec{
   173  				Vars: []VarSecinfo{{Type: &Void{}}},
   174  			}
   175  		},
   176  		func() Type { return &Float{} },
   177  		func() Type { return &declTag{Type: &Void{}} },
   178  		func() Type { return &typeTag{Type: &Void{}} },
   179  		func() Type { return &cycle{&Void{}} },
   180  	}
   181  
   182  	compareTypes := cmp.Comparer(func(a, b *Type) bool {
   183  		return a == b
   184  	})
   185  
   186  	for _, fn := range types {
   187  		typ := fn()
   188  		t.Run(fmt.Sprintf("%T", typ), func(t *testing.T) {
   189  			t.Logf("%v", typ)
   190  
   191  			if typ == typ.copy() {
   192  				t.Error("Copy doesn't copy")
   193  			}
   194  
   195  			var a []*Type
   196  			children(typ, func(t *Type) bool { a = append(a, t); return true })
   197  
   198  			if _, ok := typ.(*cycle); !ok {
   199  				if n := countChildren(t, reflect.TypeOf(typ)); len(a) < n {
   200  					t.Errorf("walkType visited %d children, expected at least %d", len(a), n)
   201  				}
   202  			}
   203  
   204  			var b []*Type
   205  			children(typ, func(t *Type) bool { b = append(b, t); return true })
   206  
   207  			if diff := cmp.Diff(a, b, compareTypes); diff != "" {
   208  				t.Errorf("Walk mismatch (-want +got):\n%s", diff)
   209  			}
   210  		})
   211  	}
   212  }
   213  
   214  func TestTagMarshaling(t *testing.T) {
   215  	for _, typ := range []Type{
   216  		&declTag{&Struct{Members: []Member{}}, "foo", -1},
   217  		&typeTag{&Int{}, "foo"},
   218  	} {
   219  		t.Run(fmt.Sprint(typ), func(t *testing.T) {
   220  			s := specFromTypes(t, []Type{typ})
   221  
   222  			have, err := s.TypeByID(1)
   223  			qt.Assert(t, qt.IsNil(err))
   224  
   225  			qt.Assert(t, qt.DeepEquals(have, typ))
   226  		})
   227  	}
   228  }
   229  
   230  func countChildren(t *testing.T, typ reflect.Type) int {
   231  	if typ.Kind() != reflect.Pointer {
   232  		t.Fatal("Expected pointer, got", typ.Kind())
   233  	}
   234  
   235  	typ = typ.Elem()
   236  	if typ.Kind() != reflect.Struct {
   237  		t.Fatal("Expected struct, got", typ.Kind())
   238  	}
   239  
   240  	var n int
   241  	for i := 0; i < typ.NumField(); i++ {
   242  		if typ.Field(i).Type == reflect.TypeOf((*Type)(nil)).Elem() {
   243  			n++
   244  		}
   245  	}
   246  
   247  	return n
   248  }
   249  
   250  type testFormattableType struct {
   251  	name  string
   252  	extra []interface{}
   253  }
   254  
   255  var _ formattableType = (*testFormattableType)(nil)
   256  
   257  func (tft *testFormattableType) TypeName() string { return tft.name }
   258  func (tft *testFormattableType) Format(fs fmt.State, verb rune) {
   259  	formatType(fs, verb, tft, tft.extra...)
   260  }
   261  
   262  func TestFormatType(t *testing.T) {
   263  	t1 := &testFormattableType{"", []interface{}{"extra"}}
   264  	t1Addr := fmt.Sprintf("%#p", t1)
   265  	goType := reflect.TypeOf(t1).Elem().Name()
   266  
   267  	t2 := &testFormattableType{"foo", []interface{}{t1}}
   268  
   269  	t3 := &testFormattableType{extra: []interface{}{""}}
   270  
   271  	tests := []struct {
   272  		t        formattableType
   273  		fmt      string
   274  		contains []string
   275  		omits    []string
   276  	}{
   277  		// %s doesn't contain address or extra.
   278  		{t1, "%s", []string{goType}, []string{t1Addr, "extra"}},
   279  		// %+s doesn't contain extra.
   280  		{t1, "%+s", []string{goType, t1Addr}, []string{"extra"}},
   281  		// %v does contain extra.
   282  		{t1, "%v", []string{goType, "extra"}, []string{t1Addr}},
   283  		// %+v does contain address.
   284  		{t1, "%+v", []string{goType, "extra", t1Addr}, nil},
   285  		// %v doesn't print nested types' extra.
   286  		{t2, "%v", []string{goType, t2.name}, []string{"extra"}},
   287  		// %1v does print nested types' extra.
   288  		{t2, "%1v", []string{goType, t2.name, "extra"}, nil},
   289  		// empty strings in extra don't emit anything.
   290  		{t3, "%v", []string{"[]"}, nil},
   291  	}
   292  
   293  	for _, test := range tests {
   294  		t.Run(test.fmt, func(t *testing.T) {
   295  			str := fmt.Sprintf(test.fmt, test.t)
   296  			t.Log(str)
   297  
   298  			for _, want := range test.contains {
   299  				qt.Assert(t, qt.StringContains(str, want))
   300  			}
   301  
   302  			for _, notWant := range test.omits {
   303  				qt.Assert(t, qt.Not(qt.StringContains(str, notWant)))
   304  			}
   305  		})
   306  	}
   307  }
   308  
   309  func newCyclicalType(n int) Type {
   310  	ptr := &Pointer{}
   311  	prev := Type(ptr)
   312  	for i := 0; i < n; i++ {
   313  		switch i % 5 {
   314  		case 0:
   315  			prev = &Struct{
   316  				Members: []Member{
   317  					{Type: prev},
   318  				},
   319  			}
   320  
   321  		case 1:
   322  			prev = &Const{Type: prev}
   323  		case 2:
   324  			prev = &Volatile{Type: prev}
   325  		case 3:
   326  			prev = &Typedef{Type: prev}
   327  		case 4:
   328  			prev = &Array{Type: prev, Index: &Int{Size: 1}}
   329  		}
   330  	}
   331  	ptr.Target = prev
   332  	return ptr
   333  }
   334  
   335  func TestUnderlyingType(t *testing.T) {
   336  	wrappers := []struct {
   337  		name string
   338  		fn   func(Type) Type
   339  	}{
   340  		{"const", func(t Type) Type { return &Const{Type: t} }},
   341  		{"volatile", func(t Type) Type { return &Volatile{Type: t} }},
   342  		{"restrict", func(t Type) Type { return &Restrict{Type: t} }},
   343  		{"typedef", func(t Type) Type { return &Typedef{Type: t} }},
   344  		{"type tag", func(t Type) Type { return &typeTag{Type: t} }},
   345  	}
   346  
   347  	for _, test := range wrappers {
   348  		t.Run(test.name+" cycle", func(t *testing.T) {
   349  			root := &Volatile{}
   350  			root.Type = test.fn(root)
   351  
   352  			got, ok := UnderlyingType(root).(*cycle)
   353  			qt.Assert(t, qt.IsTrue(ok))
   354  			qt.Assert(t, qt.Equals[Type](got.root, root))
   355  		})
   356  	}
   357  
   358  	for _, test := range wrappers {
   359  		t.Run(test.name, func(t *testing.T) {
   360  			want := &Int{}
   361  			got := UnderlyingType(test.fn(want))
   362  			qt.Assert(t, qt.Equals[Type](got, want))
   363  		})
   364  	}
   365  }
   366  
   367  func TestInflateLegacyBitfield(t *testing.T) {
   368  	const offset = 3
   369  	const size = 5
   370  
   371  	var rawInt rawType
   372  	rawInt.SetKind(kindInt)
   373  	rawInt.SetSize(4)
   374  	var data btfInt
   375  	data.SetOffset(offset)
   376  	data.SetBits(size)
   377  	rawInt.data = &data
   378  
   379  	var (
   380  		before bytes.Buffer
   381  		after  bytes.Buffer
   382  	)
   383  
   384  	var beforeInt rawType
   385  	beforeInt.SetKind(kindStruct)
   386  	beforeInt.SetVlen(1)
   387  	beforeInt.data = []btfMember{{Type: 2}}
   388  
   389  	if err := beforeInt.Marshal(&before, binary.LittleEndian); err != nil {
   390  		t.Fatal(err)
   391  	}
   392  	if err := rawInt.Marshal(&before, binary.LittleEndian); err != nil {
   393  		t.Fatal(err)
   394  	}
   395  
   396  	afterInt := beforeInt
   397  	afterInt.data = []btfMember{{Type: 1}}
   398  
   399  	if err := rawInt.Marshal(&after, binary.LittleEndian); err != nil {
   400  		t.Fatal(err)
   401  	}
   402  	if err := afterInt.Marshal(&after, binary.LittleEndian); err != nil {
   403  		t.Fatal(err)
   404  	}
   405  
   406  	emptyStrings := newStringTable("")
   407  
   408  	for _, test := range []struct {
   409  		name   string
   410  		reader io.Reader
   411  	}{
   412  		{"struct before int", &before},
   413  		{"struct after int", &after},
   414  	} {
   415  		t.Run(test.name, func(t *testing.T) {
   416  			types, err := readAndInflateTypes(test.reader, binary.LittleEndian, 2, emptyStrings, nil)
   417  			if err != nil {
   418  				fmt.Println(before.Bytes())
   419  				t.Fatal(err)
   420  			}
   421  
   422  			for _, typ := range types {
   423  				s, ok := typ.(*Struct)
   424  				if !ok {
   425  					continue
   426  				}
   427  
   428  				i := s.Members[0]
   429  				if i.BitfieldSize != size {
   430  					t.Errorf("Expected bitfield size %d, got %d", size, i.BitfieldSize)
   431  				}
   432  
   433  				if i.Offset != offset {
   434  					t.Errorf("Expected offset %d, got %d", offset, i.Offset)
   435  				}
   436  
   437  				return
   438  			}
   439  
   440  			t.Fatal("No Struct returned from inflateRawTypes")
   441  		})
   442  	}
   443  }
   444  
   445  func BenchmarkWalk(b *testing.B) {
   446  	types := []Type{
   447  		&Void{},
   448  		&Int{},
   449  		&Pointer{},
   450  		&Array{},
   451  		&Struct{Members: make([]Member, 2)},
   452  		&Union{Members: make([]Member, 2)},
   453  		&Enum{},
   454  		&Fwd{},
   455  		&Typedef{},
   456  		&Volatile{},
   457  		&Const{},
   458  		&Restrict{},
   459  		&Func{},
   460  		&FuncProto{Params: make([]FuncParam, 2)},
   461  		&Var{},
   462  		&Datasec{Vars: make([]VarSecinfo, 2)},
   463  	}
   464  
   465  	for _, typ := range types {
   466  		b.Run(fmt.Sprint(typ), func(b *testing.B) {
   467  			b.ReportAllocs()
   468  
   469  			for i := 0; i < b.N; i++ {
   470  				var dq typeDeque
   471  				children(typ, func(child *Type) bool {
   472  					dq.Push(child)
   473  					return true
   474  				})
   475  			}
   476  		})
   477  	}
   478  }
   479  
   480  func BenchmarkUnderlyingType(b *testing.B) {
   481  	b.Run("no unwrapping", func(b *testing.B) {
   482  		v := &Int{}
   483  		b.ReportAllocs()
   484  		b.ResetTimer()
   485  
   486  		for i := 0; i < b.N; i++ {
   487  			UnderlyingType(v)
   488  		}
   489  	})
   490  
   491  	b.Run("single unwrapping", func(b *testing.B) {
   492  		v := &Typedef{Type: &Int{}}
   493  		b.ReportAllocs()
   494  		b.ResetTimer()
   495  
   496  		for i := 0; i < b.N; i++ {
   497  			UnderlyingType(v)
   498  		}
   499  	})
   500  }
   501  
   502  // As can be used to strip qualifiers from a Type.
   503  func ExampleAs() {
   504  	a := &Volatile{Type: &Pointer{Target: &Typedef{Name: "foo", Type: &Int{Size: 2}}}}
   505  	fmt.Println(As[*Pointer](a))
   506  	// Output: Pointer[target=Typedef:"foo"] true
   507  }