github.com/benhoyt/goawk@v1.8.1/parser/resolve.go (about)

     1  // Resolve function calls and variable types
     2  
     3  package parser
     4  
     5  import (
     6  	"fmt"
     7  	"reflect"
     8  	"sort"
     9  
    10  	. "github.com/benhoyt/goawk/internal/ast"
    11  	. "github.com/benhoyt/goawk/lexer"
    12  )
    13  
    14  type varType int
    15  
    16  const (
    17  	typeUnknown varType = iota
    18  	typeScalar
    19  	typeArray
    20  )
    21  
    22  func (t varType) String() string {
    23  	switch t {
    24  	case typeScalar:
    25  		return "Scalar"
    26  	case typeArray:
    27  		return "Array"
    28  	default:
    29  		return "Unknown"
    30  	}
    31  }
    32  
    33  // typeInfo records type information for a single variable
    34  type typeInfo struct {
    35  	typ      varType
    36  	ref      *VarExpr
    37  	scope    VarScope
    38  	index    int
    39  	callName string
    40  	argIndex int
    41  }
    42  
    43  // Used by printVarTypes when debugTypes is turned on
    44  func (t typeInfo) String() string {
    45  	var scope string
    46  	switch t.scope {
    47  	case ScopeGlobal:
    48  		scope = "Global"
    49  	case ScopeLocal:
    50  		scope = "Local"
    51  	default:
    52  		scope = "Special"
    53  	}
    54  	return fmt.Sprintf("typ=%s ref=%p scope=%s index=%d callName=%q argIndex=%d",
    55  		t.typ, t.ref, scope, t.index, t.callName, t.argIndex)
    56  }
    57  
    58  // A single variable reference (normally scalar)
    59  type varRef struct {
    60  	funcName string
    61  	ref      *VarExpr
    62  	isArg    bool
    63  	pos      Position
    64  }
    65  
    66  // A single array reference
    67  type arrayRef struct {
    68  	funcName string
    69  	ref      *ArrayExpr
    70  	pos      Position
    71  }
    72  
    73  // Initialize the resolver
    74  func (p *parser) initResolve() {
    75  	p.varTypes = make(map[string]map[string]typeInfo)
    76  	p.varTypes[""] = make(map[string]typeInfo) // globals
    77  	p.functions = make(map[string]int)
    78  	p.arrayRef("ARGV", Position{1, 1}) // interpreter relies on ARGV being present
    79  	p.multiExprs = make(map[*MultiExpr]Position, 3)
    80  }
    81  
    82  // Signal the start of a function
    83  func (p *parser) startFunction(name string, params []string) {
    84  	p.funcName = name
    85  	p.varTypes[name] = make(map[string]typeInfo)
    86  }
    87  
    88  // Signal the end of a function
    89  func (p *parser) stopFunction() {
    90  	p.funcName = ""
    91  }
    92  
    93  // Add function by name with given index
    94  func (p *parser) addFunction(name string, index int) {
    95  	p.functions[name] = index
    96  }
    97  
    98  // Records a call to a user function (for resolving indexes later)
    99  type userCall struct {
   100  	call   *UserCallExpr
   101  	pos    Position
   102  	inFunc string
   103  }
   104  
   105  // Record a user call site
   106  func (p *parser) recordUserCall(call *UserCallExpr, pos Position) {
   107  	p.userCalls = append(p.userCalls, userCall{call, pos, p.funcName})
   108  }
   109  
   110  // After parsing, resolve all user calls to their indexes. Also
   111  // ensures functions called have actually been defined, and that
   112  // they're not being called with too many arguments.
   113  func (p *parser) resolveUserCalls(prog *Program) {
   114  	// Number the native funcs (order by name to get consistent order)
   115  	nativeNames := make([]string, 0, len(p.nativeFuncs))
   116  	for name := range p.nativeFuncs {
   117  		nativeNames = append(nativeNames, name)
   118  	}
   119  	sort.Strings(nativeNames)
   120  	nativeIndexes := make(map[string]int, len(nativeNames))
   121  	for i, name := range nativeNames {
   122  		nativeIndexes[name] = i
   123  	}
   124  
   125  	for _, c := range p.userCalls {
   126  		// AWK-defined functions take precedence over native Go funcs
   127  		index, ok := p.functions[c.call.Name]
   128  		if !ok {
   129  			f, haveNative := p.nativeFuncs[c.call.Name]
   130  			if !haveNative {
   131  				panic(&ParseError{c.pos, fmt.Sprintf("undefined function %q", c.call.Name)})
   132  			}
   133  			typ := reflect.TypeOf(f)
   134  			if !typ.IsVariadic() && len(c.call.Args) > typ.NumIn() {
   135  				panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)})
   136  			}
   137  			c.call.Native = true
   138  			c.call.Index = nativeIndexes[c.call.Name]
   139  			continue
   140  		}
   141  		function := prog.Functions[index]
   142  		if len(c.call.Args) > len(function.Params) {
   143  			panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)})
   144  		}
   145  		c.call.Index = index
   146  	}
   147  }
   148  
   149  // For arguments that are variable references, we don't know the
   150  // type based on context, so mark the types for these as unknown.
   151  func (p *parser) processUserCallArg(funcName string, arg Expr, index int) {
   152  	if varExpr, ok := arg.(*VarExpr); ok {
   153  		scope, varFuncName := p.getScope(varExpr.Name)
   154  		ref := p.varTypes[varFuncName][varExpr.Name].ref
   155  		if ref == varExpr {
   156  			// Only applies if this is the first reference to this
   157  			// variable (otherwise we know the type already)
   158  			p.varTypes[varFuncName][varExpr.Name] = typeInfo{typeUnknown, ref, scope, 0, funcName, index}
   159  		}
   160  		// Mark the last related varRef (the most recent one) as a
   161  		// call argument for later error handling
   162  		p.varRefs[len(p.varRefs)-1].isArg = true
   163  	}
   164  }
   165  
   166  // Determine scope of given variable reference (and funcName if it's
   167  // a local, otherwise empty string)
   168  func (p *parser) getScope(name string) (VarScope, string) {
   169  	switch {
   170  	case p.locals[name]:
   171  		return ScopeLocal, p.funcName
   172  	case SpecialVarIndex(name) > 0:
   173  		return ScopeSpecial, ""
   174  	default:
   175  		return ScopeGlobal, ""
   176  	}
   177  }
   178  
   179  // Record a variable (scalar) reference and return the *VarExpr (but
   180  // VarExpr.Index won't be set till later)
   181  func (p *parser) varRef(name string, pos Position) *VarExpr {
   182  	scope, funcName := p.getScope(name)
   183  	expr := &VarExpr{scope, 0, name}
   184  	p.varRefs = append(p.varRefs, varRef{funcName, expr, false, pos})
   185  	info := p.varTypes[funcName][name]
   186  	if info.typ == typeUnknown {
   187  		p.varTypes[funcName][name] = typeInfo{typeScalar, expr, scope, 0, info.callName, 0}
   188  	}
   189  	return expr
   190  }
   191  
   192  // Record an array reference and return the *ArrayExpr (but
   193  // ArrayExpr.Index won't be set till later)
   194  func (p *parser) arrayRef(name string, pos Position) *ArrayExpr {
   195  	scope, funcName := p.getScope(name)
   196  	if scope == ScopeSpecial {
   197  		panic(p.error("can't use scalar %q as array", name))
   198  	}
   199  	expr := &ArrayExpr{scope, 0, name}
   200  	p.arrayRefs = append(p.arrayRefs, arrayRef{funcName, expr, pos})
   201  	info := p.varTypes[funcName][name]
   202  	if info.typ == typeUnknown {
   203  		p.varTypes[funcName][name] = typeInfo{typeArray, nil, scope, 0, info.callName, 0}
   204  	}
   205  	return expr
   206  }
   207  
   208  // Print variable type information (for debugging) on p.debugWriter
   209  func (p *parser) printVarTypes(prog *Program) {
   210  	fmt.Fprintf(p.debugWriter, "scalars: %v\n", prog.Scalars)
   211  	fmt.Fprintf(p.debugWriter, "arrays: %v\n", prog.Arrays)
   212  	funcNames := []string{}
   213  	for funcName := range p.varTypes {
   214  		funcNames = append(funcNames, funcName)
   215  	}
   216  	sort.Strings(funcNames)
   217  	for _, funcName := range funcNames {
   218  		if funcName != "" {
   219  			fmt.Fprintf(p.debugWriter, "function %s\n", funcName)
   220  		} else {
   221  			fmt.Fprintf(p.debugWriter, "globals\n")
   222  		}
   223  		varNames := []string{}
   224  		for name := range p.varTypes[funcName] {
   225  			varNames = append(varNames, name)
   226  		}
   227  		sort.Strings(varNames)
   228  		for _, name := range varNames {
   229  			info := p.varTypes[funcName][name]
   230  			fmt.Fprintf(p.debugWriter, "  %s: %s\n", name, info)
   231  		}
   232  	}
   233  }
   234  
   235  // If we can't finish resolving after this many iterations, give up
   236  const maxResolveIterations = 10000
   237  
   238  // Resolve unknown variables types and generate variable indexes and
   239  // name-to-index mappings for interpreter
   240  func (p *parser) resolveVars(prog *Program) {
   241  	// First go through all unknown types and try to determine the
   242  	// type from the parameter type in that function definition. May
   243  	// need multiple passes depending on the order of functions. This
   244  	// is not particularly efficient, but on realistic programs it's
   245  	// not an issue.
   246  	for i := 0; ; i++ {
   247  		progressed := false
   248  		for funcName, infos := range p.varTypes {
   249  			for name, info := range infos {
   250  				if info.scope == ScopeSpecial || info.typ != typeUnknown {
   251  					// It's a special var or type is already known
   252  					continue
   253  				}
   254  				funcIndex, ok := p.functions[info.callName]
   255  				if !ok {
   256  					// Function being called is a native function
   257  					continue
   258  				}
   259  				// Determine var type based on type of this parameter
   260  				// in the called function (if we know that)
   261  				paramName := prog.Functions[funcIndex].Params[info.argIndex]
   262  				typ := p.varTypes[info.callName][paramName].typ
   263  				if typ != typeUnknown {
   264  					if p.debugTypes {
   265  						fmt.Fprintf(p.debugWriter, "resolving %s:%s to %s\n",
   266  							funcName, name, typ)
   267  					}
   268  					info.typ = typ
   269  					p.varTypes[funcName][name] = info
   270  					progressed = true
   271  				}
   272  			}
   273  		}
   274  		if !progressed {
   275  			// If we didn't progress we're done (or trying again is
   276  			// not going to help)
   277  			break
   278  		}
   279  		if i >= maxResolveIterations {
   280  			panic(p.error("too many iterations trying to resolve variable types"))
   281  		}
   282  	}
   283  
   284  	// Resolve global variables (iteration order is undefined, so
   285  	// assign indexes basically randomly)
   286  	prog.Scalars = make(map[string]int)
   287  	prog.Arrays = make(map[string]int)
   288  	for name, info := range p.varTypes[""] {
   289  		_, isFunc := p.functions[name]
   290  		if isFunc {
   291  			// Global var can't also be the name of a function
   292  			panic(p.error("global var %q can't also be a function", name))
   293  		}
   294  		var index int
   295  		if info.scope == ScopeSpecial {
   296  			index = SpecialVarIndex(name)
   297  		} else if info.typ == typeArray {
   298  			index = len(prog.Arrays)
   299  			prog.Arrays[name] = index
   300  		} else {
   301  			index = len(prog.Scalars)
   302  			prog.Scalars[name] = index
   303  		}
   304  		info.index = index
   305  		p.varTypes[""][name] = info
   306  	}
   307  
   308  	// Resolve local variables (assign indexes in order of params).
   309  	// Also patch up Function.Arrays (tells interpreter which args
   310  	// are arrays).
   311  	for funcName, infos := range p.varTypes {
   312  		if funcName == "" {
   313  			continue
   314  		}
   315  		scalarIndex := 0
   316  		arrayIndex := 0
   317  		functionIndex := p.functions[funcName]
   318  		function := prog.Functions[functionIndex]
   319  		arrays := make([]bool, len(function.Params))
   320  		for i, name := range function.Params {
   321  			info := infos[name]
   322  			var index int
   323  			if info.typ == typeArray {
   324  				index = arrayIndex
   325  				arrayIndex++
   326  				arrays[i] = true
   327  			} else {
   328  				// typeScalar or typeUnknown: variables may still be
   329  				// of unknown type if they've never been referenced --
   330  				// default to scalar in that case
   331  				index = scalarIndex
   332  				scalarIndex++
   333  			}
   334  			info.index = index
   335  			p.varTypes[funcName][name] = info
   336  		}
   337  		prog.Functions[functionIndex].Arrays = arrays
   338  	}
   339  
   340  	// Check that variables passed to functions are the correct type
   341  	for _, c := range p.userCalls {
   342  		// Check native function calls
   343  		if c.call.Native {
   344  			for _, arg := range c.call.Args {
   345  				varExpr, ok := arg.(*VarExpr)
   346  				if !ok {
   347  					// Non-variable expression, must be scalar
   348  					continue
   349  				}
   350  				funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
   351  				info := p.varTypes[funcName][varExpr.Name]
   352  				if info.typ == typeArray {
   353  					message := fmt.Sprintf("can't pass array %q to native function", varExpr.Name)
   354  					panic(&ParseError{c.pos, message})
   355  				}
   356  			}
   357  			continue
   358  		}
   359  
   360  		// Check AWK function calls
   361  		function := prog.Functions[c.call.Index]
   362  		for i, arg := range c.call.Args {
   363  			varExpr, ok := arg.(*VarExpr)
   364  			if !ok {
   365  				if function.Arrays[i] {
   366  					message := fmt.Sprintf("can't pass scalar %s as array param", arg)
   367  					panic(&ParseError{c.pos, message})
   368  				}
   369  				continue
   370  			}
   371  			funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
   372  			info := p.varTypes[funcName][varExpr.Name]
   373  			if info.typ == typeArray && !function.Arrays[i] {
   374  				message := fmt.Sprintf("can't pass array %q as scalar param", varExpr.Name)
   375  				panic(&ParseError{c.pos, message})
   376  			}
   377  			if info.typ != typeArray && function.Arrays[i] {
   378  				message := fmt.Sprintf("can't pass scalar %q as array param", varExpr.Name)
   379  				panic(&ParseError{c.pos, message})
   380  			}
   381  		}
   382  	}
   383  
   384  	if p.debugTypes {
   385  		p.printVarTypes(prog)
   386  	}
   387  
   388  	// Patch up variable indexes (interpreter uses an index instead
   389  	// of name for more efficient lookups)
   390  	for _, varRef := range p.varRefs {
   391  		info := p.varTypes[varRef.funcName][varRef.ref.Name]
   392  		if info.typ == typeArray && !varRef.isArg {
   393  			message := fmt.Sprintf("can't use array %q as scalar", varRef.ref.Name)
   394  			panic(&ParseError{varRef.pos, message})
   395  		}
   396  		varRef.ref.Index = info.index
   397  	}
   398  	for _, arrayRef := range p.arrayRefs {
   399  		info := p.varTypes[arrayRef.funcName][arrayRef.ref.Name]
   400  		if info.typ == typeScalar {
   401  			message := fmt.Sprintf("can't use scalar %q as array", arrayRef.ref.Name)
   402  			panic(&ParseError{arrayRef.pos, message})
   403  		}
   404  		arrayRef.ref.Index = info.index
   405  	}
   406  }
   407  
   408  // If name refers to a local (in function inFunc), return that
   409  // function's name, otherwise return "" (meaning global).
   410  func (p *parser) getVarFuncName(prog *Program, name, inFunc string) string {
   411  	if inFunc == "" {
   412  		return ""
   413  	}
   414  	for _, param := range prog.Functions[p.functions[inFunc]].Params {
   415  		if name == param {
   416  			return inFunc
   417  		}
   418  	}
   419  	return ""
   420  }
   421  
   422  // Record a "multi expression" (comma-separated pseudo-expression
   423  // used to allow commas around print/printf arguments).
   424  func (p *parser) multiExpr(exprs []Expr, pos Position) Expr {
   425  	expr := &MultiExpr{exprs}
   426  	p.multiExprs[expr] = pos
   427  	return expr
   428  }
   429  
   430  // Mark the multi expression as used (by a print/printf statement).
   431  func (p *parser) useMultiExpr(expr *MultiExpr) {
   432  	delete(p.multiExprs, expr)
   433  }
   434  
   435  // Check that there are no unused multi expressions (syntax error).
   436  func (p *parser) checkMultiExprs() {
   437  	if len(p.multiExprs) == 0 {
   438  		return
   439  	}
   440  	// Show error on first comma-separated expression
   441  	min := Position{1000000000, 1000000000}
   442  	for _, pos := range p.multiExprs {
   443  		if pos.Line < min.Line || (pos.Line == min.Line && pos.Column < min.Column) {
   444  			min = pos
   445  		}
   446  	}
   447  	message := fmt.Sprintf("unexpected comma-separated expression")
   448  	panic(&ParseError{min, message})
   449  }