github.com/DataDog/datadog-agent/pkg/security/secl@v0.55.0-devel.0.20240517055856-10c4965fea94/compiler/generators/accessors/accessors.go (about) 1 // Unless explicitly stated otherwise all files in this repository are licensed 2 // under the Apache License Version 2.0. 3 // This product includes software developed at Datadog (https://www.datadoghq.com/). 4 // Copyright 2016-present Datadog, Inc. 5 6 // Package main holds main related files 7 package main 8 9 import ( 10 "bufio" 11 "bytes" 12 _ "embed" 13 "flag" 14 "fmt" 15 "go/ast" 16 "log" 17 "os" 18 "os/exec" 19 "path" 20 "reflect" 21 "slices" 22 "strconv" 23 "strings" 24 "text/template" 25 "unicode" 26 27 "github.com/Masterminds/sprig/v3" 28 "github.com/davecgh/go-spew/spew" 29 "github.com/fatih/structtag" 30 "golang.org/x/text/cases" 31 "golang.org/x/text/language" 32 "golang.org/x/tools/go/packages" 33 34 "github.com/DataDog/datadog-agent/pkg/security/secl/compiler/generators/accessors/common" 35 "github.com/DataDog/datadog-agent/pkg/security/secl/compiler/generators/accessors/doc" 36 ) 37 38 const ( 39 pkgPrefix = "github.com/DataDog/datadog-agent/pkg/security/secl" 40 ) 41 42 var ( 43 modelFile string 44 typesFile string 45 pkgname string 46 output string 47 verbose bool 48 docOutput string 49 fieldHandlersOutput string 50 fieldAccessorsOutput string 51 buildTags string 52 ) 53 54 // AstFiles defines ast files 55 type AstFiles struct { 56 files []*ast.File 57 } 58 59 // LookupSymbol lookups symbol 60 func (af *AstFiles) LookupSymbol(symbol string) *ast.Object { 61 for _, file := range af.files { 62 if obj := file.Scope.Lookup(symbol); obj != nil { 63 return obj 64 } 65 } 66 return nil 67 } 68 69 // GetSpecs gets specs 70 func (af *AstFiles) GetSpecs() []ast.Spec { 71 var specs []ast.Spec 72 73 for _, file := range af.files { 74 for _, decl := range file.Decls { 75 decl, ok := decl.(*ast.GenDecl) 76 if !ok || decl.Doc == nil { 77 continue 78 } 79 80 var genaccessors bool 81 for _, document := range decl.Doc.List { 82 if strings.Contains(document.Text, "genaccessors") { 83 genaccessors = true 84 break 85 } 86 } 87 88 if !genaccessors { 89 continue 90 } 91 92 specs = append(specs, decl.Specs...) 93 } 94 } 95 96 return specs 97 } 98 99 func origTypeToBasicType(kind string) string { 100 switch kind { 101 case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64": 102 return "int" 103 } 104 return kind 105 } 106 107 func isNetType(kind string) bool { 108 return kind == "net.IPNet" 109 } 110 111 func isBasicType(kind string) bool { 112 switch kind { 113 case "string", "bool", "int", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "net.IPNet": 114 return true 115 } 116 return false 117 } 118 119 func isBasicTypeForGettersOnly(kind string) bool { 120 if isBasicType(kind) { 121 return true 122 } 123 124 switch kind { 125 case "time.Time": 126 return true 127 } 128 return false 129 } 130 131 func qualifiedType(module *common.Module, kind string) string { 132 switch kind { 133 case "int", "string", "bool": 134 return kind 135 default: 136 return module.SourcePkgPrefix + kind 137 } 138 } 139 140 // handleBasic adds fields of "basic" type to list of exposed SECL fields of the module 141 func handleBasic(module *common.Module, field seclField, name, alias, aliasPrefix, prefix, kind, event, opOverrides, commentText, containerStructName string, iterator *common.StructField, isArray bool) { 142 if verbose { 143 fmt.Printf("handleBasic name: %s, kind: %s, alias: %s, isArray: %v\n", name, kind, alias, isArray) 144 } 145 146 if prefix != "" { 147 name = prefix + "." + name 148 } 149 150 if aliasPrefix != "" { 151 alias = aliasPrefix + "." + alias 152 } 153 154 basicType := origTypeToBasicType(kind) 155 newStructField := &common.StructField{ 156 Name: name, 157 BasicType: basicType, 158 ReturnType: basicType, 159 IsArray: strings.HasPrefix(kind, "[]") || isArray, 160 Event: event, 161 OrigType: kind, 162 Iterator: iterator, 163 CommentText: commentText, 164 OpOverrides: opOverrides, 165 Struct: containerStructName, 166 Alias: alias, 167 AliasPrefix: aliasPrefix, 168 GettersOnly: field.gettersOnly, 169 } 170 171 module.Fields[alias] = newStructField 172 173 if _, ok := module.EventTypes[event]; !ok { 174 module.EventTypes[event] = common.NewEventTypeMetada() 175 } 176 177 if field.lengthField { 178 name = name + ".length" 179 aliasPrefix = alias 180 alias = alias + ".length" 181 182 newStructField := &common.StructField{ 183 Name: name, 184 BasicType: "int", 185 ReturnType: "int", 186 OrigType: "int", 187 IsArray: isArray, 188 IsLength: true, 189 Event: event, 190 Iterator: iterator, 191 CommentText: doc.SECLDocForLength, 192 OpOverrides: opOverrides, 193 Struct: "string", 194 Alias: alias, 195 AliasPrefix: aliasPrefix, 196 GettersOnly: field.gettersOnly, 197 } 198 199 module.Fields[alias] = newStructField 200 } 201 } 202 203 // handleEmbedded adds embedded fields to list of exposed SECL fields of the module 204 func handleEmbedded(module *common.Module, name, prefix, event string, fieldTypeExpr ast.Expr) { 205 if verbose { 206 log.Printf("handleEmbedded name: %s", name) 207 } 208 209 if prefix != "" { 210 name = fmt.Sprintf("%s.%s", prefix, name) 211 } 212 213 fieldType, isPointer, isArray := getFieldIdentName(fieldTypeExpr) 214 215 // maintain a list of all the fields 216 module.AllFields[name] = &common.StructField{ 217 Name: name, 218 Event: event, 219 OrigType: qualifiedType(module, fieldType), 220 IsOrigTypePtr: isPointer, 221 IsArray: isArray, 222 } 223 } 224 225 // handleNonEmbedded adds non-embedded fields to list of all possible (but not necessarily exposed) SECL fields of the module 226 func handleNonEmbedded(module *common.Module, field seclField, prefixedFieldName, event, fieldType string, isPointer, isArray bool) { 227 module.AllFields[prefixedFieldName] = &common.StructField{ 228 Name: prefixedFieldName, 229 Event: event, 230 OrigType: qualifiedType(module, fieldType), 231 IsOrigTypePtr: isPointer, 232 IsArray: isArray, 233 Check: field.check, 234 } 235 } 236 237 // handleIterator adds iterator to list of exposed SECL iterators of the module 238 func handleIterator(module *common.Module, field seclField, fieldType, iterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides string, isPointer, isArray bool) *common.StructField { 239 alias := field.name 240 if aliasPrefix != "" { 241 alias = aliasPrefix + "." + field.name 242 } 243 244 module.Iterators[alias] = &common.StructField{ 245 Name: prefixedFieldName, 246 ReturnType: qualifiedType(module, iterator), 247 Event: event, 248 OrigType: qualifiedType(module, fieldType), 249 IsOrigTypePtr: isPointer, 250 IsArray: isArray, 251 Weight: field.weight, 252 CommentText: fieldCommentText, 253 OpOverrides: opOverrides, 254 Helper: field.helper, 255 SkipADResolution: field.skipADResolution, 256 Check: field.check, 257 } 258 259 return module.Iterators[alias] 260 } 261 262 // handleFieldWithHandler adds non-embedded fields with handlers to list of exposed SECL fields and event types of the module 263 func handleFieldWithHandler(module *common.Module, field seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, containerStructName, event, fieldCommentText, opOverrides, handler string, isPointer, isArray bool, fieldIterator *common.StructField) { 264 alias := field.name 265 266 if aliasPrefix != "" { 267 alias = aliasPrefix + "." + alias 268 } 269 270 if event == "" { 271 log.Printf("event type not specified for field: %s", prefixedFieldName) 272 } 273 274 newStructField := &common.StructField{ 275 Prefix: prefix, 276 Name: prefixedFieldName, 277 BasicType: origTypeToBasicType(fieldType), 278 Struct: containerStructName, 279 Handler: handler, 280 ReturnType: origTypeToBasicType(fieldType), 281 Event: event, 282 OrigType: fieldType, 283 Iterator: fieldIterator, 284 IsArray: isArray, 285 Weight: field.weight, 286 CommentText: fieldCommentText, 287 OpOverrides: opOverrides, 288 Helper: field.helper, 289 SkipADResolution: field.skipADResolution, 290 IsOrigTypePtr: isPointer, 291 Check: field.check, 292 Alias: alias, 293 AliasPrefix: aliasPrefix, 294 GettersOnly: field.gettersOnly, 295 } 296 297 module.Fields[alias] = newStructField 298 299 if field.lengthField { 300 var lengthField = *module.Fields[alias] 301 lengthField.IsLength = true 302 lengthField.Name += ".length" 303 lengthField.OrigType = "int" 304 lengthField.BasicType = "int" 305 lengthField.ReturnType = "int" 306 lengthField.Struct = "string" 307 lengthField.AliasPrefix = alias 308 lengthField.Alias = alias + ".length" 309 lengthField.CommentText = doc.SECLDocForLength 310 311 module.Fields[lengthField.Alias] = &lengthField 312 } 313 314 if _, ok := module.EventTypes[event]; !ok { 315 module.EventTypes[event] = common.NewEventTypeMetada(alias) 316 } else { 317 module.EventTypes[event].Fields = append(module.EventTypes[event].Fields, alias) 318 } 319 } 320 321 func getFieldName(expr ast.Expr) string { 322 switch expr := expr.(type) { 323 case *ast.Ident: 324 return expr.Name 325 case *ast.StarExpr: 326 return getFieldName(expr.X) 327 case *ast.ArrayType: 328 return getFieldName(expr.Elt) 329 case *ast.SelectorExpr: 330 return getFieldName(expr.X) + "." + getFieldName(expr.Sel) 331 default: 332 return "" 333 } 334 } 335 336 func getFieldIdentName(expr ast.Expr) (name string, isPointer bool, isArray bool) { 337 switch expr.(type) { 338 case *ast.StarExpr: 339 isPointer = true 340 case *ast.ArrayType: 341 isArray = true 342 } 343 344 return getFieldName(expr), isPointer, isArray 345 } 346 347 type seclField struct { 348 name string 349 iterator string 350 handler string 351 helper bool // mark the handler as just a helper and not a real resolver. Won't be called by ResolveFields 352 skipADResolution bool 353 lengthField bool 354 weight int64 355 check string 356 exposedAtEventRootOnly bool // fields that should only be exposed at the root of an event, i.e. `parent` should not be exposed for an `ancestor` of a process 357 containerStructName string 358 gettersOnly bool // a field that is not exposed via SECL, but still has an accessor generated 359 } 360 361 func parseFieldDef(def string) (seclField, error) { 362 def = strings.TrimSpace(def) 363 alias, options, splitted := strings.Cut(def, ",") 364 365 field := seclField{name: alias} 366 367 if alias == "-" { 368 return field, nil 369 } 370 371 // arguments 372 if splitted { 373 for _, el := range strings.Split(options, ",") { 374 kv := strings.Split(el, ":") 375 376 key, value := kv[0], kv[1] 377 378 switch key { 379 case "handler": 380 field.handler = value 381 case "weight": 382 weight, err := strconv.ParseInt(value, 10, 64) 383 if err != nil { 384 return field, err 385 } 386 field.weight = weight 387 case "iterator": 388 field.iterator = value 389 case "check": 390 field.check = value 391 case "opts": 392 for _, opt := range strings.Split(value, "|") { 393 switch opt { 394 case "helper": 395 field.helper = true 396 case "length": 397 field.lengthField = true 398 case "skip_ad": 399 field.skipADResolution = true 400 case "exposed_at_event_root_only": 401 field.exposedAtEventRootOnly = true 402 case "getters_only": 403 field.gettersOnly = true 404 field.exposedAtEventRootOnly = true 405 } 406 } 407 } 408 } 409 } 410 411 return field, nil 412 } 413 414 // handleSpecRecursive is a recursive function that walks through the fields of a module 415 func handleSpecRecursive(module *common.Module, astFiles *AstFiles, spec interface{}, prefix, aliasPrefix, event string, iterator *common.StructField, dejavu map[string]bool) { 416 if verbose { 417 fmt.Printf("handleSpec spec: %+v, prefix: %s, aliasPrefix %s, event %s, iterator %+v\n", spec, prefix, aliasPrefix, event, iterator) 418 } 419 420 var typeSpec *ast.TypeSpec 421 var structType *ast.StructType 422 var ok bool 423 if typeSpec, ok = spec.(*ast.TypeSpec); !ok { 424 return 425 } 426 if structType, ok = typeSpec.Type.(*ast.StructType); !ok { 427 log.Printf("Don't know what to do with %s (%s)", typeSpec.Name, spew.Sdump(typeSpec)) 428 return 429 } 430 431 for _, field := range structType.Fields.List { 432 fieldCommentText := field.Comment.Text() 433 fieldIterator := iterator 434 435 var tag reflect.StructTag 436 if field.Tag != nil { 437 tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1]) 438 } 439 440 if e, ok := tag.Lookup("event"); ok { 441 event = e 442 if _, ok = module.EventTypes[e]; !ok { 443 module.EventTypes[e] = common.NewEventTypeMetada() 444 dejavu = make(map[string]bool) // clear dejavu map when it's a new event type 445 } 446 if e != "*" { 447 module.EventTypes[e].Doc = fieldCommentText 448 } 449 } 450 451 if isEmbedded := len(field.Names) == 0; isEmbedded { // embedded as in a struct embedded in another struct 452 if fieldTag, found := tag.Lookup("field"); found && fieldTag == "-" { 453 continue 454 } 455 456 ident, _ := field.Type.(*ast.Ident) 457 if ident == nil { 458 if starExpr, ok := field.Type.(*ast.StarExpr); ok { 459 ident, _ = starExpr.X.(*ast.Ident) 460 } 461 } 462 463 if ident != nil { 464 name := ident.Name 465 if prefix != "" { 466 name = prefix + "." + ident.Name 467 } 468 469 embedded := astFiles.LookupSymbol(ident.Name) 470 if embedded != nil { 471 handleEmbedded(module, ident.Name, prefix, event, field.Type) 472 handleSpecRecursive(module, astFiles, embedded.Decl, name, aliasPrefix, event, fieldIterator, dejavu) 473 } else { 474 log.Printf("failed to resolve symbol for identifier %+v in %s", ident.Name, pkgname) 475 } 476 } 477 } else { 478 fieldBasename := field.Names[0].Name 479 if !unicode.IsUpper(rune(fieldBasename[0])) { 480 continue 481 } 482 483 if dejavu[fieldBasename] { 484 continue 485 } 486 487 var opOverrides string 488 var fields []seclField 489 var gettersOnlyFields []seclField 490 if tags, err := structtag.Parse(string(tag)); err == nil && len(tags.Tags()) != 0 { 491 opOverrides, fields, gettersOnlyFields = parseTags(tags, typeSpec.Name.Name) 492 493 if opOverrides == "" && fields == nil && gettersOnlyFields == nil { 494 continue 495 } 496 } else { 497 fields = append(fields, seclField{name: fieldBasename}) 498 } 499 500 fieldType, isPointer, isArray := getFieldIdentName(field.Type) 501 502 prefixedFieldName := fieldBasename 503 if prefix != "" { 504 prefixedFieldName = fmt.Sprintf("%s.%s", prefix, fieldBasename) 505 } 506 507 for _, seclField := range fields { 508 handleNonEmbedded(module, seclField, prefixedFieldName, event, fieldType, isPointer, isArray) 509 510 if seclFieldIterator := seclField.iterator; seclFieldIterator != "" { 511 fieldIterator = handleIterator(module, seclField, fieldType, seclFieldIterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides, isPointer, isArray) 512 } 513 514 if handler := seclField.handler; handler != "" { 515 516 handleFieldWithHandler(module, seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, seclField.containerStructName, event, fieldCommentText, opOverrides, handler, isPointer, isArray, fieldIterator) 517 518 delete(dejavu, fieldBasename) 519 continue 520 } 521 522 if verbose { 523 log.Printf("Don't know what to do with %s: %s", fieldBasename, spew.Sdump(field.Type)) 524 } 525 526 dejavu[fieldBasename] = true 527 528 if len(fieldType) == 0 { 529 continue 530 } 531 532 if isNetType((fieldType)) { 533 if !slices.Contains(module.Imports, "net") { 534 module.Imports = append(module.Imports, "net") 535 } 536 } 537 538 alias := seclField.name 539 if isBasicType(fieldType) { 540 handleBasic(module, seclField, fieldBasename, alias, aliasPrefix, prefix, fieldType, event, opOverrides, fieldCommentText, seclField.containerStructName, fieldIterator, isArray) 541 } else { 542 spec := astFiles.LookupSymbol(fieldType) 543 if spec != nil { 544 newPrefix, newAliasPrefix := fieldBasename, alias 545 546 if prefix != "" { 547 newPrefix = prefix + "." + fieldBasename 548 } 549 550 if aliasPrefix != "" { 551 newAliasPrefix = aliasPrefix + "." + alias 552 } 553 554 handleSpecRecursive(module, astFiles, spec.Decl, newPrefix, newAliasPrefix, event, fieldIterator, dejavu) 555 } else { 556 log.Printf("failed to resolve symbol for type %+v in %s", fieldType, pkgname) 557 } 558 } 559 560 if !seclField.exposedAtEventRootOnly { 561 delete(dejavu, fieldBasename) 562 } 563 } 564 for _, seclField := range gettersOnlyFields { 565 handleNonEmbedded(module, seclField, prefixedFieldName, event, fieldType, isPointer, isArray) 566 567 if seclFieldIterator := seclField.iterator; seclFieldIterator != "" { 568 fieldIterator = handleIterator(module, seclField, fieldType, seclFieldIterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides, isPointer, isArray) 569 } 570 571 if handler := seclField.handler; handler != "" { 572 handleFieldWithHandler(module, seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, seclField.containerStructName, event, fieldCommentText, opOverrides, handler, isPointer, isArray, fieldIterator) 573 574 delete(dejavu, fieldBasename) 575 continue 576 } 577 578 if verbose { 579 log.Printf("Don't know what to do with %s: %s", fieldBasename, spew.Sdump(field.Type)) 580 } 581 582 dejavu[fieldBasename] = true 583 584 if len(fieldType) == 0 { 585 continue 586 } 587 588 alias := seclField.name 589 if isBasicTypeForGettersOnly(fieldType) { 590 handleBasic(module, seclField, fieldBasename, alias, aliasPrefix, prefix, fieldType, event, opOverrides, fieldCommentText, seclField.containerStructName, fieldIterator, isArray) 591 } else { 592 spec := astFiles.LookupSymbol(fieldType) 593 if spec != nil { 594 newPrefix, newAliasPrefix := fieldBasename, alias 595 596 if prefix != "" { 597 newPrefix = prefix + "." + fieldBasename 598 } 599 600 if aliasPrefix != "" { 601 newAliasPrefix = aliasPrefix + "." + alias 602 } 603 604 handleSpecRecursive(module, astFiles, spec.Decl, newPrefix, newAliasPrefix, event, fieldIterator, dejavu) 605 } else { 606 log.Printf("failed to resolve symbol for type %+v in %s", fieldType, pkgname) 607 } 608 } 609 610 if !seclField.exposedAtEventRootOnly { 611 delete(dejavu, fieldBasename) 612 } 613 } 614 } 615 } 616 } 617 618 func parseTags(tags *structtag.Tags, containerStructName string) (string, []seclField, []seclField) { 619 var opOverrides string 620 var fields []seclField 621 var gettersOnlyFields []seclField 622 623 for _, tag := range tags.Tags() { 624 switch tag.Key { 625 case "field": 626 fieldDefs := strings.Split(tag.Value(), ";") 627 for _, fieldDef := range fieldDefs { 628 field, err := parseFieldDef(fieldDef) 629 if err != nil { 630 log.Panicf("unable to parse field definition: %s", err) 631 } 632 633 if field.name == "-" { 634 return "", nil, nil 635 } 636 637 field.containerStructName = containerStructName 638 639 if field.gettersOnly { 640 gettersOnlyFields = append(gettersOnlyFields, field) 641 } else { 642 fields = append(fields, field) 643 } 644 } 645 646 case "op_override": 647 opOverrides = tag.Value() 648 } 649 } 650 651 return opOverrides, fields, gettersOnlyFields 652 } 653 654 func newAstFiles(cfg *packages.Config, files ...string) (*AstFiles, error) { 655 var astFiles AstFiles 656 657 for _, file := range files { 658 pkgs, err := packages.Load(cfg, file) 659 if err != nil { 660 return nil, err 661 } 662 663 if len(pkgs) == 0 || len(pkgs[0].Syntax) == 0 { 664 return nil, fmt.Errorf("failed to get syntax from parse file %s", file) 665 } 666 667 astFiles.files = append(astFiles.files, pkgs[0].Syntax[0]) 668 } 669 670 return &astFiles, nil 671 } 672 673 func parseFile(modelFile string, typesFile string, pkgName string) (*common.Module, error) { 674 cfg := packages.Config{ 675 Mode: packages.NeedSyntax | packages.NeedTypes | packages.NeedImports, 676 BuildFlags: []string{"-mod=mod", fmt.Sprintf("-tags=%s", buildTags)}, 677 } 678 679 astFiles, err := newAstFiles(&cfg, modelFile, typesFile) 680 if err != nil { 681 return nil, err 682 } 683 684 moduleName := path.Base(path.Dir(output)) 685 if moduleName == "." { 686 moduleName = path.Base(pkgName) 687 } 688 689 module := &common.Module{ 690 Name: moduleName, 691 SourcePkg: pkgName, 692 TargetPkg: pkgName, 693 BuildTags: formatBuildTags(buildTags), 694 Fields: make(map[string]*common.StructField), 695 AllFields: make(map[string]*common.StructField), 696 Iterators: make(map[string]*common.StructField), 697 EventTypes: make(map[string]*common.EventTypeMetadata), 698 } 699 700 // If the target package is different from the model package 701 if module.Name != path.Base(pkgName) { 702 module.SourcePkgPrefix = path.Base(pkgName) + "." 703 module.TargetPkg = path.Clean(path.Join(pkgName, path.Dir(output))) 704 } 705 706 for _, spec := range astFiles.GetSpecs() { 707 handleSpecRecursive(module, astFiles, spec, "", "", "", nil, make(map[string]bool)) 708 } 709 710 return module, nil 711 } 712 713 func formatBuildTags(buildTags string) []string { 714 splittedBuildTags := strings.Split(buildTags, ",") 715 var formattedBuildTags []string 716 for _, tag := range splittedBuildTags { 717 if tag != "" { 718 formattedBuildTags = append(formattedBuildTags, fmt.Sprintf("go:build %s", tag)) 719 } 720 } 721 return formattedBuildTags 722 } 723 724 func newField(allFields map[string]*common.StructField, field *common.StructField) string { 725 var fieldPath, result string 726 for _, node := range strings.Split(field.Name, ".") { 727 if fieldPath != "" { 728 fieldPath += "." + node 729 } else { 730 fieldPath = node 731 } 732 733 if field, ok := allFields[fieldPath]; ok { 734 if field.IsOrigTypePtr { 735 result += fmt.Sprintf("if ev.%s == nil { ev.%s = &%s{} }\n", field.Name, field.Name, field.OrigType) 736 } 737 } 738 } 739 740 return result 741 } 742 743 func generatePrefixNilChecks(allFields map[string]*common.StructField, returnType string, field *common.StructField) string { 744 var fieldPath, result string 745 for _, node := range strings.Split(field.Name, ".") { 746 if fieldPath != "" { 747 fieldPath += "." + node 748 } else { 749 fieldPath = node 750 } 751 752 if field, ok := allFields[fieldPath]; ok { 753 if field.IsOrigTypePtr { 754 result += fmt.Sprintf("if ev.%s == nil { return %s }\n", field.Name, getDefaultValueOfType(returnType)) 755 } 756 } 757 } 758 759 return result 760 } 761 762 func split(r rune) bool { 763 return r == '.' || r == '_' 764 } 765 766 func pascalCaseFieldName(fieldName string) string { 767 chunks := strings.FieldsFunc(fieldName, split) 768 caser := cases.Title(language.English, cases.NoLower) 769 770 for idx, chunk := range chunks { 771 newChunk := chunk 772 chunks[idx] = caser.String(newChunk) 773 } 774 775 return strings.Join(chunks, "") 776 } 777 778 func getDefaultValueOfType(returnType string) string { 779 baseType, isArray := strings.CutPrefix(returnType, "[]") 780 781 if baseType == "int" { 782 if isArray { 783 return "[]int{}" 784 } 785 return "0" 786 } else if baseType == "int64" { 787 if isArray { 788 return "[]int64{}" 789 } 790 return "int64(0)" 791 } else if baseType == "uint16" { 792 if isArray { 793 return "[]uint16{}" 794 } 795 return "uint16(0)" 796 } else if baseType == "uint32" { 797 if isArray { 798 return "[]uint32{}" 799 } 800 return "uint32(0)" 801 } else if baseType == "uint64" { 802 if isArray { 803 return "[]uint64{}" 804 } 805 return "uint64(0)" 806 } else if baseType == "bool" { 807 if isArray { 808 return "[]bool{}" 809 } 810 return "false" 811 } else if baseType == "net.IPNet" { 812 if isArray { 813 return "&eval.CIDRValues{}" 814 } 815 return "net.IPNet{}" 816 } else if baseType == "time.Time" { 817 if isArray { 818 return "[]time.Time{}" 819 } 820 return "time.Time{}" 821 } else if isArray { 822 return "[]string{}" 823 } 824 return `""` 825 } 826 827 func needScrubbed(fieldName string) bool { 828 loweredFieldName := strings.ToLower(fieldName) 829 if (strings.Contains(loweredFieldName, "argv") && !strings.Contains(loweredFieldName, "argv0")) && !strings.Contains(loweredFieldName, "module") { 830 return true 831 } 832 return false 833 } 834 835 func addSuffixToFuncPrototype(suffix string, prototype string) string { 836 chunks := strings.SplitN(prototype, "(", 3) 837 chunks = append(chunks[:1], append([]string{suffix, "("}, chunks[1:]...)...) 838 839 return strings.Join(chunks, "") 840 } 841 842 func getFieldHandler(allFields map[string]*common.StructField, field *common.StructField) string { 843 if field.Handler == "" || field.Iterator != nil || field.Helper { 844 return "" 845 } 846 847 if field.Prefix == "" { 848 return fmt.Sprintf("ev.FieldHandlers.%s(ev)", field.Handler) 849 } 850 851 ptr := "&" 852 if allFields[field.Prefix].IsOrigTypePtr { 853 ptr = "" 854 } 855 856 return fmt.Sprintf("ev.FieldHandlers.%s(ev, %sev.%s)", field.Handler, ptr, field.Prefix) 857 } 858 859 func fieldADPrint(field *common.StructField, handler string) string { 860 if field.SkipADResolution { 861 return fmt.Sprintf("if !forADs { _ = %s }", handler) 862 } 863 return fmt.Sprintf("_ = %s", handler) 864 } 865 866 func getHolder(allFields map[string]*common.StructField, field *common.StructField) *common.StructField { 867 idx := strings.LastIndex(field.Name, ".") 868 if idx == -1 { 869 return nil 870 } 871 name := field.Name[:idx] 872 return allFields[name] 873 } 874 875 func getChecks(allFields map[string]*common.StructField, field *common.StructField) []string { 876 var checks []string 877 878 name := field.Name 879 for name != "" { 880 field := allFields[name] 881 if field == nil { 882 break 883 } 884 885 if field.Check != "" { 886 if holder := getHolder(allFields, field); holder != nil { 887 check := fmt.Sprintf(`%s.%s`, holder.Name, field.Check) 888 checks = append([]string{check}, checks...) 889 } 890 } 891 892 idx := strings.LastIndex(name, ".") 893 if idx == -1 { 894 break 895 } 896 name = name[:idx] 897 } 898 899 return checks 900 } 901 902 func getHandlers(allFields map[string]*common.StructField) map[string]string { 903 handlers := make(map[string]string) 904 905 for _, field := range allFields { 906 if field.Handler != "" && !field.IsLength { 907 returnType := field.ReturnType 908 if field.IsArray { 909 returnType = "[]" + returnType 910 } 911 912 var handler string 913 if field.Prefix == "" { 914 handler = fmt.Sprintf("%s(ev *Event) %s", field.Handler, returnType) 915 } else { 916 handler = fmt.Sprintf("%s(ev *Event, e *%s) %s", field.Handler, field.Struct, returnType) 917 } 918 919 if _, exists := handlers[handler]; exists { 920 continue 921 } 922 923 var name string 924 if field.Prefix == "" { 925 name = "ev." + field.Name 926 } else { 927 name = "e" + strings.TrimPrefix(field.Name, field.Prefix) 928 } 929 930 if field.ReturnType == "int" { 931 if field.IsArray { 932 handlers[handler] = fmt.Sprintf("{ var result []int; for _, value := range %s { result = append(result, int(value)) }; return result }", name) 933 } else { 934 handlers[handler] = fmt.Sprintf("{ return int(%s) }", name) 935 } 936 } else { 937 handlers[handler] = fmt.Sprintf("{ return %s }", name) 938 } 939 } 940 } 941 942 return handlers 943 } 944 945 var funcMap = map[string]interface{}{ 946 "TrimPrefix": strings.TrimPrefix, 947 "TrimSuffix": strings.TrimSuffix, 948 "HasPrefix": strings.HasPrefix, 949 "NewField": newField, 950 "GeneratePrefixNilChecks": generatePrefixNilChecks, 951 "GetFieldHandler": getFieldHandler, 952 "FieldADPrint": fieldADPrint, 953 "GetChecks": getChecks, 954 "GetHandlers": getHandlers, 955 "PascalCaseFieldName": pascalCaseFieldName, 956 "GetDefaultValueOfType": getDefaultValueOfType, 957 "NeedScrubbed": needScrubbed, 958 "AddSuffixToFuncPrototype": addSuffixToFuncPrototype, 959 } 960 961 //go:embed accessors.tmpl 962 var accessorsTemplateCode string 963 964 //go:embed field_handlers.tmpl 965 var fieldHandlersTemplate string 966 967 //go:embed field_accessors.tmpl 968 var perFieldAccessorsTemplate string 969 970 func main() { 971 module, err := parseFile(modelFile, typesFile, pkgname) 972 if err != nil { 973 panic(err) 974 } 975 976 if len(fieldHandlersOutput) > 0 { 977 if err = GenerateContent(fieldHandlersOutput, module, fieldHandlersTemplate); err != nil { 978 panic(err) 979 } 980 } 981 982 if docOutput != "" { 983 os.Remove(docOutput) 984 if err := doc.GenerateDocJSON(module, path.Dir(modelFile), docOutput); err != nil { 985 panic(err) 986 } 987 } 988 989 os.Remove(output) 990 if err := GenerateContent(output, module, accessorsTemplateCode); err != nil { 991 panic(err) 992 } 993 994 if err := GenerateContent(fieldAccessorsOutput, module, perFieldAccessorsTemplate); err != nil { 995 panic(err) 996 } 997 } 998 999 // GenerateContent generates with the given template 1000 func GenerateContent(output string, module *common.Module, tmplCode string) error { 1001 tmpl := template.Must(template.New("header").Funcs(funcMap).Funcs(sprig.TxtFuncMap()).Parse(tmplCode)) 1002 1003 buffer := bytes.Buffer{} 1004 if err := tmpl.Execute(&buffer, module); err != nil { 1005 return err 1006 } 1007 1008 cleaned := removeEmptyLines(&buffer) 1009 1010 tmpfile, err := os.CreateTemp(path.Dir(output), "secl-helpers") 1011 if err != nil { 1012 return err 1013 } 1014 1015 if _, err := tmpfile.WriteString(cleaned); err != nil { 1016 return err 1017 } 1018 1019 if err := tmpfile.Close(); err != nil { 1020 return err 1021 } 1022 1023 cmd := exec.Command("gofmt", "-s", "-w", tmpfile.Name()) 1024 if output, err := cmd.CombinedOutput(); err != nil { 1025 log.Fatal(string(output)) 1026 return err 1027 } 1028 1029 return os.Rename(tmpfile.Name(), output) 1030 } 1031 1032 func removeEmptyLines(input *bytes.Buffer) string { 1033 scanner := bufio.NewScanner(input) 1034 builder := strings.Builder{} 1035 inGoCode := false 1036 1037 for scanner.Scan() { 1038 trimmed := strings.TrimSpace(scanner.Text()) 1039 1040 if strings.HasPrefix(trimmed, "package") { 1041 inGoCode = true 1042 } 1043 1044 if len(trimmed) != 0 || !inGoCode { 1045 builder.WriteString(trimmed) 1046 builder.WriteRune('\n') 1047 } 1048 } 1049 1050 return builder.String() 1051 } 1052 1053 func init() { 1054 flag.BoolVar(&verbose, "verbose", false, "Be verbose") 1055 flag.StringVar(&docOutput, "doc", "", "Generate documentation JSON") 1056 flag.StringVar(&fieldHandlersOutput, "field-handlers", "field_handlers_unix.go", "Field handlers output file") 1057 flag.StringVar(&modelFile, "input", os.Getenv("GOFILE"), "Go file to generate decoders from") 1058 flag.StringVar(&typesFile, "types-file", os.Getenv("TYPESFILE"), "Go type file to use with the model file") 1059 flag.StringVar(&pkgname, "package", pkgPrefix+"/"+os.Getenv("GOPACKAGE"), "Go package name") 1060 flag.StringVar(&buildTags, "tags", "unix", "build tags used for parsing") 1061 flag.StringVar(&fieldAccessorsOutput, "field-accessors-output", "field_accessors_unix.go", "Generated per-field accessors output file") 1062 flag.StringVar(&output, "output", "accessors_unix.go", "Go generated file") 1063 flag.Parse() 1064 }