github.com/cilium/ebpf@v0.10.0/internal/cmd/gentypes/main.go (about)

     1  // Program gentypes reads a compressed vmlinux .BTF section and generates
     2  // syscall bindings from it.
     3  //
     4  // Output is written to "types.go".
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"sort"
    13  	"strings"
    14  
    15  	"github.com/cilium/ebpf/btf"
    16  	"github.com/cilium/ebpf/internal"
    17  	"github.com/cilium/ebpf/internal/sys"
    18  )
    19  
    20  type syscallRetval int
    21  
    22  const (
    23  	retError syscallRetval = iota
    24  	retFd
    25  )
    26  
    27  func main() {
    28  	if err := run(os.Args[1:]); err != nil {
    29  		fmt.Fprintln(os.Stderr, "Error:", err)
    30  		os.Exit(1)
    31  	}
    32  }
    33  
    34  func run(args []string) error {
    35  	if len(args) != 1 {
    36  		return fmt.Errorf("expect location of compressed vmlinux .BTF as argument")
    37  	}
    38  
    39  	raw, err := internal.ReadAllCompressed(args[0])
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	spec, err := btf.LoadSpecFromReader(bytes.NewReader(raw))
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	output, err := generateTypes(spec)
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	w, err := os.Create("types.go")
    55  	if err != nil {
    56  		return err
    57  	}
    58  	defer w.Close()
    59  
    60  	return internal.WriteFormatted(output, w)
    61  }
    62  
    63  func generateTypes(spec *btf.Spec) ([]byte, error) {
    64  	objName := &btf.Array{Nelems: 16, Type: &btf.Int{Encoding: btf.Char, Size: 1}}
    65  	linkID := &btf.Int{Size: 4}
    66  	btfID := &btf.Int{Size: 4}
    67  	pointer := &btf.Int{Size: 8}
    68  	logLevel := &btf.Int{Size: 4}
    69  	mapFlags := &btf.Int{Size: 4}
    70  
    71  	// Pre-declare handwritten types so that generated types can refer to them.
    72  	var (
    73  		_ sys.ObjName
    74  		_ sys.LinkID
    75  		_ sys.BTFID
    76  		_ sys.Pointer
    77  		_ sys.LogLevel
    78  		_ sys.MapFlags
    79  	)
    80  
    81  	gf := &btf.GoFormatter{
    82  		Names: map[btf.Type]string{
    83  			objName:  "ObjName",
    84  			linkID:   "LinkID",
    85  			btfID:    "BTFID",
    86  			pointer:  "Pointer",
    87  			logLevel: "LogLevel",
    88  			mapFlags: "MapFlags",
    89  		},
    90  		Identifier: internal.Identifier,
    91  		EnumIdentifier: func(name, element string) string {
    92  			return element
    93  		},
    94  	}
    95  
    96  	w := bytes.NewBuffer(nil)
    97  	w.WriteString(`// Code generated by internal/cmd/gentypes; DO NOT EDIT.
    98  
    99  package sys
   100  
   101  import (
   102  	"unsafe"
   103  )
   104  
   105  `)
   106  
   107  	enums := []struct {
   108  		goType string
   109  		cType  string
   110  	}{
   111  		{"Cmd", "bpf_cmd"},
   112  		{"MapType", "bpf_map_type"},
   113  		{"ProgType", "bpf_prog_type"},
   114  		{"AttachType", "bpf_attach_type"},
   115  		{"LinkType", "bpf_link_type"},
   116  		{"StatsType", "bpf_stats_type"},
   117  		{"SkAction", "sk_action"},
   118  		{"StackBuildIdStatus", "bpf_stack_build_id_status"},
   119  		{"FunctionId", "bpf_func_id"},
   120  		{"AdjRoomMode", "bpf_adj_room_mode"},
   121  		{"HdrStartOff", "bpf_hdr_start_off"},
   122  		{"RetCode", "bpf_ret_code"},
   123  		{"XdpAction", "xdp_action"},
   124  	}
   125  
   126  	sort.Slice(enums, func(i, j int) bool {
   127  		return enums[i].goType < enums[j].goType
   128  	})
   129  
   130  	enumTypes := make(map[string]btf.Type)
   131  	for _, o := range enums {
   132  		fmt.Println("enum", o.goType)
   133  
   134  		var t *btf.Enum
   135  		if err := spec.TypeByName(o.cType, &t); err != nil {
   136  			return nil, err
   137  		}
   138  
   139  		// Add the enum as a predeclared type so that generated structs
   140  		// refer to the Go types.
   141  		if name := gf.Names[t]; name != "" {
   142  			return nil, fmt.Errorf("type %q is already declared as %s", o.cType, name)
   143  		}
   144  		gf.Names[t] = o.goType
   145  		enumTypes[o.goType] = t
   146  
   147  		decl, err := gf.TypeDeclaration(o.goType, t)
   148  		if err != nil {
   149  			return nil, fmt.Errorf("generate %q: %w", o.goType, err)
   150  		}
   151  
   152  		w.WriteString(decl)
   153  		w.WriteRune('\n')
   154  	}
   155  
   156  	// Assorted structs
   157  
   158  	structs := []struct {
   159  		goType  string
   160  		cType   string
   161  		patches []patch
   162  	}{
   163  		{
   164  			"ProgInfo", "bpf_prog_info",
   165  			[]patch{
   166  				replace(objName, "name"),
   167  				replace(pointer, "xlated_prog_insns"),
   168  				replace(pointer, "map_ids"),
   169  			},
   170  		},
   171  		{
   172  			"MapInfo", "bpf_map_info",
   173  			[]patch{
   174  				replace(objName, "name"),
   175  				replace(mapFlags, "map_flags"),
   176  			},
   177  		},
   178  		{
   179  			"BtfInfo", "bpf_btf_info",
   180  			[]patch{
   181  				replace(pointer, "btf", "name"),
   182  				replace(btfID, "id"),
   183  			},
   184  		},
   185  		{
   186  			"LinkInfo", "bpf_link_info",
   187  			[]patch{
   188  				replace(enumTypes["LinkType"], "type"),
   189  				replace(linkID, "id"),
   190  				name(3, "extra"),
   191  				replaceWithBytes("extra"),
   192  			},
   193  		},
   194  		{"FuncInfo", "bpf_func_info", nil},
   195  		{"LineInfo", "bpf_line_info", nil},
   196  		{"XdpMd", "xdp_md", nil},
   197  		{
   198  			"SkLookup", "bpf_sk_lookup",
   199  			[]patch{
   200  				choose(0, "cookie"),
   201  				replaceWithBytes("remote_ip4", "remote_ip6", "local_ip4", "local_ip6"),
   202  			},
   203  		},
   204  	}
   205  
   206  	sort.Slice(structs, func(i, j int) bool {
   207  		return structs[i].goType < structs[j].goType
   208  	})
   209  
   210  	for _, s := range structs {
   211  		fmt.Println("struct", s.goType)
   212  
   213  		var t *btf.Struct
   214  		if err := spec.TypeByName(s.cType, &t); err != nil {
   215  			return nil, err
   216  		}
   217  
   218  		if err := outputPatchedStruct(gf, w, s.goType, t, s.patches); err != nil {
   219  			return nil, fmt.Errorf("output %q: %w", s.goType, err)
   220  		}
   221  	}
   222  
   223  	// Attrs
   224  
   225  	attrs := []struct {
   226  		goType  string
   227  		ret     syscallRetval
   228  		cType   string
   229  		cmd     string
   230  		patches []patch
   231  	}{
   232  		{
   233  			"MapCreate", retFd, "map_create", "BPF_MAP_CREATE",
   234  			[]patch{
   235  				replace(objName, "map_name"),
   236  				replace(enumTypes["MapType"], "map_type"),
   237  				replace(mapFlags, "map_flags"),
   238  			},
   239  		},
   240  		{
   241  			"MapLookupElem", retError, "map_elem", "BPF_MAP_LOOKUP_ELEM",
   242  			[]patch{choose(2, "value"), replace(pointer, "key", "value")},
   243  		},
   244  		{
   245  			"MapLookupAndDeleteElem", retError, "map_elem", "BPF_MAP_LOOKUP_AND_DELETE_ELEM",
   246  			[]patch{choose(2, "value"), replace(pointer, "key", "value")},
   247  		},
   248  		{
   249  			"MapUpdateElem", retError, "map_elem", "BPF_MAP_UPDATE_ELEM",
   250  			[]patch{choose(2, "value"), replace(pointer, "key", "value")},
   251  		},
   252  		{
   253  			"MapDeleteElem", retError, "map_elem", "BPF_MAP_DELETE_ELEM",
   254  			[]patch{choose(2, "value"), replace(pointer, "key", "value")},
   255  		},
   256  		{
   257  			"MapGetNextKey", retError, "map_elem", "BPF_MAP_GET_NEXT_KEY",
   258  			[]patch{
   259  				choose(2, "next_key"), replace(pointer, "key", "next_key"),
   260  				truncateAfter("next_key"),
   261  			},
   262  		},
   263  		{
   264  			"MapFreeze", retError, "map_elem", "BPF_MAP_FREEZE",
   265  			[]patch{truncateAfter("map_fd")},
   266  		},
   267  		{
   268  			"MapLookupBatch", retError, "map_elem_batch", "BPF_MAP_LOOKUP_BATCH",
   269  			[]patch{replace(pointer, "in_batch", "out_batch", "keys", "values")},
   270  		},
   271  		{
   272  			"MapLookupAndDeleteBatch", retError, "map_elem_batch", "BPF_MAP_LOOKUP_AND_DELETE_BATCH",
   273  			[]patch{replace(pointer, "in_batch", "out_batch", "keys", "values")},
   274  		},
   275  		{
   276  			"MapUpdateBatch", retError, "map_elem_batch", "BPF_MAP_UPDATE_BATCH",
   277  			[]patch{replace(pointer, "in_batch", "out_batch", "keys", "values")},
   278  		},
   279  		{
   280  			"MapDeleteBatch", retError, "map_elem_batch", "BPF_MAP_DELETE_BATCH",
   281  			[]patch{replace(pointer, "in_batch", "out_batch", "keys", "values")},
   282  		},
   283  		{
   284  			"ProgLoad", retFd, "prog_load", "BPF_PROG_LOAD",
   285  			[]patch{
   286  				replace(objName, "prog_name"),
   287  				replace(enumTypes["ProgType"], "prog_type"),
   288  				replace(enumTypes["AttachType"], "expected_attach_type"),
   289  				replace(logLevel, "log_level"),
   290  				replace(pointer,
   291  					"insns",
   292  					"license",
   293  					"log_buf",
   294  					"func_info",
   295  					"line_info",
   296  					"fd_array",
   297  					"core_relos",
   298  				),
   299  				choose(20, "attach_btf_obj_fd"),
   300  			},
   301  		},
   302  		{
   303  			"ProgBindMap", retError, "prog_bind_map", "BPF_PROG_BIND_MAP",
   304  			nil,
   305  		},
   306  		{
   307  			"ObjPin", retError, "obj_pin", "BPF_OBJ_PIN",
   308  			[]patch{replace(pointer, "pathname")},
   309  		},
   310  		{
   311  			"ObjGet", retFd, "obj_pin", "BPF_OBJ_GET",
   312  			[]patch{replace(pointer, "pathname")},
   313  		},
   314  		{
   315  			"ProgAttach", retError, "prog_attach", "BPF_PROG_ATTACH",
   316  			nil,
   317  		},
   318  		{
   319  			"ProgDetach", retError, "prog_attach", "BPF_PROG_DETACH",
   320  			[]patch{truncateAfter("attach_type")},
   321  		},
   322  		{
   323  			"ProgRun", retError, "prog_run", "BPF_PROG_TEST_RUN",
   324  			[]patch{replace(pointer, "data_in", "data_out", "ctx_in", "ctx_out")},
   325  		},
   326  		{
   327  			"ProgGetNextId", retError, "obj_next_id", "BPF_PROG_GET_NEXT_ID",
   328  			[]patch{
   329  				choose(0, "start_id"), rename("start_id", "id"),
   330  				truncateAfter("next_id"),
   331  			},
   332  		},
   333  		{
   334  			"MapGetNextId", retError, "obj_next_id", "BPF_MAP_GET_NEXT_ID",
   335  			[]patch{
   336  				choose(0, "start_id"), rename("start_id", "id"),
   337  				truncateAfter("next_id"),
   338  			},
   339  		},
   340  		{
   341  			"BtfGetNextId", retError, "obj_next_id", "BPF_BTF_GET_NEXT_ID",
   342  			[]patch{
   343  				choose(0, "start_id"), rename("start_id", "id"),
   344  				replace(btfID, "id", "next_id"),
   345  				truncateAfter("next_id"),
   346  			},
   347  		},
   348  		// These piggy back on the obj_next_id decl, but only support the
   349  		// first field...
   350  		{
   351  			"BtfGetFdById", retFd, "obj_next_id", "BPF_BTF_GET_FD_BY_ID",
   352  			[]patch{choose(0, "start_id"), rename("start_id", "id"), truncateAfter("id")},
   353  		},
   354  		{
   355  			"MapGetFdById", retFd, "obj_next_id", "BPF_MAP_GET_FD_BY_ID",
   356  			[]patch{choose(0, "start_id"), rename("start_id", "id"), truncateAfter("id")},
   357  		},
   358  		{
   359  			"ProgGetFdById", retFd, "obj_next_id", "BPF_PROG_GET_FD_BY_ID",
   360  			[]patch{choose(0, "start_id"), rename("start_id", "id"), truncateAfter("id")},
   361  		},
   362  		{
   363  			"ObjGetInfoByFd", retError, "info_by_fd", "BPF_OBJ_GET_INFO_BY_FD",
   364  			[]patch{replace(pointer, "info")},
   365  		},
   366  		{
   367  			"RawTracepointOpen", retFd, "raw_tracepoint_open", "BPF_RAW_TRACEPOINT_OPEN",
   368  			[]patch{replace(pointer, "name")},
   369  		},
   370  		{
   371  			"BtfLoad", retFd, "btf_load", "BPF_BTF_LOAD",
   372  			[]patch{replace(pointer, "btf", "btf_log_buf")},
   373  		},
   374  		{
   375  			"LinkCreate", retFd, "link_create", "BPF_LINK_CREATE",
   376  			[]patch{replace(enumTypes["AttachType"], "attach_type")},
   377  		},
   378  		{
   379  			"LinkCreateIter", retFd, "link_create", "BPF_LINK_CREATE",
   380  			[]patch{
   381  				chooseNth(4, 1),
   382  				replace(enumTypes["AttachType"], "attach_type"),
   383  				flattenAnon,
   384  				replace(pointer, "iter_info"),
   385  			},
   386  		},
   387  		{
   388  			"LinkCreatePerfEvent", retFd, "link_create", "BPF_LINK_CREATE",
   389  			[]patch{
   390  				chooseNth(4, 2),
   391  				replace(enumTypes["AttachType"], "attach_type"),
   392  				flattenAnon,
   393  			},
   394  		},
   395  		{
   396  			"LinkCreateKprobeMulti", retFd, "link_create", "BPF_LINK_CREATE",
   397  			[]patch{
   398  				chooseNth(4, 3),
   399  				replace(enumTypes["AttachType"], "attach_type"),
   400  				modify(func(m *btf.Member) error {
   401  					return rename("flags", "kprobe_multi_flags")(m.Type.(*btf.Struct))
   402  				}, "kprobe_multi"),
   403  				flattenAnon,
   404  				replace(pointer, "cookies"),
   405  				replace(pointer, "addrs"),
   406  				replace(pointer, "syms"),
   407  				rename("cnt", "count"),
   408  			},
   409  		},
   410  		{
   411  			"LinkUpdate", retError, "link_update", "BPF_LINK_UPDATE",
   412  			nil,
   413  		},
   414  		{
   415  			"EnableStats", retFd, "enable_stats", "BPF_ENABLE_STATS",
   416  			nil,
   417  		},
   418  		{
   419  			"IterCreate", retFd, "iter_create", "BPF_ITER_CREATE",
   420  			nil,
   421  		},
   422  		{
   423  			"ProgQuery", retError, "prog_query", "BPF_PROG_QUERY",
   424  			[]patch{
   425  				replace(enumTypes["AttachType"], "attach_type"),
   426  				replace(pointer, "prog_ids"),
   427  				rename("prog_cnt", "prog_count"),
   428  			},
   429  		},
   430  	}
   431  
   432  	sort.Slice(attrs, func(i, j int) bool {
   433  		return attrs[i].goType < attrs[j].goType
   434  	})
   435  
   436  	var bpfAttr *btf.Union
   437  	if err := spec.TypeByName("bpf_attr", &bpfAttr); err != nil {
   438  		return nil, err
   439  	}
   440  	attrTypes, err := splitUnion(bpfAttr, types{
   441  		{"map_create", "map_type"},
   442  		{"map_elem", "map_fd"},
   443  		{"map_elem_batch", "batch"},
   444  		{"prog_load", "prog_type"},
   445  		{"obj_pin", "pathname"},
   446  		{"prog_attach", "target_fd"},
   447  		{"prog_run", "test"},
   448  		{"obj_next_id", ""},
   449  		{"info_by_fd", "info"},
   450  		{"prog_query", "query"},
   451  		{"raw_tracepoint_open", "raw_tracepoint"},
   452  		{"btf_load", "btf"},
   453  		{"task_fd_query", "task_fd_query"},
   454  		{"link_create", "link_create"},
   455  		{"link_update", "link_update"},
   456  		{"link_detach", "link_detach"},
   457  		{"enable_stats", "enable_stats"},
   458  		{"iter_create", "iter_create"},
   459  		{"prog_bind_map", "prog_bind_map"},
   460  	})
   461  	if err != nil {
   462  		return nil, fmt.Errorf("splitting bpf_attr: %w", err)
   463  	}
   464  
   465  	for _, s := range attrs {
   466  		fmt.Println("attr", s.goType)
   467  
   468  		t := attrTypes[s.cType]
   469  		if t == nil {
   470  			return nil, fmt.Errorf("unknown attr %q", s.cType)
   471  		}
   472  
   473  		goAttrType := s.goType + "Attr"
   474  		if err := outputPatchedStruct(gf, w, goAttrType, t, s.patches); err != nil {
   475  			return nil, fmt.Errorf("output %q: %w", goAttrType, err)
   476  		}
   477  
   478  		switch s.ret {
   479  		case retError:
   480  			fmt.Fprintf(w, "func %s(attr *%s) error { _, err := BPF(%s, unsafe.Pointer(attr), unsafe.Sizeof(*attr)); return err }\n\n", s.goType, goAttrType, s.cmd)
   481  		case retFd:
   482  			fmt.Fprintf(w, "func %s(attr *%s) (*FD, error) { fd, err := BPF(%s, unsafe.Pointer(attr), unsafe.Sizeof(*attr)); if err != nil { return nil, err }; return NewFD(int(fd)) }\n\n", s.goType, goAttrType, s.cmd)
   483  		}
   484  	}
   485  
   486  	// Link info type specific
   487  
   488  	linkInfoExtraTypes := []struct {
   489  		goType  string
   490  		cType   string
   491  		patches []patch
   492  	}{
   493  		{"CgroupLinkInfo", "cgroup", []patch{replace(enumTypes["AttachType"], "attach_type")}},
   494  		{"IterLinkInfo", "iter", []patch{replace(pointer, "target_name"), truncateAfter("target_name_len")}},
   495  		{"NetNsLinkInfo", "netns", []patch{replace(enumTypes["AttachType"], "attach_type")}},
   496  		{"RawTracepointLinkInfo", "raw_tracepoint", []patch{replace(pointer, "tp_name")}},
   497  		{"TracingLinkInfo", "tracing", []patch{replace(enumTypes["AttachType"], "attach_type")}},
   498  		{"XDPLinkInfo", "xdp", nil},
   499  	}
   500  
   501  	sort.Slice(linkInfoExtraTypes, func(i, j int) bool {
   502  		return linkInfoExtraTypes[i].goType < linkInfoExtraTypes[j].goType
   503  	})
   504  
   505  	var bpfLinkInfo *btf.Struct
   506  	if err := spec.TypeByName("bpf_link_info", &bpfLinkInfo); err != nil {
   507  		return nil, err
   508  	}
   509  
   510  	member := bpfLinkInfo.Members[len(bpfLinkInfo.Members)-1]
   511  	bpfLinkInfoUnion, ok := member.Type.(*btf.Union)
   512  	if !ok {
   513  		return nil, fmt.Errorf("there is not type-specific union")
   514  	}
   515  
   516  	linkInfoTypes, err := splitUnion(bpfLinkInfoUnion, types{
   517  		{"raw_tracepoint", "raw_tracepoint"},
   518  		{"tracing", "tracing"},
   519  		{"cgroup", "cgroup"},
   520  		{"iter", "iter"},
   521  		{"netns", "netns"},
   522  		{"xdp", "xdp"},
   523  	})
   524  	if err != nil {
   525  		return nil, fmt.Errorf("splitting linkInfo: %w", err)
   526  	}
   527  
   528  	for _, s := range linkInfoExtraTypes {
   529  		t := linkInfoTypes[s.cType]
   530  		if err := outputPatchedStruct(gf, w, s.goType, t, s.patches); err != nil {
   531  			return nil, fmt.Errorf("output %q: %w", s.goType, err)
   532  		}
   533  	}
   534  
   535  	return w.Bytes(), nil
   536  }
   537  
   538  func outputPatchedStruct(gf *btf.GoFormatter, w *bytes.Buffer, id string, s *btf.Struct, patches []patch) error {
   539  	s = btf.Copy(s, nil).(*btf.Struct)
   540  
   541  	for i, p := range patches {
   542  		if err := p(s); err != nil {
   543  			return fmt.Errorf("patch %d: %w", i, err)
   544  		}
   545  	}
   546  
   547  	decl, err := gf.TypeDeclaration(id, s)
   548  	if err != nil {
   549  		return err
   550  	}
   551  
   552  	w.WriteString(decl)
   553  	w.WriteString("\n\n")
   554  	return nil
   555  }
   556  
   557  type types []struct {
   558  	name                string
   559  	cFieldOrFirstMember string
   560  }
   561  
   562  func splitUnion(union *btf.Union, types types) (map[string]*btf.Struct, error) {
   563  	structs := make(map[string]*btf.Struct)
   564  
   565  	for i, t := range types {
   566  		member := union.Members[i]
   567  		s, ok := member.Type.(*btf.Struct)
   568  		if !ok {
   569  			return nil, fmt.Errorf("%q: %s is not a struct", t.name, member.Type)
   570  		}
   571  
   572  		if member.Name == "" {
   573  			// This is an anonymous struct, check the name of the first member instead.
   574  			if name := s.Members[0].Name; name != t.cFieldOrFirstMember {
   575  				return nil, fmt.Errorf("first field of %q is %q, not %q", t.name, name, t.cFieldOrFirstMember)
   576  			}
   577  		} else if member.Name != t.cFieldOrFirstMember {
   578  			return nil, fmt.Errorf("name for %q is %q, not %q", t.name, member.Name, t.cFieldOrFirstMember)
   579  		}
   580  
   581  		structs[t.name] = s
   582  	}
   583  
   584  	return structs, nil
   585  }
   586  
   587  type patch func(*btf.Struct) error
   588  
   589  func modify(fn func(*btf.Member) error, members ...string) patch {
   590  	return func(s *btf.Struct) error {
   591  		want := make(map[string]bool)
   592  		for _, name := range members {
   593  			want[name] = true
   594  		}
   595  
   596  		for i, m := range s.Members {
   597  			if want[m.Name] {
   598  				if err := fn(&s.Members[i]); err != nil {
   599  					return err
   600  				}
   601  				delete(want, m.Name)
   602  			}
   603  		}
   604  
   605  		if len(want) == 0 {
   606  			return nil
   607  		}
   608  
   609  		var missing []string
   610  		for name := range want {
   611  			missing = append(missing, name)
   612  		}
   613  		sort.Strings(missing)
   614  
   615  		return fmt.Errorf("missing members: %v", strings.Join(missing, ", "))
   616  	}
   617  }
   618  
   619  func modifyNth(fn func(*btf.Member) error, indices ...int) patch {
   620  	return func(s *btf.Struct) error {
   621  		for _, i := range indices {
   622  			if i >= len(s.Members) {
   623  				return fmt.Errorf("index %d is out of bounds", i)
   624  			}
   625  
   626  			if err := fn(&s.Members[i]); err != nil {
   627  				return fmt.Errorf("member #%d: %w", i, err)
   628  			}
   629  		}
   630  		return nil
   631  	}
   632  }
   633  
   634  func replace(t btf.Type, members ...string) patch {
   635  	return modify(func(m *btf.Member) error {
   636  		m.Type = t
   637  		return nil
   638  	}, members...)
   639  }
   640  
   641  func choose(member int, name string) patch {
   642  	return modifyNth(func(m *btf.Member) error {
   643  		union, ok := m.Type.(*btf.Union)
   644  		if !ok {
   645  			return fmt.Errorf("member %d is %s, not a union", member, m.Type)
   646  		}
   647  
   648  		for _, um := range union.Members {
   649  			if um.Name == name {
   650  				m.Name = um.Name
   651  				m.Type = um.Type
   652  				return nil
   653  			}
   654  		}
   655  
   656  		return fmt.Errorf("%s has no member %q", union, name)
   657  	}, member)
   658  }
   659  
   660  func chooseNth(member int, n int) patch {
   661  	return modifyNth(func(m *btf.Member) error {
   662  		union, ok := m.Type.(*btf.Union)
   663  		if !ok {
   664  			return fmt.Errorf("member %d is %s, not a union", member, m.Type)
   665  		}
   666  
   667  		if n >= len(union.Members) {
   668  			return fmt.Errorf("member %d is out of bounds", n)
   669  		}
   670  
   671  		um := union.Members[n]
   672  		m.Name = um.Name
   673  		m.Type = um.Type
   674  		return nil
   675  	}, member)
   676  }
   677  
   678  func flattenAnon(s *btf.Struct) error {
   679  	for i := range s.Members {
   680  		m := &s.Members[i]
   681  
   682  		cs, ok := m.Type.(*btf.Struct)
   683  		if !ok || cs.TypeName() != "" {
   684  			continue
   685  		}
   686  
   687  		for j := range cs.Members {
   688  			cs.Members[j].Offset += m.Offset
   689  		}
   690  
   691  		newMembers := make([]btf.Member, 0, len(s.Members)+len(cs.Members)-1)
   692  		newMembers = append(newMembers, s.Members[:i]...)
   693  		newMembers = append(newMembers, cs.Members...)
   694  		newMembers = append(newMembers, s.Members[i+1:]...)
   695  
   696  		s.Members = newMembers
   697  	}
   698  
   699  	return nil
   700  }
   701  
   702  func truncateAfter(name string) patch {
   703  	return func(s *btf.Struct) error {
   704  		for i, m := range s.Members {
   705  			if m.Name != name {
   706  				continue
   707  			}
   708  
   709  			size, err := btf.Sizeof(m.Type)
   710  			if err != nil {
   711  				return err
   712  			}
   713  
   714  			s.Members = s.Members[:i+1]
   715  			s.Size = m.Offset.Bytes() + uint32(size)
   716  			return nil
   717  		}
   718  
   719  		return fmt.Errorf("no member %q", name)
   720  	}
   721  }
   722  
   723  func rename(from, to string) patch {
   724  	return func(s *btf.Struct) error {
   725  		for i, m := range s.Members {
   726  			if m.Name == from {
   727  				s.Members[i].Name = to
   728  				return nil
   729  			}
   730  		}
   731  		return fmt.Errorf("no member named %q", from)
   732  	}
   733  }
   734  
   735  func name(member int, name string) patch {
   736  	return modifyNth(func(m *btf.Member) error {
   737  		if m.Name != "" {
   738  			return fmt.Errorf("member already has name %q", m.Name)
   739  		}
   740  
   741  		m.Name = name
   742  		return nil
   743  	}, member)
   744  }
   745  
   746  func replaceWithBytes(members ...string) patch {
   747  	return modify(func(m *btf.Member) error {
   748  		if m.BitfieldSize != 0 {
   749  			return errors.New("replaceWithBytes: member is a bitfield")
   750  		}
   751  
   752  		size, err := btf.Sizeof(m.Type)
   753  		if err != nil {
   754  			return fmt.Errorf("replaceWithBytes: size of %s: %w", m.Type, err)
   755  		}
   756  
   757  		m.Type = &btf.Array{
   758  			Type:   &btf.Int{Size: 1},
   759  			Nelems: uint32(size),
   760  		}
   761  
   762  		return nil
   763  	}, members...)
   764  }