github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/transform/staticlark/functions.go (about)

     1  package staticlark
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"go.starlark.net/syntax"
     8  )
     9  
    10  // build a list of functions
    11  func collectFuncDefsTopLevelCalls(stmts []syntax.Stmt) ([]*funcNode, []string, error) {
    12  	functions := []*funcNode{}
    13  	topLevel := []string{}
    14  	for _, stmt := range stmts {
    15  		switch item := stmt.(type) {
    16  		case *syntax.DefStmt:
    17  			res, err := analyzeFunction(item)
    18  			if err != nil {
    19  				return nil, nil, err
    20  			}
    21  			functions = append(functions, res)
    22  		default:
    23  			calls := getFuncCallsInStmtList([]syntax.Stmt{stmt})
    24  			topLevel = append(topLevel, calls...)
    25  		}
    26  	}
    27  	return functions, topLevel, nil
    28  }
    29  
    30  // build a function object, contains calls to other functions
    31  func analyzeFunction(def *syntax.DefStmt) (*funcNode, error) {
    32  	params := make([]string, len(def.Params))
    33  	for k, param := range def.Params {
    34  		p := parameterName(param)
    35  		params[k] = p
    36  	}
    37  
    38  	res, err := buildFromFuncBody(def.Body)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	res.name = def.Name.Name
    43  	res.params = params
    44  	res.body = def.Body
    45  
    46  	return res, nil
    47  }
    48  
    49  func parameterName(e syntax.Expr) string {
    50  	id, ok := e.(*syntax.Ident)
    51  	if !ok {
    52  		return fmt.Sprintf("<UNKNOWN: %s>", reflect.TypeOf(e))
    53  	}
    54  	return id.Name
    55  }
    56  
    57  // funcNode is a single function definition and body, along with
    58  // additional information derived from static analysis
    59  type funcNode struct {
    60  	name   string
    61  	params []string
    62  	body   []syntax.Stmt
    63  	calls  []*funcNode
    64  	reach  bool
    65  	height int
    66  	// the string names of functions that are called, only needed
    67  	// until call graph is built, and the `calls` field is set
    68  	callNames []string
    69  	// used by dataflow analysis to track sensitive data and
    70  	// dangerous parameters, to ensure safe data usage
    71  	dangerousParams []bool
    72  	sensitiveReturn bool
    73  	reasonParams    []reason
    74  }
    75  
    76  // newFuncNode constructs a new funcNode
    77  func newFuncNode() *funcNode {
    78  	return &funcNode{calls: []*funcNode{}}
    79  }
    80  
    81  func buildFromFuncBody(body []syntax.Stmt) (*funcNode, error) {
    82  	node := newFuncNode()
    83  	node.callNames = getFuncCallsInStmtList(body)
    84  	return node, nil
    85  }
    86  
    87  func getFuncCallsInStmtList(listStmt []syntax.Stmt) []string {
    88  	result := make([]string, 0)
    89  
    90  	for _, stmt := range listStmt {
    91  		switch item := stmt.(type) {
    92  		case *syntax.AssignStmt:
    93  			calls := getFuncCallsInExpr(item.LHS)
    94  			calls = append(calls, getFuncCallsInExpr(item.RHS)...)
    95  			result = append(result, calls...)
    96  
    97  		case *syntax.BranchStmt:
    98  			// pass
    99  
   100  		case *syntax.DefStmt:
   101  			// TODO(dustmop): Add this definition to the lexical scope
   102  
   103  		case *syntax.ExprStmt:
   104  			calls := getFuncCallsInExpr(item.X)
   105  			result = append(result, calls...)
   106  
   107  		case *syntax.ForStmt:
   108  			calls := getFuncCallsInExpr(item.X)
   109  			calls = append(calls, getFuncCallsInStmtList(item.Body)...)
   110  			result = append(result, calls...)
   111  
   112  		case *syntax.WhileStmt:
   113  			calls := getFuncCallsInExpr(item.Cond)
   114  			calls = append(calls, getFuncCallsInStmtList(item.Body)...)
   115  			result = append(result, calls...)
   116  
   117  		case *syntax.IfStmt:
   118  			calls := getFuncCallsInExpr(item.Cond)
   119  			calls = append(calls, getFuncCallsInStmtList(item.True)...)
   120  			calls = append(calls, getFuncCallsInStmtList(item.False)...)
   121  			result = append(result, calls...)
   122  
   123  		case *syntax.LoadStmt:
   124  			// pass
   125  
   126  		case *syntax.ReturnStmt:
   127  			calls := getFuncCallsInExpr(item.Result)
   128  			result = append(result, calls...)
   129  
   130  		}
   131  	}
   132  
   133  	return result
   134  }
   135  
   136  func getFuncCallsInExpr(expr syntax.Expr) []string {
   137  	if expr == nil {
   138  		return []string{}
   139  	}
   140  	switch item := expr.(type) {
   141  	case *syntax.BinaryExpr:
   142  		return append(getFuncCallsInExpr(item.X), getFuncCallsInExpr(item.Y)...)
   143  
   144  	case *syntax.CallExpr:
   145  		// TODO(dustmop): Add lexical scoping so that inner functions are
   146  		// correctly associated with their call sites
   147  		funcName := simpleExprToFuncName(item.Fn)
   148  		result := make([]string, 0, 1+len(item.Args))
   149  		result = append(result, funcName)
   150  		for _, arg := range item.Args {
   151  			result = append(result, getFuncCallsInExpr(arg)...)
   152  		}
   153  		return result
   154  
   155  	case *syntax.Comprehension:
   156  		result := getFuncCallsInExpr(item.Body)
   157  		return result
   158  
   159  	case *syntax.CondExpr:
   160  		result := getFuncCallsInExpr(item.Cond)
   161  		result = append(result, getFuncCallsInExpr(item.True)...)
   162  		result = append(result, getFuncCallsInExpr(item.False)...)
   163  		return result
   164  
   165  	case *syntax.DictEntry:
   166  		return append(getFuncCallsInExpr(item.Key), getFuncCallsInExpr(item.Value)...)
   167  
   168  	case *syntax.DictExpr:
   169  		result := make([]string, 0, len(item.List))
   170  		for _, elem := range item.List {
   171  			result = append(result, getFuncCallsInExpr(elem)...)
   172  		}
   173  		return result
   174  
   175  	case *syntax.DotExpr:
   176  		return []string{}
   177  
   178  	case *syntax.Ident:
   179  		return []string{}
   180  
   181  	case *syntax.IndexExpr:
   182  		return append(getFuncCallsInExpr(item.X), getFuncCallsInExpr(item.Y)...)
   183  
   184  	case *syntax.LambdaExpr:
   185  		result := make([]string, 0, 1+len(item.Params))
   186  		result = append(result, getFuncCallsInExpr(item.Body)...)
   187  		for _, elem := range item.Params {
   188  			result = append(result, getFuncCallsInExpr(elem)...)
   189  		}
   190  		return result
   191  
   192  	case *syntax.ListExpr:
   193  		result := make([]string, 0, len(item.List))
   194  		for _, elem := range item.List {
   195  			result = append(result, getFuncCallsInExpr(elem)...)
   196  		}
   197  		return result
   198  
   199  	case *syntax.Literal:
   200  		return []string{}
   201  
   202  	case *syntax.ParenExpr:
   203  		return getFuncCallsInExpr(item.X)
   204  
   205  	case *syntax.SliceExpr:
   206  		result := getFuncCallsInExpr(item.X)
   207  		result = append(result, getFuncCallsInExpr(item.Lo)...)
   208  		result = append(result, getFuncCallsInExpr(item.Hi)...)
   209  		result = append(result, getFuncCallsInExpr(item.Step)...)
   210  		return result
   211  
   212  	case *syntax.TupleExpr:
   213  		result := make([]string, 0, len(item.List))
   214  		for _, elem := range item.List {
   215  			result = append(result, getFuncCallsInExpr(elem)...)
   216  		}
   217  		return result
   218  
   219  	case *syntax.UnaryExpr:
   220  		return getFuncCallsInExpr(item.X)
   221  
   222  	}
   223  	return nil
   224  }
   225  
   226  func simpleExprToFuncName(expr syntax.Expr) string {
   227  	if item, ok := expr.(*syntax.Ident); ok {
   228  		return item.Name
   229  	}
   230  	if item, ok := expr.(*syntax.DotExpr); ok {
   231  		lhs := simpleExprToFuncName(item.X)
   232  		return fmt.Sprintf("%s.%s", lhs, item.Name.Name)
   233  	}
   234  	return fmt.Sprintf("<Unknown Name, Type: %q>", reflect.TypeOf(expr))
   235  }