github.com/DataDog/datadog-agent/pkg/security/secl@v0.55.0-devel.0.20240517055856-10c4965fea94/compiler/generators/accessors/accessors.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the Apache License Version 2.0.
     3  // This product includes software developed at Datadog (https://www.datadoghq.com/).
     4  // Copyright 2016-present Datadog, Inc.
     5  
     6  // Package main holds main related files
     7  package main
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	_ "embed"
    13  	"flag"
    14  	"fmt"
    15  	"go/ast"
    16  	"log"
    17  	"os"
    18  	"os/exec"
    19  	"path"
    20  	"reflect"
    21  	"slices"
    22  	"strconv"
    23  	"strings"
    24  	"text/template"
    25  	"unicode"
    26  
    27  	"github.com/Masterminds/sprig/v3"
    28  	"github.com/davecgh/go-spew/spew"
    29  	"github.com/fatih/structtag"
    30  	"golang.org/x/text/cases"
    31  	"golang.org/x/text/language"
    32  	"golang.org/x/tools/go/packages"
    33  
    34  	"github.com/DataDog/datadog-agent/pkg/security/secl/compiler/generators/accessors/common"
    35  	"github.com/DataDog/datadog-agent/pkg/security/secl/compiler/generators/accessors/doc"
    36  )
    37  
    38  const (
    39  	pkgPrefix = "github.com/DataDog/datadog-agent/pkg/security/secl"
    40  )
    41  
    42  var (
    43  	modelFile            string
    44  	typesFile            string
    45  	pkgname              string
    46  	output               string
    47  	verbose              bool
    48  	docOutput            string
    49  	fieldHandlersOutput  string
    50  	fieldAccessorsOutput string
    51  	buildTags            string
    52  )
    53  
    54  // AstFiles defines ast files
    55  type AstFiles struct {
    56  	files []*ast.File
    57  }
    58  
    59  // LookupSymbol lookups symbol
    60  func (af *AstFiles) LookupSymbol(symbol string) *ast.Object {
    61  	for _, file := range af.files {
    62  		if obj := file.Scope.Lookup(symbol); obj != nil {
    63  			return obj
    64  		}
    65  	}
    66  	return nil
    67  }
    68  
    69  // GetSpecs gets specs
    70  func (af *AstFiles) GetSpecs() []ast.Spec {
    71  	var specs []ast.Spec
    72  
    73  	for _, file := range af.files {
    74  		for _, decl := range file.Decls {
    75  			decl, ok := decl.(*ast.GenDecl)
    76  			if !ok || decl.Doc == nil {
    77  				continue
    78  			}
    79  
    80  			var genaccessors bool
    81  			for _, document := range decl.Doc.List {
    82  				if strings.Contains(document.Text, "genaccessors") {
    83  					genaccessors = true
    84  					break
    85  				}
    86  			}
    87  
    88  			if !genaccessors {
    89  				continue
    90  			}
    91  
    92  			specs = append(specs, decl.Specs...)
    93  		}
    94  	}
    95  
    96  	return specs
    97  }
    98  
    99  func origTypeToBasicType(kind string) string {
   100  	switch kind {
   101  	case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64":
   102  		return "int"
   103  	}
   104  	return kind
   105  }
   106  
   107  func isNetType(kind string) bool {
   108  	return kind == "net.IPNet"
   109  }
   110  
   111  func isBasicType(kind string) bool {
   112  	switch kind {
   113  	case "string", "bool", "int", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "net.IPNet":
   114  		return true
   115  	}
   116  	return false
   117  }
   118  
   119  func isBasicTypeForGettersOnly(kind string) bool {
   120  	if isBasicType(kind) {
   121  		return true
   122  	}
   123  
   124  	switch kind {
   125  	case "time.Time":
   126  		return true
   127  	}
   128  	return false
   129  }
   130  
   131  func qualifiedType(module *common.Module, kind string) string {
   132  	switch kind {
   133  	case "int", "string", "bool":
   134  		return kind
   135  	default:
   136  		return module.SourcePkgPrefix + kind
   137  	}
   138  }
   139  
   140  // handleBasic adds fields of "basic" type to list of exposed SECL fields of the module
   141  func handleBasic(module *common.Module, field seclField, name, alias, aliasPrefix, prefix, kind, event, opOverrides, commentText, containerStructName string, iterator *common.StructField, isArray bool) {
   142  	if verbose {
   143  		fmt.Printf("handleBasic name: %s, kind: %s, alias: %s, isArray: %v\n", name, kind, alias, isArray)
   144  	}
   145  
   146  	if prefix != "" {
   147  		name = prefix + "." + name
   148  	}
   149  
   150  	if aliasPrefix != "" {
   151  		alias = aliasPrefix + "." + alias
   152  	}
   153  
   154  	basicType := origTypeToBasicType(kind)
   155  	newStructField := &common.StructField{
   156  		Name:        name,
   157  		BasicType:   basicType,
   158  		ReturnType:  basicType,
   159  		IsArray:     strings.HasPrefix(kind, "[]") || isArray,
   160  		Event:       event,
   161  		OrigType:    kind,
   162  		Iterator:    iterator,
   163  		CommentText: commentText,
   164  		OpOverrides: opOverrides,
   165  		Struct:      containerStructName,
   166  		Alias:       alias,
   167  		AliasPrefix: aliasPrefix,
   168  		GettersOnly: field.gettersOnly,
   169  	}
   170  
   171  	module.Fields[alias] = newStructField
   172  
   173  	if _, ok := module.EventTypes[event]; !ok {
   174  		module.EventTypes[event] = common.NewEventTypeMetada()
   175  	}
   176  
   177  	if field.lengthField {
   178  		name = name + ".length"
   179  		aliasPrefix = alias
   180  		alias = alias + ".length"
   181  
   182  		newStructField := &common.StructField{
   183  			Name:        name,
   184  			BasicType:   "int",
   185  			ReturnType:  "int",
   186  			OrigType:    "int",
   187  			IsArray:     isArray,
   188  			IsLength:    true,
   189  			Event:       event,
   190  			Iterator:    iterator,
   191  			CommentText: doc.SECLDocForLength,
   192  			OpOverrides: opOverrides,
   193  			Struct:      "string",
   194  			Alias:       alias,
   195  			AliasPrefix: aliasPrefix,
   196  			GettersOnly: field.gettersOnly,
   197  		}
   198  
   199  		module.Fields[alias] = newStructField
   200  	}
   201  }
   202  
   203  // handleEmbedded adds embedded fields to list of exposed SECL fields of the module
   204  func handleEmbedded(module *common.Module, name, prefix, event string, fieldTypeExpr ast.Expr) {
   205  	if verbose {
   206  		log.Printf("handleEmbedded name: %s", name)
   207  	}
   208  
   209  	if prefix != "" {
   210  		name = fmt.Sprintf("%s.%s", prefix, name)
   211  	}
   212  
   213  	fieldType, isPointer, isArray := getFieldIdentName(fieldTypeExpr)
   214  
   215  	// maintain a list of all the fields
   216  	module.AllFields[name] = &common.StructField{
   217  		Name:          name,
   218  		Event:         event,
   219  		OrigType:      qualifiedType(module, fieldType),
   220  		IsOrigTypePtr: isPointer,
   221  		IsArray:       isArray,
   222  	}
   223  }
   224  
   225  // handleNonEmbedded adds non-embedded fields to list of all possible (but not necessarily exposed) SECL fields of the module
   226  func handleNonEmbedded(module *common.Module, field seclField, prefixedFieldName, event, fieldType string, isPointer, isArray bool) {
   227  	module.AllFields[prefixedFieldName] = &common.StructField{
   228  		Name:          prefixedFieldName,
   229  		Event:         event,
   230  		OrigType:      qualifiedType(module, fieldType),
   231  		IsOrigTypePtr: isPointer,
   232  		IsArray:       isArray,
   233  		Check:         field.check,
   234  	}
   235  }
   236  
   237  // handleIterator adds iterator to list of exposed SECL iterators of the module
   238  func handleIterator(module *common.Module, field seclField, fieldType, iterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides string, isPointer, isArray bool) *common.StructField {
   239  	alias := field.name
   240  	if aliasPrefix != "" {
   241  		alias = aliasPrefix + "." + field.name
   242  	}
   243  
   244  	module.Iterators[alias] = &common.StructField{
   245  		Name:             prefixedFieldName,
   246  		ReturnType:       qualifiedType(module, iterator),
   247  		Event:            event,
   248  		OrigType:         qualifiedType(module, fieldType),
   249  		IsOrigTypePtr:    isPointer,
   250  		IsArray:          isArray,
   251  		Weight:           field.weight,
   252  		CommentText:      fieldCommentText,
   253  		OpOverrides:      opOverrides,
   254  		Helper:           field.helper,
   255  		SkipADResolution: field.skipADResolution,
   256  		Check:            field.check,
   257  	}
   258  
   259  	return module.Iterators[alias]
   260  }
   261  
   262  // handleFieldWithHandler adds non-embedded fields with handlers to list of exposed SECL fields and event types of the module
   263  func handleFieldWithHandler(module *common.Module, field seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, containerStructName, event, fieldCommentText, opOverrides, handler string, isPointer, isArray bool, fieldIterator *common.StructField) {
   264  	alias := field.name
   265  
   266  	if aliasPrefix != "" {
   267  		alias = aliasPrefix + "." + alias
   268  	}
   269  
   270  	if event == "" {
   271  		log.Printf("event type not specified for field: %s", prefixedFieldName)
   272  	}
   273  
   274  	newStructField := &common.StructField{
   275  		Prefix:           prefix,
   276  		Name:             prefixedFieldName,
   277  		BasicType:        origTypeToBasicType(fieldType),
   278  		Struct:           containerStructName,
   279  		Handler:          handler,
   280  		ReturnType:       origTypeToBasicType(fieldType),
   281  		Event:            event,
   282  		OrigType:         fieldType,
   283  		Iterator:         fieldIterator,
   284  		IsArray:          isArray,
   285  		Weight:           field.weight,
   286  		CommentText:      fieldCommentText,
   287  		OpOverrides:      opOverrides,
   288  		Helper:           field.helper,
   289  		SkipADResolution: field.skipADResolution,
   290  		IsOrigTypePtr:    isPointer,
   291  		Check:            field.check,
   292  		Alias:            alias,
   293  		AliasPrefix:      aliasPrefix,
   294  		GettersOnly:      field.gettersOnly,
   295  	}
   296  
   297  	module.Fields[alias] = newStructField
   298  
   299  	if field.lengthField {
   300  		var lengthField = *module.Fields[alias]
   301  		lengthField.IsLength = true
   302  		lengthField.Name += ".length"
   303  		lengthField.OrigType = "int"
   304  		lengthField.BasicType = "int"
   305  		lengthField.ReturnType = "int"
   306  		lengthField.Struct = "string"
   307  		lengthField.AliasPrefix = alias
   308  		lengthField.Alias = alias + ".length"
   309  		lengthField.CommentText = doc.SECLDocForLength
   310  
   311  		module.Fields[lengthField.Alias] = &lengthField
   312  	}
   313  
   314  	if _, ok := module.EventTypes[event]; !ok {
   315  		module.EventTypes[event] = common.NewEventTypeMetada(alias)
   316  	} else {
   317  		module.EventTypes[event].Fields = append(module.EventTypes[event].Fields, alias)
   318  	}
   319  }
   320  
   321  func getFieldName(expr ast.Expr) string {
   322  	switch expr := expr.(type) {
   323  	case *ast.Ident:
   324  		return expr.Name
   325  	case *ast.StarExpr:
   326  		return getFieldName(expr.X)
   327  	case *ast.ArrayType:
   328  		return getFieldName(expr.Elt)
   329  	case *ast.SelectorExpr:
   330  		return getFieldName(expr.X) + "." + getFieldName(expr.Sel)
   331  	default:
   332  		return ""
   333  	}
   334  }
   335  
   336  func getFieldIdentName(expr ast.Expr) (name string, isPointer bool, isArray bool) {
   337  	switch expr.(type) {
   338  	case *ast.StarExpr:
   339  		isPointer = true
   340  	case *ast.ArrayType:
   341  		isArray = true
   342  	}
   343  
   344  	return getFieldName(expr), isPointer, isArray
   345  }
   346  
   347  type seclField struct {
   348  	name                   string
   349  	iterator               string
   350  	handler                string
   351  	helper                 bool // mark the handler as just a helper and not a real resolver. Won't be called by ResolveFields
   352  	skipADResolution       bool
   353  	lengthField            bool
   354  	weight                 int64
   355  	check                  string
   356  	exposedAtEventRootOnly bool // fields that should only be exposed at the root of an event, i.e. `parent` should not be exposed for an `ancestor` of a process
   357  	containerStructName    string
   358  	gettersOnly            bool //  a field that is not exposed via SECL, but still has an accessor generated
   359  }
   360  
   361  func parseFieldDef(def string) (seclField, error) {
   362  	def = strings.TrimSpace(def)
   363  	alias, options, splitted := strings.Cut(def, ",")
   364  
   365  	field := seclField{name: alias}
   366  
   367  	if alias == "-" {
   368  		return field, nil
   369  	}
   370  
   371  	// arguments
   372  	if splitted {
   373  		for _, el := range strings.Split(options, ",") {
   374  			kv := strings.Split(el, ":")
   375  
   376  			key, value := kv[0], kv[1]
   377  
   378  			switch key {
   379  			case "handler":
   380  				field.handler = value
   381  			case "weight":
   382  				weight, err := strconv.ParseInt(value, 10, 64)
   383  				if err != nil {
   384  					return field, err
   385  				}
   386  				field.weight = weight
   387  			case "iterator":
   388  				field.iterator = value
   389  			case "check":
   390  				field.check = value
   391  			case "opts":
   392  				for _, opt := range strings.Split(value, "|") {
   393  					switch opt {
   394  					case "helper":
   395  						field.helper = true
   396  					case "length":
   397  						field.lengthField = true
   398  					case "skip_ad":
   399  						field.skipADResolution = true
   400  					case "exposed_at_event_root_only":
   401  						field.exposedAtEventRootOnly = true
   402  					case "getters_only":
   403  						field.gettersOnly = true
   404  						field.exposedAtEventRootOnly = true
   405  					}
   406  				}
   407  			}
   408  		}
   409  	}
   410  
   411  	return field, nil
   412  }
   413  
   414  // handleSpecRecursive is a recursive function that walks through the fields of a module
   415  func handleSpecRecursive(module *common.Module, astFiles *AstFiles, spec interface{}, prefix, aliasPrefix, event string, iterator *common.StructField, dejavu map[string]bool) {
   416  	if verbose {
   417  		fmt.Printf("handleSpec spec: %+v, prefix: %s, aliasPrefix %s, event %s, iterator %+v\n", spec, prefix, aliasPrefix, event, iterator)
   418  	}
   419  
   420  	var typeSpec *ast.TypeSpec
   421  	var structType *ast.StructType
   422  	var ok bool
   423  	if typeSpec, ok = spec.(*ast.TypeSpec); !ok {
   424  		return
   425  	}
   426  	if structType, ok = typeSpec.Type.(*ast.StructType); !ok {
   427  		log.Printf("Don't know what to do with %s (%s)", typeSpec.Name, spew.Sdump(typeSpec))
   428  		return
   429  	}
   430  
   431  	for _, field := range structType.Fields.List {
   432  		fieldCommentText := field.Comment.Text()
   433  		fieldIterator := iterator
   434  
   435  		var tag reflect.StructTag
   436  		if field.Tag != nil {
   437  			tag = reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1])
   438  		}
   439  
   440  		if e, ok := tag.Lookup("event"); ok {
   441  			event = e
   442  			if _, ok = module.EventTypes[e]; !ok {
   443  				module.EventTypes[e] = common.NewEventTypeMetada()
   444  				dejavu = make(map[string]bool) // clear dejavu map when it's a new event type
   445  			}
   446  			if e != "*" {
   447  				module.EventTypes[e].Doc = fieldCommentText
   448  			}
   449  		}
   450  
   451  		if isEmbedded := len(field.Names) == 0; isEmbedded { // embedded as in a struct embedded in another struct
   452  			if fieldTag, found := tag.Lookup("field"); found && fieldTag == "-" {
   453  				continue
   454  			}
   455  
   456  			ident, _ := field.Type.(*ast.Ident)
   457  			if ident == nil {
   458  				if starExpr, ok := field.Type.(*ast.StarExpr); ok {
   459  					ident, _ = starExpr.X.(*ast.Ident)
   460  				}
   461  			}
   462  
   463  			if ident != nil {
   464  				name := ident.Name
   465  				if prefix != "" {
   466  					name = prefix + "." + ident.Name
   467  				}
   468  
   469  				embedded := astFiles.LookupSymbol(ident.Name)
   470  				if embedded != nil {
   471  					handleEmbedded(module, ident.Name, prefix, event, field.Type)
   472  					handleSpecRecursive(module, astFiles, embedded.Decl, name, aliasPrefix, event, fieldIterator, dejavu)
   473  				} else {
   474  					log.Printf("failed to resolve symbol for identifier %+v in %s", ident.Name, pkgname)
   475  				}
   476  			}
   477  		} else {
   478  			fieldBasename := field.Names[0].Name
   479  			if !unicode.IsUpper(rune(fieldBasename[0])) {
   480  				continue
   481  			}
   482  
   483  			if dejavu[fieldBasename] {
   484  				continue
   485  			}
   486  
   487  			var opOverrides string
   488  			var fields []seclField
   489  			var gettersOnlyFields []seclField
   490  			if tags, err := structtag.Parse(string(tag)); err == nil && len(tags.Tags()) != 0 {
   491  				opOverrides, fields, gettersOnlyFields = parseTags(tags, typeSpec.Name.Name)
   492  
   493  				if opOverrides == "" && fields == nil && gettersOnlyFields == nil {
   494  					continue
   495  				}
   496  			} else {
   497  				fields = append(fields, seclField{name: fieldBasename})
   498  			}
   499  
   500  			fieldType, isPointer, isArray := getFieldIdentName(field.Type)
   501  
   502  			prefixedFieldName := fieldBasename
   503  			if prefix != "" {
   504  				prefixedFieldName = fmt.Sprintf("%s.%s", prefix, fieldBasename)
   505  			}
   506  
   507  			for _, seclField := range fields {
   508  				handleNonEmbedded(module, seclField, prefixedFieldName, event, fieldType, isPointer, isArray)
   509  
   510  				if seclFieldIterator := seclField.iterator; seclFieldIterator != "" {
   511  					fieldIterator = handleIterator(module, seclField, fieldType, seclFieldIterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides, isPointer, isArray)
   512  				}
   513  
   514  				if handler := seclField.handler; handler != "" {
   515  
   516  					handleFieldWithHandler(module, seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, seclField.containerStructName, event, fieldCommentText, opOverrides, handler, isPointer, isArray, fieldIterator)
   517  
   518  					delete(dejavu, fieldBasename)
   519  					continue
   520  				}
   521  
   522  				if verbose {
   523  					log.Printf("Don't know what to do with %s: %s", fieldBasename, spew.Sdump(field.Type))
   524  				}
   525  
   526  				dejavu[fieldBasename] = true
   527  
   528  				if len(fieldType) == 0 {
   529  					continue
   530  				}
   531  
   532  				if isNetType((fieldType)) {
   533  					if !slices.Contains(module.Imports, "net") {
   534  						module.Imports = append(module.Imports, "net")
   535  					}
   536  				}
   537  
   538  				alias := seclField.name
   539  				if isBasicType(fieldType) {
   540  					handleBasic(module, seclField, fieldBasename, alias, aliasPrefix, prefix, fieldType, event, opOverrides, fieldCommentText, seclField.containerStructName, fieldIterator, isArray)
   541  				} else {
   542  					spec := astFiles.LookupSymbol(fieldType)
   543  					if spec != nil {
   544  						newPrefix, newAliasPrefix := fieldBasename, alias
   545  
   546  						if prefix != "" {
   547  							newPrefix = prefix + "." + fieldBasename
   548  						}
   549  
   550  						if aliasPrefix != "" {
   551  							newAliasPrefix = aliasPrefix + "." + alias
   552  						}
   553  
   554  						handleSpecRecursive(module, astFiles, spec.Decl, newPrefix, newAliasPrefix, event, fieldIterator, dejavu)
   555  					} else {
   556  						log.Printf("failed to resolve symbol for type %+v in %s", fieldType, pkgname)
   557  					}
   558  				}
   559  
   560  				if !seclField.exposedAtEventRootOnly {
   561  					delete(dejavu, fieldBasename)
   562  				}
   563  			}
   564  			for _, seclField := range gettersOnlyFields {
   565  				handleNonEmbedded(module, seclField, prefixedFieldName, event, fieldType, isPointer, isArray)
   566  
   567  				if seclFieldIterator := seclField.iterator; seclFieldIterator != "" {
   568  					fieldIterator = handleIterator(module, seclField, fieldType, seclFieldIterator, aliasPrefix, prefixedFieldName, event, fieldCommentText, opOverrides, isPointer, isArray)
   569  				}
   570  
   571  				if handler := seclField.handler; handler != "" {
   572  					handleFieldWithHandler(module, seclField, aliasPrefix, prefix, prefixedFieldName, fieldType, seclField.containerStructName, event, fieldCommentText, opOverrides, handler, isPointer, isArray, fieldIterator)
   573  
   574  					delete(dejavu, fieldBasename)
   575  					continue
   576  				}
   577  
   578  				if verbose {
   579  					log.Printf("Don't know what to do with %s: %s", fieldBasename, spew.Sdump(field.Type))
   580  				}
   581  
   582  				dejavu[fieldBasename] = true
   583  
   584  				if len(fieldType) == 0 {
   585  					continue
   586  				}
   587  
   588  				alias := seclField.name
   589  				if isBasicTypeForGettersOnly(fieldType) {
   590  					handleBasic(module, seclField, fieldBasename, alias, aliasPrefix, prefix, fieldType, event, opOverrides, fieldCommentText, seclField.containerStructName, fieldIterator, isArray)
   591  				} else {
   592  					spec := astFiles.LookupSymbol(fieldType)
   593  					if spec != nil {
   594  						newPrefix, newAliasPrefix := fieldBasename, alias
   595  
   596  						if prefix != "" {
   597  							newPrefix = prefix + "." + fieldBasename
   598  						}
   599  
   600  						if aliasPrefix != "" {
   601  							newAliasPrefix = aliasPrefix + "." + alias
   602  						}
   603  
   604  						handleSpecRecursive(module, astFiles, spec.Decl, newPrefix, newAliasPrefix, event, fieldIterator, dejavu)
   605  					} else {
   606  						log.Printf("failed to resolve symbol for type %+v in %s", fieldType, pkgname)
   607  					}
   608  				}
   609  
   610  				if !seclField.exposedAtEventRootOnly {
   611  					delete(dejavu, fieldBasename)
   612  				}
   613  			}
   614  		}
   615  	}
   616  }
   617  
   618  func parseTags(tags *structtag.Tags, containerStructName string) (string, []seclField, []seclField) {
   619  	var opOverrides string
   620  	var fields []seclField
   621  	var gettersOnlyFields []seclField
   622  
   623  	for _, tag := range tags.Tags() {
   624  		switch tag.Key {
   625  		case "field":
   626  			fieldDefs := strings.Split(tag.Value(), ";")
   627  			for _, fieldDef := range fieldDefs {
   628  				field, err := parseFieldDef(fieldDef)
   629  				if err != nil {
   630  					log.Panicf("unable to parse field definition: %s", err)
   631  				}
   632  
   633  				if field.name == "-" {
   634  					return "", nil, nil
   635  				}
   636  
   637  				field.containerStructName = containerStructName
   638  
   639  				if field.gettersOnly {
   640  					gettersOnlyFields = append(gettersOnlyFields, field)
   641  				} else {
   642  					fields = append(fields, field)
   643  				}
   644  			}
   645  
   646  		case "op_override":
   647  			opOverrides = tag.Value()
   648  		}
   649  	}
   650  
   651  	return opOverrides, fields, gettersOnlyFields
   652  }
   653  
   654  func newAstFiles(cfg *packages.Config, files ...string) (*AstFiles, error) {
   655  	var astFiles AstFiles
   656  
   657  	for _, file := range files {
   658  		pkgs, err := packages.Load(cfg, file)
   659  		if err != nil {
   660  			return nil, err
   661  		}
   662  
   663  		if len(pkgs) == 0 || len(pkgs[0].Syntax) == 0 {
   664  			return nil, fmt.Errorf("failed to get syntax from parse file %s", file)
   665  		}
   666  
   667  		astFiles.files = append(astFiles.files, pkgs[0].Syntax[0])
   668  	}
   669  
   670  	return &astFiles, nil
   671  }
   672  
   673  func parseFile(modelFile string, typesFile string, pkgName string) (*common.Module, error) {
   674  	cfg := packages.Config{
   675  		Mode:       packages.NeedSyntax | packages.NeedTypes | packages.NeedImports,
   676  		BuildFlags: []string{"-mod=mod", fmt.Sprintf("-tags=%s", buildTags)},
   677  	}
   678  
   679  	astFiles, err := newAstFiles(&cfg, modelFile, typesFile)
   680  	if err != nil {
   681  		return nil, err
   682  	}
   683  
   684  	moduleName := path.Base(path.Dir(output))
   685  	if moduleName == "." {
   686  		moduleName = path.Base(pkgName)
   687  	}
   688  
   689  	module := &common.Module{
   690  		Name:       moduleName,
   691  		SourcePkg:  pkgName,
   692  		TargetPkg:  pkgName,
   693  		BuildTags:  formatBuildTags(buildTags),
   694  		Fields:     make(map[string]*common.StructField),
   695  		AllFields:  make(map[string]*common.StructField),
   696  		Iterators:  make(map[string]*common.StructField),
   697  		EventTypes: make(map[string]*common.EventTypeMetadata),
   698  	}
   699  
   700  	// If the target package is different from the model package
   701  	if module.Name != path.Base(pkgName) {
   702  		module.SourcePkgPrefix = path.Base(pkgName) + "."
   703  		module.TargetPkg = path.Clean(path.Join(pkgName, path.Dir(output)))
   704  	}
   705  
   706  	for _, spec := range astFiles.GetSpecs() {
   707  		handleSpecRecursive(module, astFiles, spec, "", "", "", nil, make(map[string]bool))
   708  	}
   709  
   710  	return module, nil
   711  }
   712  
   713  func formatBuildTags(buildTags string) []string {
   714  	splittedBuildTags := strings.Split(buildTags, ",")
   715  	var formattedBuildTags []string
   716  	for _, tag := range splittedBuildTags {
   717  		if tag != "" {
   718  			formattedBuildTags = append(formattedBuildTags, fmt.Sprintf("go:build %s", tag))
   719  		}
   720  	}
   721  	return formattedBuildTags
   722  }
   723  
   724  func newField(allFields map[string]*common.StructField, field *common.StructField) string {
   725  	var fieldPath, result string
   726  	for _, node := range strings.Split(field.Name, ".") {
   727  		if fieldPath != "" {
   728  			fieldPath += "." + node
   729  		} else {
   730  			fieldPath = node
   731  		}
   732  
   733  		if field, ok := allFields[fieldPath]; ok {
   734  			if field.IsOrigTypePtr {
   735  				result += fmt.Sprintf("if ev.%s == nil { ev.%s = &%s{} }\n", field.Name, field.Name, field.OrigType)
   736  			}
   737  		}
   738  	}
   739  
   740  	return result
   741  }
   742  
   743  func generatePrefixNilChecks(allFields map[string]*common.StructField, returnType string, field *common.StructField) string {
   744  	var fieldPath, result string
   745  	for _, node := range strings.Split(field.Name, ".") {
   746  		if fieldPath != "" {
   747  			fieldPath += "." + node
   748  		} else {
   749  			fieldPath = node
   750  		}
   751  
   752  		if field, ok := allFields[fieldPath]; ok {
   753  			if field.IsOrigTypePtr {
   754  				result += fmt.Sprintf("if ev.%s == nil { return %s }\n", field.Name, getDefaultValueOfType(returnType))
   755  			}
   756  		}
   757  	}
   758  
   759  	return result
   760  }
   761  
   762  func split(r rune) bool {
   763  	return r == '.' || r == '_'
   764  }
   765  
   766  func pascalCaseFieldName(fieldName string) string {
   767  	chunks := strings.FieldsFunc(fieldName, split)
   768  	caser := cases.Title(language.English, cases.NoLower)
   769  
   770  	for idx, chunk := range chunks {
   771  		newChunk := chunk
   772  		chunks[idx] = caser.String(newChunk)
   773  	}
   774  
   775  	return strings.Join(chunks, "")
   776  }
   777  
   778  func getDefaultValueOfType(returnType string) string {
   779  	baseType, isArray := strings.CutPrefix(returnType, "[]")
   780  
   781  	if baseType == "int" {
   782  		if isArray {
   783  			return "[]int{}"
   784  		}
   785  		return "0"
   786  	} else if baseType == "int64" {
   787  		if isArray {
   788  			return "[]int64{}"
   789  		}
   790  		return "int64(0)"
   791  	} else if baseType == "uint16" {
   792  		if isArray {
   793  			return "[]uint16{}"
   794  		}
   795  		return "uint16(0)"
   796  	} else if baseType == "uint32" {
   797  		if isArray {
   798  			return "[]uint32{}"
   799  		}
   800  		return "uint32(0)"
   801  	} else if baseType == "uint64" {
   802  		if isArray {
   803  			return "[]uint64{}"
   804  		}
   805  		return "uint64(0)"
   806  	} else if baseType == "bool" {
   807  		if isArray {
   808  			return "[]bool{}"
   809  		}
   810  		return "false"
   811  	} else if baseType == "net.IPNet" {
   812  		if isArray {
   813  			return "&eval.CIDRValues{}"
   814  		}
   815  		return "net.IPNet{}"
   816  	} else if baseType == "time.Time" {
   817  		if isArray {
   818  			return "[]time.Time{}"
   819  		}
   820  		return "time.Time{}"
   821  	} else if isArray {
   822  		return "[]string{}"
   823  	}
   824  	return `""`
   825  }
   826  
   827  func needScrubbed(fieldName string) bool {
   828  	loweredFieldName := strings.ToLower(fieldName)
   829  	if (strings.Contains(loweredFieldName, "argv") && !strings.Contains(loweredFieldName, "argv0")) && !strings.Contains(loweredFieldName, "module") {
   830  		return true
   831  	}
   832  	return false
   833  }
   834  
   835  func addSuffixToFuncPrototype(suffix string, prototype string) string {
   836  	chunks := strings.SplitN(prototype, "(", 3)
   837  	chunks = append(chunks[:1], append([]string{suffix, "("}, chunks[1:]...)...)
   838  
   839  	return strings.Join(chunks, "")
   840  }
   841  
   842  func getFieldHandler(allFields map[string]*common.StructField, field *common.StructField) string {
   843  	if field.Handler == "" || field.Iterator != nil || field.Helper {
   844  		return ""
   845  	}
   846  
   847  	if field.Prefix == "" {
   848  		return fmt.Sprintf("ev.FieldHandlers.%s(ev)", field.Handler)
   849  	}
   850  
   851  	ptr := "&"
   852  	if allFields[field.Prefix].IsOrigTypePtr {
   853  		ptr = ""
   854  	}
   855  
   856  	return fmt.Sprintf("ev.FieldHandlers.%s(ev, %sev.%s)", field.Handler, ptr, field.Prefix)
   857  }
   858  
   859  func fieldADPrint(field *common.StructField, handler string) string {
   860  	if field.SkipADResolution {
   861  		return fmt.Sprintf("if !forADs { _ = %s }", handler)
   862  	}
   863  	return fmt.Sprintf("_ = %s", handler)
   864  }
   865  
   866  func getHolder(allFields map[string]*common.StructField, field *common.StructField) *common.StructField {
   867  	idx := strings.LastIndex(field.Name, ".")
   868  	if idx == -1 {
   869  		return nil
   870  	}
   871  	name := field.Name[:idx]
   872  	return allFields[name]
   873  }
   874  
   875  func getChecks(allFields map[string]*common.StructField, field *common.StructField) []string {
   876  	var checks []string
   877  
   878  	name := field.Name
   879  	for name != "" {
   880  		field := allFields[name]
   881  		if field == nil {
   882  			break
   883  		}
   884  
   885  		if field.Check != "" {
   886  			if holder := getHolder(allFields, field); holder != nil {
   887  				check := fmt.Sprintf(`%s.%s`, holder.Name, field.Check)
   888  				checks = append([]string{check}, checks...)
   889  			}
   890  		}
   891  
   892  		idx := strings.LastIndex(name, ".")
   893  		if idx == -1 {
   894  			break
   895  		}
   896  		name = name[:idx]
   897  	}
   898  
   899  	return checks
   900  }
   901  
   902  func getHandlers(allFields map[string]*common.StructField) map[string]string {
   903  	handlers := make(map[string]string)
   904  
   905  	for _, field := range allFields {
   906  		if field.Handler != "" && !field.IsLength {
   907  			returnType := field.ReturnType
   908  			if field.IsArray {
   909  				returnType = "[]" + returnType
   910  			}
   911  
   912  			var handler string
   913  			if field.Prefix == "" {
   914  				handler = fmt.Sprintf("%s(ev *Event) %s", field.Handler, returnType)
   915  			} else {
   916  				handler = fmt.Sprintf("%s(ev *Event, e *%s) %s", field.Handler, field.Struct, returnType)
   917  			}
   918  
   919  			if _, exists := handlers[handler]; exists {
   920  				continue
   921  			}
   922  
   923  			var name string
   924  			if field.Prefix == "" {
   925  				name = "ev." + field.Name
   926  			} else {
   927  				name = "e" + strings.TrimPrefix(field.Name, field.Prefix)
   928  			}
   929  
   930  			if field.ReturnType == "int" {
   931  				if field.IsArray {
   932  					handlers[handler] = fmt.Sprintf("{ var result []int; for _, value := range %s { result = append(result, int(value)) }; return result }", name)
   933  				} else {
   934  					handlers[handler] = fmt.Sprintf("{ return int(%s) }", name)
   935  				}
   936  			} else {
   937  				handlers[handler] = fmt.Sprintf("{ return %s }", name)
   938  			}
   939  		}
   940  	}
   941  
   942  	return handlers
   943  }
   944  
   945  var funcMap = map[string]interface{}{
   946  	"TrimPrefix":               strings.TrimPrefix,
   947  	"TrimSuffix":               strings.TrimSuffix,
   948  	"HasPrefix":                strings.HasPrefix,
   949  	"NewField":                 newField,
   950  	"GeneratePrefixNilChecks":  generatePrefixNilChecks,
   951  	"GetFieldHandler":          getFieldHandler,
   952  	"FieldADPrint":             fieldADPrint,
   953  	"GetChecks":                getChecks,
   954  	"GetHandlers":              getHandlers,
   955  	"PascalCaseFieldName":      pascalCaseFieldName,
   956  	"GetDefaultValueOfType":    getDefaultValueOfType,
   957  	"NeedScrubbed":             needScrubbed,
   958  	"AddSuffixToFuncPrototype": addSuffixToFuncPrototype,
   959  }
   960  
   961  //go:embed accessors.tmpl
   962  var accessorsTemplateCode string
   963  
   964  //go:embed field_handlers.tmpl
   965  var fieldHandlersTemplate string
   966  
   967  //go:embed field_accessors.tmpl
   968  var perFieldAccessorsTemplate string
   969  
   970  func main() {
   971  	module, err := parseFile(modelFile, typesFile, pkgname)
   972  	if err != nil {
   973  		panic(err)
   974  	}
   975  
   976  	if len(fieldHandlersOutput) > 0 {
   977  		if err = GenerateContent(fieldHandlersOutput, module, fieldHandlersTemplate); err != nil {
   978  			panic(err)
   979  		}
   980  	}
   981  
   982  	if docOutput != "" {
   983  		os.Remove(docOutput)
   984  		if err := doc.GenerateDocJSON(module, path.Dir(modelFile), docOutput); err != nil {
   985  			panic(err)
   986  		}
   987  	}
   988  
   989  	os.Remove(output)
   990  	if err := GenerateContent(output, module, accessorsTemplateCode); err != nil {
   991  		panic(err)
   992  	}
   993  
   994  	if err := GenerateContent(fieldAccessorsOutput, module, perFieldAccessorsTemplate); err != nil {
   995  		panic(err)
   996  	}
   997  }
   998  
   999  // GenerateContent generates with the given template
  1000  func GenerateContent(output string, module *common.Module, tmplCode string) error {
  1001  	tmpl := template.Must(template.New("header").Funcs(funcMap).Funcs(sprig.TxtFuncMap()).Parse(tmplCode))
  1002  
  1003  	buffer := bytes.Buffer{}
  1004  	if err := tmpl.Execute(&buffer, module); err != nil {
  1005  		return err
  1006  	}
  1007  
  1008  	cleaned := removeEmptyLines(&buffer)
  1009  
  1010  	tmpfile, err := os.CreateTemp(path.Dir(output), "secl-helpers")
  1011  	if err != nil {
  1012  		return err
  1013  	}
  1014  
  1015  	if _, err := tmpfile.WriteString(cleaned); err != nil {
  1016  		return err
  1017  	}
  1018  
  1019  	if err := tmpfile.Close(); err != nil {
  1020  		return err
  1021  	}
  1022  
  1023  	cmd := exec.Command("gofmt", "-s", "-w", tmpfile.Name())
  1024  	if output, err := cmd.CombinedOutput(); err != nil {
  1025  		log.Fatal(string(output))
  1026  		return err
  1027  	}
  1028  
  1029  	return os.Rename(tmpfile.Name(), output)
  1030  }
  1031  
  1032  func removeEmptyLines(input *bytes.Buffer) string {
  1033  	scanner := bufio.NewScanner(input)
  1034  	builder := strings.Builder{}
  1035  	inGoCode := false
  1036  
  1037  	for scanner.Scan() {
  1038  		trimmed := strings.TrimSpace(scanner.Text())
  1039  
  1040  		if strings.HasPrefix(trimmed, "package") {
  1041  			inGoCode = true
  1042  		}
  1043  
  1044  		if len(trimmed) != 0 || !inGoCode {
  1045  			builder.WriteString(trimmed)
  1046  			builder.WriteRune('\n')
  1047  		}
  1048  	}
  1049  
  1050  	return builder.String()
  1051  }
  1052  
  1053  func init() {
  1054  	flag.BoolVar(&verbose, "verbose", false, "Be verbose")
  1055  	flag.StringVar(&docOutput, "doc", "", "Generate documentation JSON")
  1056  	flag.StringVar(&fieldHandlersOutput, "field-handlers", "field_handlers_unix.go", "Field handlers output file")
  1057  	flag.StringVar(&modelFile, "input", os.Getenv("GOFILE"), "Go file to generate decoders from")
  1058  	flag.StringVar(&typesFile, "types-file", os.Getenv("TYPESFILE"), "Go type file to use with the model file")
  1059  	flag.StringVar(&pkgname, "package", pkgPrefix+"/"+os.Getenv("GOPACKAGE"), "Go package name")
  1060  	flag.StringVar(&buildTags, "tags", "unix", "build tags used for parsing")
  1061  	flag.StringVar(&fieldAccessorsOutput, "field-accessors-output", "field_accessors_unix.go", "Generated per-field accessors output file")
  1062  	flag.StringVar(&output, "output", "accessors_unix.go", "Go generated file")
  1063  	flag.Parse()
  1064  }