github.com/castai/kvisor@v1.7.1-0.20240516114728-b3572a2607b5/tools/codegen/generator.go (about) 1 package main 2 3 import ( 4 "errors" 5 "fmt" 6 "strings" 7 "unicode" 8 9 "github.com/iancoleman/strcase" 10 ) 11 12 func generateTypes(events []eventDefinition) (string, error) { 13 sink := &strings.Builder{} 14 15 _, err := sink.WriteString(generateTypesFileHeader()) 16 if err != nil { 17 return "", err 18 } 19 20 _, err = sink.WriteString("\n") 21 if err != nil { 22 return "", err 23 } 24 25 for i, definition := range events { 26 err = generateStruct(sink, definition) 27 if err != nil { 28 return "", err 29 } 30 31 if i < len(events)-1 { 32 _, err = sink.WriteString("\n") 33 if err != nil { 34 return "", err 35 } 36 } 37 } 38 39 return sink.String(), nil 40 } 41 42 func generateParsers(targetPackage string, events []eventDefinition) (string, error) { 43 sink := &strings.Builder{} 44 45 _, err := sink.WriteString(generateParsersFileHeader(targetPackage)) 46 if err != nil { 47 return "", err 48 } 49 50 _, err = sink.WriteString("\n") 51 if err != nil { 52 return "", err 53 } 54 55 for _, definition := range events { 56 err = generatePerEventParserFunction(sink, definition) 57 if err != nil { 58 return "", err 59 } 60 _, err = sink.WriteString("\n") 61 if err != nil { 62 return "", err 63 } 64 } 65 66 err = generateParserFunction(sink, events) 67 if err != nil { 68 return "", err 69 } 70 71 return sink.String(), nil 72 } 73 74 func generateTypesFileHeader() string { 75 return fmt.Sprintf(`// Code generated by tools/codegen; DO NOT EDIT. 76 77 package types 78 79 type Args interface { 80 args() 81 } 82 83 // internalArgs is a marker type to distinguish Args interface from basically any 84 type internalArgs struct{} 85 86 func (c internalArgs) args() {} 87 `) 88 } 89 90 func generateParsersFileHeader(targetPackage string) string { 91 return fmt.Sprintf(`// Code generated by tools/codegen; DO NOT EDIT. 92 93 package %s 94 95 import ( 96 "errors" 97 98 "github.com/castai/kvisor/pkg/ebpftracer/events" 99 "github.com/castai/kvisor/pkg/ebpftracer/types" 100 ) 101 102 var ( 103 ErrUnknownArgsType error = errors.New("unknown args type") 104 ) 105 106 // eventMaxByteSliceBufferSize is used to determine the max slice size allowed for different 107 // event types. For example, most events have a max size of 4096, but for network related events 108 // there is no max size (this is represented as -1). 109 func eventMaxByteSliceBufferSize(id events.ID) int { 110 // For non network event, we have a max byte slice size of 4096 111 if id < events.NetPacketBase || id > events.MaxNetID { 112 return 4096 113 } 114 115 // Network events do not have a max buffer size. 116 return -1 117 } 118 `, targetPackage) 119 } 120 121 func generateStruct(sink *strings.Builder, definition eventDefinition) error { 122 sink.WriteString(fmt.Sprintf(`type %s struct { 123 internalArgs 124 `, generateArgName(definition))) 125 126 if len(definition.params) > 0 { 127 _, err := sink.WriteRune('\n') 128 if err != nil { 129 return err 130 } 131 } 132 133 for _, p := range definition.params { 134 paramLine, err := generateParamStructField(p) 135 if err != nil { 136 return err 137 } 138 139 _, err = sink.WriteString(paramLine) 140 if err != nil { 141 return err 142 } 143 144 _, err = sink.WriteString("\n") 145 if err != nil { 146 return err 147 } 148 } 149 _, err := sink.WriteString("}\n") 150 if err != nil { 151 return err 152 } 153 154 return nil 155 } 156 157 func generateParamStructField(p param) (string, error) { 158 goType, err := toGolangType(p.paramType) 159 if err != nil { 160 return "", err 161 } 162 163 return fmt.Sprintf(" %s %s", generateParamName(p), goType), nil 164 } 165 166 func capitalize(str string) string { 167 runes := []rune(str) 168 runes[0] = unicode.ToUpper(runes[0]) 169 return string(runes) 170 } 171 172 func generateParamName(p param) string { 173 // This should transform parameters called e.g. `src_old` to `SrcOld` 174 return strcase.ToCamel(p.name) 175 } 176 177 func generateArgName(definition eventDefinition) string { 178 return fmt.Sprintf("%sArgs", capitalize(definition.event)) 179 } 180 181 func toGolangType(t ArgType) (string, error) { 182 switch t { 183 case noneT: 184 return "", errors.New("cannot handle erorr type none!") 185 case u8T: 186 return "uint8", nil 187 case u16T: 188 return "uint16", nil 189 case intT: 190 return "int32", nil 191 case uintT, devT, modeT: 192 return "uint32", nil 193 case longT: 194 return "int64", nil 195 case ulongT, offT, sizeT: 196 return "uint64", nil 197 case boolT: 198 return "bool", nil 199 case pointerT: 200 return "uintptr", nil 201 case sockAddrT: 202 return "Sockaddr", nil 203 case credT: 204 return "SlimCred", nil 205 case strT: 206 return "string", nil 207 case strArrT, argsArrT: 208 return "[]string", nil 209 case bytesT: 210 return "[]byte", nil 211 case intArr2T: 212 return "[2]int32", nil 213 case uint64ArrT: 214 return "[]uint64", nil 215 case timespecT: 216 return "float64", nil 217 case tupleT: 218 return "AddrTuple", nil 219 case protoDNST: 220 return "*ProtoDNS", nil 221 } 222 223 return "", fmt.Errorf("unknown event type: %d", t) 224 } 225 226 func indent(str string, indentSize int) string { 227 if len(str) == 0 { 228 return str 229 } 230 231 indent := strings.Repeat(" ", indentSize) 232 233 var result strings.Builder 234 235 for _, line := range strings.Split(str, "\n") { 236 if len(strings.TrimSpace(line)) > 0 { 237 _, err := result.WriteString(indent) 238 if err != nil { 239 panic(err) 240 } 241 242 _, err = result.WriteString(line) 243 if err != nil { 244 panic(err) 245 } 246 } 247 248 _, err := result.WriteRune('\n') 249 if err != nil { 250 panic(err) 251 } 252 } 253 254 return result.String() 255 } 256 257 func generatePerEventParserFunction(sink *strings.Builder, definition eventDefinition) error { 258 eventName := generateArgName(definition) 259 260 _, err := sink.WriteString(fmt.Sprintf(`func Parse%s(decoder *Decoder) (types.%s, error) { 261 `, 262 eventName, eventName)) 263 if err != nil { 264 return err 265 } 266 267 if len(definition.params) == 0 { 268 _, err = sink.WriteString(fmt.Sprintf(` return types.%s{}, nil 269 } 270 `, eventName)) 271 if err != nil { 272 return err 273 } 274 return nil 275 } 276 277 _, err = sink.WriteString(fmt.Sprintf(` var result types.%s 278 var err error 279 280 `, eventName)) 281 if err != nil { 282 return err 283 } 284 285 _, err = sink.WriteString(generateParseNumArgsCode(definition)) 286 if err != nil { 287 return err 288 } 289 290 _, err = sink.WriteString("\n") 291 if err != nil { 292 return err 293 } 294 295 _, err = sink.WriteString(` for arg := 0; arg < int(numArgs); arg++ { 296 `) 297 if err != nil { 298 return err 299 } 300 301 _, err = sink.WriteString(indent(generateCurrentArgCode(definition), 2)) 302 if err != nil { 303 return err 304 } 305 306 _, err = sink.WriteString(indent(` switch currArg {`, 2)) 307 if err != nil { 308 return err 309 } 310 311 for i, p := range definition.params { 312 _, err = sink.WriteString(indent(fmt.Sprintf(` case %d:`, i), 2)) 313 if err != nil { 314 return err 315 } 316 317 line, err := getDecoderCode(definition, p) 318 if err != nil { 319 return err 320 } 321 322 _, err = sink.WriteString(indent(line, 4)) 323 if err != nil { 324 return err 325 } 326 } 327 328 _, err = sink.WriteString(` } 329 } 330 `) 331 if err != nil { 332 return err 333 } 334 335 _, err = sink.WriteString(` return result, nil 336 } 337 `) 338 if err != nil { 339 return err 340 } 341 342 return nil 343 } 344 func generateParseNumArgsCode(definition eventDefinition) string { 345 return fmt.Sprintf(` var numArgs uint8 346 err = decoder.DecodeUint8(&numArgs) 347 %s 348 `, generateDecoderErrorCheck(definition)) 349 } 350 351 func generateCurrentArgCode(definition eventDefinition) string { 352 return fmt.Sprintf(` var currArg uint8 353 err = decoder.DecodeUint8(&currArg) 354 %s 355 `, generateDecoderErrorCheck(definition)) 356 } 357 358 func getDecoderCode(definition eventDefinition, p param) (string, error) { 359 paramName := generateParamName(p) 360 switch p.paramType { 361 case noneT: 362 return "", errors.New("cannot handle erorr type none!") 363 case u8T: 364 return fmt.Sprintf(` err = decoder.DecodeUint8(&result.%s) 365 %s`, paramName, generateDecoderErrorCheck(definition)), nil 366 case u16T: 367 return fmt.Sprintf(` err = decoder.DecodeUint16(&result.%s) 368 %s`, paramName, generateDecoderErrorCheck(definition)), nil 369 case intT: 370 return fmt.Sprintf(` err = decoder.DecodeInt32(&result.%s) 371 %s`, paramName, generateDecoderErrorCheck(definition)), nil 372 case uintT, devT, modeT: 373 return fmt.Sprintf(` err = decoder.DecodeUint32(&result.%s) 374 %s`, paramName, generateDecoderErrorCheck(definition)), nil 375 case longT: 376 return fmt.Sprintf(` err = decoder.DecodeInt64(&result.%s) 377 %s`, paramName, generateDecoderErrorCheck(definition)), nil 378 case ulongT, offT, sizeT: 379 return fmt.Sprintf(` err = decoder.DecodeUint64(&result.%s) 380 %s`, paramName, generateDecoderErrorCheck(definition)), nil 381 case boolT: 382 return fmt.Sprintf(` err = decoder.DecodeBool(&result.%s) 383 %s`, paramName, generateDecoderErrorCheck(definition)), nil 384 case pointerT: 385 return fmt.Sprintf(` var data%s uint64 386 err = decoder.DecodeUint64(&data%s) 387 %s 388 result.%s = uintptr(data%s)`, paramName, paramName, generateDecoderErrorCheck(definition), paramName, paramName), nil 389 case sockAddrT: 390 return fmt.Sprintf(` result.%s, err = decoder.ReadSockaddrFromBuff() 391 %s`, paramName, generateDecoderErrorCheck(definition)), nil 392 case credT: 393 return fmt.Sprintf(` err = decoder.DecodeSlimCred(&result.%s) 394 %s`, paramName, generateDecoderErrorCheck(definition)), nil 395 case strT: 396 return fmt.Sprintf(` result.%s, err = decoder.ReadStringFromBuff() 397 %s`, paramName, generateDecoderErrorCheck(definition)), nil 398 case strArrT: 399 return fmt.Sprintf(` result.%s, err = decoder.ReadStringArrayFromBuff() 400 %s`, paramName, generateDecoderErrorCheck(definition)), nil 401 case argsArrT: 402 return fmt.Sprintf(` result.%s, err = decoder.ReadArgsArrayFromBuff() 403 %s`, paramName, generateDecoderErrorCheck(definition)), nil 404 case bytesT: 405 return fmt.Sprintf(` result.%s, err = decoder.ReadMaxByteSliceFromBuff(eventMaxByteSliceBufferSize(events.%s)) 406 %s`, paramName, definition.event, generateDecoderErrorCheck(definition)), nil 407 case intArr2T: 408 return fmt.Sprintf(` err = decoder.DecodeIntArray(result.%s[:], 2) 409 %s`, paramName, generateDecoderErrorCheck(definition)), nil 410 case uint64ArrT: 411 return fmt.Sprintf(` err = decoder.DecodeUint64Array(&result.%s) 412 %s`, paramName, generateDecoderErrorCheck(definition)), nil 413 case timespecT: 414 return fmt.Sprintf(` result.%s, err = decoder.ReadTimespec() 415 %s`, paramName, generateDecoderErrorCheck(definition)), nil 416 case tupleT: 417 return fmt.Sprintf(` result.%s, err = decoder.ReadAddrTuple() 418 %s`, paramName, generateDecoderErrorCheck(definition)), nil 419 case protoDNST: 420 return fmt.Sprintf(` result.%s, err = decoder.ReadProtoDNS() 421 %s`, paramName, generateDecoderErrorCheck(definition)), nil 422 } 423 424 return "", fmt.Errorf("unknown event type: %d", p.paramType) 425 } 426 427 func generateDecoderErrorCheck(definition eventDefinition) string { 428 return fmt.Sprintf(` if err != nil { 429 return types.%s{}, err 430 }`, generateArgName(definition)) 431 } 432 433 func generateParserFunction(sink *strings.Builder, definitions []eventDefinition) error { 434 _, err := sink.WriteString(`func ParseArgs(decoder *Decoder, event events.ID) (types.Args, error) { 435 switch event { 436 `) 437 if err != nil { 438 return err 439 } 440 441 for _, definition := range definitions { 442 sink.WriteString(fmt.Sprintf(` case events.%s: 443 return Parse%s(decoder) 444 `, definition.event, generateArgName(definition))) 445 } 446 447 _, err = sink.WriteString(` } 448 449 return nil, ErrUnknownArgsType 450 } 451 `) 452 if err != nil { 453 return err 454 } 455 return nil 456 }