github.com/castai/kvisor@v1.7.1-0.20240516114728-b3572a2607b5/tools/codegen/generator.go (about)

     1  package main
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"strings"
     7  	"unicode"
     8  
     9  	"github.com/iancoleman/strcase"
    10  )
    11  
    12  func generateTypes(events []eventDefinition) (string, error) {
    13  	sink := &strings.Builder{}
    14  
    15  	_, err := sink.WriteString(generateTypesFileHeader())
    16  	if err != nil {
    17  		return "", err
    18  	}
    19  
    20  	_, err = sink.WriteString("\n")
    21  	if err != nil {
    22  		return "", err
    23  	}
    24  
    25  	for i, definition := range events {
    26  		err = generateStruct(sink, definition)
    27  		if err != nil {
    28  			return "", err
    29  		}
    30  
    31  		if i < len(events)-1 {
    32  			_, err = sink.WriteString("\n")
    33  			if err != nil {
    34  				return "", err
    35  			}
    36  		}
    37  	}
    38  
    39  	return sink.String(), nil
    40  }
    41  
    42  func generateParsers(targetPackage string, events []eventDefinition) (string, error) {
    43  	sink := &strings.Builder{}
    44  
    45  	_, err := sink.WriteString(generateParsersFileHeader(targetPackage))
    46  	if err != nil {
    47  		return "", err
    48  	}
    49  
    50  	_, err = sink.WriteString("\n")
    51  	if err != nil {
    52  		return "", err
    53  	}
    54  
    55  	for _, definition := range events {
    56  		err = generatePerEventParserFunction(sink, definition)
    57  		if err != nil {
    58  			return "", err
    59  		}
    60  		_, err = sink.WriteString("\n")
    61  		if err != nil {
    62  			return "", err
    63  		}
    64  	}
    65  
    66  	err = generateParserFunction(sink, events)
    67  	if err != nil {
    68  		return "", err
    69  	}
    70  
    71  	return sink.String(), nil
    72  }
    73  
    74  func generateTypesFileHeader() string {
    75  	return fmt.Sprintf(`// Code generated by tools/codegen; DO NOT EDIT.
    76  
    77  package types
    78  
    79  type Args interface {
    80    args()
    81  }
    82  
    83  // internalArgs is a marker type to distinguish Args interface from basically any
    84  type internalArgs struct{}
    85  
    86  func (c internalArgs) args() {}
    87  `)
    88  }
    89  
    90  func generateParsersFileHeader(targetPackage string) string {
    91  	return fmt.Sprintf(`// Code generated by tools/codegen; DO NOT EDIT.
    92  
    93  package %s
    94  
    95  import (
    96    "errors"
    97  
    98    "github.com/castai/kvisor/pkg/ebpftracer/events"
    99    "github.com/castai/kvisor/pkg/ebpftracer/types"
   100  )
   101  
   102  var (
   103    ErrUnknownArgsType error = errors.New("unknown args type")
   104  )
   105  
   106  // eventMaxByteSliceBufferSize is used to determine the max slice size allowed for different
   107  // event types. For example, most events have a max size of 4096, but for network related events
   108  // there is no max size (this is represented as -1).
   109  func eventMaxByteSliceBufferSize(id events.ID) int {
   110    // For non network event, we have a max byte slice size of 4096
   111    if id < events.NetPacketBase || id > events.MaxNetID {
   112      return 4096
   113    }
   114  
   115    // Network events do not have a max buffer size.
   116    return -1
   117  }
   118  `, targetPackage)
   119  }
   120  
   121  func generateStruct(sink *strings.Builder, definition eventDefinition) error {
   122  	sink.WriteString(fmt.Sprintf(`type %s struct {
   123    internalArgs
   124  `, generateArgName(definition)))
   125  
   126  	if len(definition.params) > 0 {
   127  		_, err := sink.WriteRune('\n')
   128  		if err != nil {
   129  			return err
   130  		}
   131  	}
   132  
   133  	for _, p := range definition.params {
   134  		paramLine, err := generateParamStructField(p)
   135  		if err != nil {
   136  			return err
   137  		}
   138  
   139  		_, err = sink.WriteString(paramLine)
   140  		if err != nil {
   141  			return err
   142  		}
   143  
   144  		_, err = sink.WriteString("\n")
   145  		if err != nil {
   146  			return err
   147  		}
   148  	}
   149  	_, err := sink.WriteString("}\n")
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func generateParamStructField(p param) (string, error) {
   158  	goType, err := toGolangType(p.paramType)
   159  	if err != nil {
   160  		return "", err
   161  	}
   162  
   163  	return fmt.Sprintf("  %s %s", generateParamName(p), goType), nil
   164  }
   165  
   166  func capitalize(str string) string {
   167  	runes := []rune(str)
   168  	runes[0] = unicode.ToUpper(runes[0])
   169  	return string(runes)
   170  }
   171  
   172  func generateParamName(p param) string {
   173  	// This should transform parameters called e.g. `src_old` to `SrcOld`
   174  	return strcase.ToCamel(p.name)
   175  }
   176  
   177  func generateArgName(definition eventDefinition) string {
   178  	return fmt.Sprintf("%sArgs", capitalize(definition.event))
   179  }
   180  
   181  func toGolangType(t ArgType) (string, error) {
   182  	switch t {
   183  	case noneT:
   184  		return "", errors.New("cannot handle erorr type none!")
   185  	case u8T:
   186  		return "uint8", nil
   187  	case u16T:
   188  		return "uint16", nil
   189  	case intT:
   190  		return "int32", nil
   191  	case uintT, devT, modeT:
   192  		return "uint32", nil
   193  	case longT:
   194  		return "int64", nil
   195  	case ulongT, offT, sizeT:
   196  		return "uint64", nil
   197  	case boolT:
   198  		return "bool", nil
   199  	case pointerT:
   200  		return "uintptr", nil
   201  	case sockAddrT:
   202  		return "Sockaddr", nil
   203  	case credT:
   204  		return "SlimCred", nil
   205  	case strT:
   206  		return "string", nil
   207  	case strArrT, argsArrT:
   208  		return "[]string", nil
   209  	case bytesT:
   210  		return "[]byte", nil
   211  	case intArr2T:
   212  		return "[2]int32", nil
   213  	case uint64ArrT:
   214  		return "[]uint64", nil
   215  	case timespecT:
   216  		return "float64", nil
   217  	case tupleT:
   218  		return "AddrTuple", nil
   219  	case protoDNST:
   220  		return "*ProtoDNS", nil
   221  	}
   222  
   223  	return "", fmt.Errorf("unknown event type: %d", t)
   224  }
   225  
   226  func indent(str string, indentSize int) string {
   227  	if len(str) == 0 {
   228  		return str
   229  	}
   230  
   231  	indent := strings.Repeat(" ", indentSize)
   232  
   233  	var result strings.Builder
   234  
   235  	for _, line := range strings.Split(str, "\n") {
   236  		if len(strings.TrimSpace(line)) > 0 {
   237  			_, err := result.WriteString(indent)
   238  			if err != nil {
   239  				panic(err)
   240  			}
   241  
   242  			_, err = result.WriteString(line)
   243  			if err != nil {
   244  				panic(err)
   245  			}
   246  		}
   247  
   248  		_, err := result.WriteRune('\n')
   249  		if err != nil {
   250  			panic(err)
   251  		}
   252  	}
   253  
   254  	return result.String()
   255  }
   256  
   257  func generatePerEventParserFunction(sink *strings.Builder, definition eventDefinition) error {
   258  	eventName := generateArgName(definition)
   259  
   260  	_, err := sink.WriteString(fmt.Sprintf(`func Parse%s(decoder *Decoder) (types.%s, error) {
   261  `,
   262  		eventName, eventName))
   263  	if err != nil {
   264  		return err
   265  	}
   266  
   267  	if len(definition.params) == 0 {
   268  		_, err = sink.WriteString(fmt.Sprintf(`  return types.%s{}, nil
   269  }
   270  `, eventName))
   271  		if err != nil {
   272  			return err
   273  		}
   274  		return nil
   275  	}
   276  
   277  	_, err = sink.WriteString(fmt.Sprintf(`  var result types.%s
   278    var err error
   279  
   280  `, eventName))
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	_, err = sink.WriteString(generateParseNumArgsCode(definition))
   286  	if err != nil {
   287  		return err
   288  	}
   289  
   290  	_, err = sink.WriteString("\n")
   291  	if err != nil {
   292  		return err
   293  	}
   294  
   295  	_, err = sink.WriteString(`  for arg := 0; arg < int(numArgs); arg++ {
   296  `)
   297  	if err != nil {
   298  		return err
   299  	}
   300  
   301  	_, err = sink.WriteString(indent(generateCurrentArgCode(definition), 2))
   302  	if err != nil {
   303  		return err
   304  	}
   305  
   306  	_, err = sink.WriteString(indent(`  switch currArg {`, 2))
   307  	if err != nil {
   308  		return err
   309  	}
   310  
   311  	for i, p := range definition.params {
   312  		_, err = sink.WriteString(indent(fmt.Sprintf(`  case %d:`, i), 2))
   313  		if err != nil {
   314  			return err
   315  		}
   316  
   317  		line, err := getDecoderCode(definition, p)
   318  		if err != nil {
   319  			return err
   320  		}
   321  
   322  		_, err = sink.WriteString(indent(line, 4))
   323  		if err != nil {
   324  			return err
   325  		}
   326  	}
   327  
   328  	_, err = sink.WriteString(`    }
   329    }
   330  `)
   331  	if err != nil {
   332  		return err
   333  	}
   334  
   335  	_, err = sink.WriteString(`  return result, nil
   336  }
   337  `)
   338  	if err != nil {
   339  		return err
   340  	}
   341  
   342  	return nil
   343  }
   344  func generateParseNumArgsCode(definition eventDefinition) string {
   345  	return fmt.Sprintf(`  var numArgs uint8
   346    err = decoder.DecodeUint8(&numArgs)
   347  %s
   348  `, generateDecoderErrorCheck(definition))
   349  }
   350  
   351  func generateCurrentArgCode(definition eventDefinition) string {
   352  	return fmt.Sprintf(`  var currArg uint8
   353    err = decoder.DecodeUint8(&currArg)
   354  %s
   355  `, generateDecoderErrorCheck(definition))
   356  }
   357  
   358  func getDecoderCode(definition eventDefinition, p param) (string, error) {
   359  	paramName := generateParamName(p)
   360  	switch p.paramType {
   361  	case noneT:
   362  		return "", errors.New("cannot handle erorr type none!")
   363  	case u8T:
   364  		return fmt.Sprintf(`  err = decoder.DecodeUint8(&result.%s)
   365  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   366  	case u16T:
   367  		return fmt.Sprintf(`  err = decoder.DecodeUint16(&result.%s)
   368  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   369  	case intT:
   370  		return fmt.Sprintf(`  err = decoder.DecodeInt32(&result.%s)
   371  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   372  	case uintT, devT, modeT:
   373  		return fmt.Sprintf(`  err = decoder.DecodeUint32(&result.%s)
   374  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   375  	case longT:
   376  		return fmt.Sprintf(`  err = decoder.DecodeInt64(&result.%s)
   377  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   378  	case ulongT, offT, sizeT:
   379  		return fmt.Sprintf(`  err = decoder.DecodeUint64(&result.%s)
   380  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   381  	case boolT:
   382  		return fmt.Sprintf(`  err = decoder.DecodeBool(&result.%s)
   383  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   384  	case pointerT:
   385  		return fmt.Sprintf(`  var data%s uint64
   386    err = decoder.DecodeUint64(&data%s)
   387  %s
   388    result.%s = uintptr(data%s)`, paramName, paramName, generateDecoderErrorCheck(definition), paramName, paramName), nil
   389  	case sockAddrT:
   390  		return fmt.Sprintf(`  result.%s, err = decoder.ReadSockaddrFromBuff()
   391  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   392  	case credT:
   393  		return fmt.Sprintf(`  err = decoder.DecodeSlimCred(&result.%s)
   394  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   395  	case strT:
   396  		return fmt.Sprintf(`  result.%s, err = decoder.ReadStringFromBuff()
   397  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   398  	case strArrT:
   399  		return fmt.Sprintf(`  result.%s, err = decoder.ReadStringArrayFromBuff()
   400  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   401  	case argsArrT:
   402  		return fmt.Sprintf(`  result.%s, err = decoder.ReadArgsArrayFromBuff()
   403  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   404  	case bytesT:
   405  		return fmt.Sprintf(`  result.%s, err = decoder.ReadMaxByteSliceFromBuff(eventMaxByteSliceBufferSize(events.%s))
   406  %s`, paramName, definition.event, generateDecoderErrorCheck(definition)), nil
   407  	case intArr2T:
   408  		return fmt.Sprintf(`  err = decoder.DecodeIntArray(result.%s[:], 2)
   409  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   410  	case uint64ArrT:
   411  		return fmt.Sprintf(`  err = decoder.DecodeUint64Array(&result.%s)
   412  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   413  	case timespecT:
   414  		return fmt.Sprintf(`  result.%s, err = decoder.ReadTimespec()
   415  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   416  	case tupleT:
   417  		return fmt.Sprintf(`  result.%s, err = decoder.ReadAddrTuple()
   418  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   419  	case protoDNST:
   420  		return fmt.Sprintf(`  result.%s, err = decoder.ReadProtoDNS()
   421  %s`, paramName, generateDecoderErrorCheck(definition)), nil
   422  	}
   423  
   424  	return "", fmt.Errorf("unknown event type: %d", p.paramType)
   425  }
   426  
   427  func generateDecoderErrorCheck(definition eventDefinition) string {
   428  	return fmt.Sprintf(`  if err != nil {
   429      return types.%s{}, err
   430    }`, generateArgName(definition))
   431  }
   432  
   433  func generateParserFunction(sink *strings.Builder, definitions []eventDefinition) error {
   434  	_, err := sink.WriteString(`func ParseArgs(decoder *Decoder, event events.ID) (types.Args, error) {
   435    switch event {
   436  `)
   437  	if err != nil {
   438  		return err
   439  	}
   440  
   441  	for _, definition := range definitions {
   442  		sink.WriteString(fmt.Sprintf(`  case events.%s:
   443      return Parse%s(decoder)
   444  `, definition.event, generateArgName(definition)))
   445  	}
   446  
   447  	_, err = sink.WriteString(`  }
   448  
   449    return nil, ErrUnknownArgsType
   450  }
   451  `)
   452  	if err != nil {
   453  		return err
   454  	}
   455  	return nil
   456  }