github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/transform/staticlark/control_flow.go (about)

     1  package staticlark
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"strings"
     7  
     8  	"go.starlark.net/syntax"
     9  )
    10  
    11  // controlFlow represents the control flow within a single function
    12  // It is a list of blocks, where each block is one or more code units
    13  // that have linear control flow. In addition, each block has a list
    14  // of outgoing edges that reference which block control flow may move
    15  // to next. These edges are represented by indexes into the control
    16  // flow graph's block list.
    17  //
    18  // Example:
    19  //
    20  // The code:
    21  //
    22  //   a = 1
    23  //   if a > b:
    24  //     print('big')
    25  //   print('done')
    26  //
    27  // The graph represented abstractly:
    28  //
    29  // +-------+
    30  // | a = 1 |--+
    31  // +-------+  |
    32  //            |
    33  //   +--------+
    34  //   |
    35  //   v
    36  // +----------+
    37  // | if a > b |------+
    38  // +----------+      |
    39  //                   |
    40  //   +---------------+
    41  //   |               |
    42  //   v               |
    43  // +--------------+  |
    44  // | print('big') |--+
    45  // +--------------+  |
    46  //                   |
    47  //   +---------------+
    48  //   |
    49  //   v
    50  // +---------------+
    51  // | print('done') |
    52  // +---------------+
    53  //
    54  // Control flow, stringified:
    55  //
    56  // 0: [set! a 1]
    57  //   outs: 1
    58  // 1: [if [> a b]]
    59  //   outs: 2,3
    60  // 2: [print 'big']
    61  //   outs: 3
    62  // 3: [print 'done']
    63  //   outs: -
    64  type controlFlow struct {
    65  	blocks []*codeBlock
    66  }
    67  
    68  func newControlFlow() *controlFlow {
    69  	return &controlFlow{}
    70  }
    71  
    72  func newControlFlowFromFunc(fn *funcNode) (*controlFlow, error) {
    73  	builder := newControlFlowBuilder()
    74  	cf := builder.build(fn.body)
    75  	return cf, nil
    76  }
    77  
    78  func (c *controlFlow) stringify() string {
    79  	result := ""
    80  	for i, block := range c.blocks {
    81  		result += fmt.Sprintf("%d: %s\n", i, block.stringify())
    82  	}
    83  	return result
    84  }
    85  
    86  // codeBlock is a single block of linear control flow, along with
    87  // a list of outgoing edges, represented as indexes into other
    88  // blocks in the control flow
    89  type codeBlock struct {
    90  	units []*unit
    91  	edges []int
    92  	// if the block has multiple edges that branch forward, join is
    93  	// the index that those branches join at. Only used by `if`
    94  	join int
    95  }
    96  
    97  func newCodeBlock() *codeBlock {
    98  	return &codeBlock{
    99  		units: []*unit{},
   100  	}
   101  }
   102  
   103  func (b *codeBlock) stringify() string {
   104  	result := ""
   105  	for j, unit := range b.units {
   106  		if j == 0 {
   107  			result += fmt.Sprintf("%s\n", unit)
   108  		} else {
   109  			padding := strings.Repeat(" ", 3)
   110  			result += fmt.Sprintf("%s%s\n", padding, unit)
   111  		}
   112  	}
   113  	if len(b.units) == 0 {
   114  		result += "-\n"
   115  	}
   116  	if len(b.edges) == 0 {
   117  		result += "  out: -"
   118  	} else if len(b.edges) == 1 && b.edges[0] == -1 {
   119  		result += "  out: return"
   120  	} else if len(b.edges) == 1 && b.edges[0] == -2 {
   121  		result += "  out: break"
   122  	} else {
   123  		outs := make([]string, len(b.edges))
   124  		for j, edge := range b.edges {
   125  			outs[j] = strconv.Itoa(edge)
   126  		}
   127  		if b.join > 0 {
   128  			result += fmt.Sprintf("  out: %s, join: %d", strings.Join(outs, ","), b.join)
   129  		} else {
   130  			result += fmt.Sprintf("  out: %s", strings.Join(outs, ","))
   131  		}
   132  	}
   133  	return result
   134  }
   135  
   136  func (b *codeBlock) isLinear() bool {
   137  	return len(b.edges) <= 1
   138  }
   139  
   140  func (b *codeBlock) isIfCondition() bool {
   141  	if len(b.units) == 1 {
   142  		u := b.units[0]
   143  		if u.atom == "if" {
   144  			return true
   145  		}
   146  	}
   147  	return false
   148  }
   149  
   150  func (b *codeBlock) String() string {
   151  	return b.stringify()
   152  }
   153  
   154  // cfBuilder is a builder that creates a control flow
   155  type cfBuilder struct {
   156  	// the entire control flow being built
   157  	flow *controlFlow
   158  	// current block being added to
   159  	curr *codeBlock
   160  	// references to blocks that do not have outgoing edges
   161  	dangling []int
   162  	// references to blocks that want to join at the next block
   163  	danglingJoin []int
   164  }
   165  
   166  func newControlFlowBuilder() *cfBuilder {
   167  	builder := &cfBuilder{}
   168  	return builder
   169  }
   170  
   171  // create a new block if there is no current block being built
   172  func (builder *cfBuilder) ensureBlock() {
   173  	if builder.flow.blocks == nil || builder.curr == nil {
   174  		builder.makeBlock()
   175  	}
   176  }
   177  
   178  // put a unit at the end of the current block
   179  func (builder *cfBuilder) put(unit *unit) {
   180  	builder.curr.units = append(builder.curr.units, unit)
   181  }
   182  
   183  // for each referenced block, add `dest` to the outgoing edges
   184  func (builder *cfBuilder) addEdges(blockRefs []int, dest []int) {
   185  	for _, ref := range blockRefs {
   186  		builder.flow.blocks[ref].edges = append(builder.flow.blocks[ref].edges, dest...)
   187  	}
   188  }
   189  
   190  // mark the referenced blocks as needing to join to the next block
   191  func (builder *cfBuilder) wantJoin(blockRefs []int) {
   192  	for _, ref := range blockRefs {
   193  		builder.flow.blocks[ref].join = -1
   194  	}
   195  	builder.danglingJoin = append(builder.danglingJoin, blockRefs...)
   196  }
   197  
   198  // add a new block to the control flow, and point any dangling edges to it
   199  func (builder *cfBuilder) makeBlock() {
   200  	if len(builder.dangling) != 0 {
   201  		index := builder.refNext()
   202  		if len(builder.dangling) > 1 && builder.danglingJoin != nil {
   203  			for _, idx := range builder.danglingJoin {
   204  				builder.flow.blocks[idx].join = index
   205  			}
   206  			builder.danglingJoin = nil
   207  		}
   208  		builder.addEdges(builder.dangling, []int{index})
   209  	}
   210  
   211  	index := builder.refNext()
   212  	builder.flow.blocks = append(builder.flow.blocks, newCodeBlock())
   213  	builder.curr = builder.flow.blocks[index]
   214  	builder.dangling = []int{index}
   215  }
   216  
   217  // mark the current block as done, return refs to blocks with dangling edges
   218  func (builder *cfBuilder) finish() []int {
   219  	refs := builder.dangling
   220  	builder.curr = nil
   221  	builder.dangling = nil
   222  	return refs
   223  }
   224  
   225  // get reference to the next block, as an index
   226  func (builder *cfBuilder) refNext() int {
   227  	return len(builder.flow.blocks)
   228  }
   229  
   230  // build is the main entry point for the builder. Takes a body of a
   231  // function and creates a control flow for that body
   232  func (builder *cfBuilder) build(stmtList []syntax.Stmt) *controlFlow {
   233  	builder.flow = newControlFlow()
   234  	builder.buildSubGraph(stmtList)
   235  	return builder.flow
   236  }
   237  
   238  // build a sub graph of an already in-use builder, return the dangling
   239  // edges at the end of that sub-graph
   240  func (builder *cfBuilder) buildSubGraph(stmtList []syntax.Stmt) []int {
   241  	for _, line := range stmtList {
   242  		builder.buildSingleNode(line)
   243  	}
   244  	return builder.finish()
   245  }
   246  
   247  func (builder *cfBuilder) buildSingleNode(stmt syntax.Stmt) {
   248  	switch item := stmt.(type) {
   249  	case *syntax.AssignStmt:
   250  		assignUnit := assignmentToUnit(item)
   251  		builder.ensureBlock()
   252  		builder.put(assignUnit)
   253  
   254  	case *syntax.BranchStmt:
   255  		// TODO(dustmop): support other operations, like `continue`
   256  		builder.ensureBlock()
   257  		builder.put(&unit{atom: "[break]"})
   258  		// TODO(dustmop): have the builder track the inner-most loop
   259  		// being built, have this edge point to the end of that loop
   260  		builder.addEdges(builder.dangling, []int{-2})
   261  		builder.dangling = nil
   262  
   263  	case *syntax.DefStmt:
   264  		// TODO(dustmop): inner functions need to be supported
   265  
   266  	case *syntax.ExprStmt:
   267  		builder.ensureBlock()
   268  		builder.put(exprStatementToUnit(item))
   269  
   270  	case *syntax.ForStmt:
   271  		startPos := builder.refNext()
   272  		builder.makeBlock()
   273  
   274  		// condition for loop
   275  		checkUnit := &unit{atom: "for"}
   276  		checkUnit.tail = []*unit{exprToUnit(item.Vars), exprToUnit(item.X)}
   277  		builder.put(checkUnit)
   278  		loopEntry := builder.finish()
   279  
   280  		bodyPos := builder.refNext()
   281  		loopLeave := builder.buildSubGraph(item.Body)
   282  
   283  		// add edges to create the loop flow
   284  		builder.addEdges(loopEntry, []int{bodyPos})
   285  		builder.addEdges(loopLeave, []int{startPos})
   286  		builder.dangling = []int{startPos}
   287  
   288  	case *syntax.WhileStmt:
   289  		// NOTE: analyzer does not support while loops
   290  
   291  	case *syntax.IfStmt:
   292  		builder.makeBlock()
   293  
   294  		// condition of the if
   295  		condLine := &unit{atom: "if"}
   296  		condLine.tail = append(condLine.tail, exprToUnit(item.Cond))
   297  		builder.put(condLine)
   298  		branchEntry := builder.finish()
   299  
   300  		truePos := builder.refNext()
   301  		exitTrue := builder.buildSubGraph(item.True)
   302  
   303  		falsePos := builder.refNext()
   304  		exitFalse := builder.buildSubGraph(item.False)
   305  
   306  		if len(exitFalse) > 0 {
   307  			builder.addEdges(branchEntry, []int{truePos, falsePos})
   308  			builder.dangling = append(exitTrue, exitFalse...)
   309  		} else {
   310  			builder.addEdges(branchEntry, []int{truePos})
   311  			builder.dangling = append(exitTrue, branchEntry...)
   312  		}
   313  		builder.wantJoin(branchEntry)
   314  
   315  	case *syntax.LoadStmt:
   316  		// nothing to do
   317  
   318  	case *syntax.ReturnStmt:
   319  		builder.ensureBlock()
   320  		retLine := &unit{atom: "return"}
   321  		retLine.tail = []*unit{exprToUnit(item.Result)}
   322  		builder.put(retLine)
   323  		builder.addEdges(builder.dangling, []int{-1})
   324  		builder.dangling = nil
   325  
   326  	}
   327  }
   328  
   329  func assignmentToUnit(assign *syntax.AssignStmt) *unit {
   330  	// left hand side of assignment
   331  	lhs := ""
   332  	if ident, ok := assign.LHS.(*syntax.Ident); ok {
   333  		lhs = ident.Name
   334  	} else {
   335  		lhs = fmt.Sprintf("TODO:%T", assign.LHS)
   336  	}
   337  
   338  	// right hand side of assignment
   339  	rhs := &unit{}
   340  	if ident, ok := assign.RHS.(*syntax.Ident); ok {
   341  		rhs.Push(ident.Name)
   342  	} else if val, ok := assign.RHS.(*syntax.Literal); ok {
   343  		rhs.Push(val.Raw)
   344  	} else if binExp, ok := assign.RHS.(*syntax.BinaryExpr); ok {
   345  		tree := binaryOpToUnit(binExp)
   346  		rhs.tail = append(rhs.tail, tree)
   347  	} else if _, ok := assign.RHS.(*syntax.CallExpr); ok {
   348  		unit := exprToUnit(assign.RHS)
   349  		rhs.tail = append(rhs.tail, unit)
   350  	} else {
   351  		rhs.Push(fmt.Sprintf("TODO:%T", assign.RHS))
   352  	}
   353  
   354  	result := &unit{atom: "set!", where: getWhere(assign)}
   355  	result.Push(lhs)
   356  	if assign.Op == syntax.EQ {
   357  		result.tail = append(result.tail, rhs.tail...)
   358  	} else {
   359  		result.tail = buildAssignOp(syntaxOpToString(assign.Op), lhs, rhs.tail)
   360  	}
   361  	return result
   362  }
   363  
   364  func syntaxOpToString(op syntax.Token) string {
   365  	switch op {
   366  	case syntax.EQ:
   367  		return "="
   368  	case syntax.PLUS_EQ:
   369  		return "+="
   370  	case syntax.MINUS_EQ:
   371  		return "-="
   372  	case syntax.STAR_EQ:
   373  		return "*="
   374  	case syntax.SLASH_EQ:
   375  		return "/="
   376  	}
   377  	return "?"
   378  }
   379  
   380  func buildAssignOp(opText, ident string, expr []*unit) []*unit {
   381  	prev := []*unit{&unit{atom: ident}}
   382  	return append(prev, &unit{atom: opText, tail: append([]*unit{&unit{atom: ident}}, expr...)})
   383  }
   384  
   385  func exprStatementToUnit(expr *syntax.ExprStmt) *unit {
   386  	e := expr.X
   387  	switch item := e.(type) {
   388  	case *syntax.BinaryExpr:
   389  		return toUnitTODO("binary()")
   390  
   391  	case *syntax.CallExpr:
   392  		fn := item.Fn
   393  		callName := simpleExprToFuncName(fn)
   394  		tail := []*unit{}
   395  		for _, e := range item.Args {
   396  			// TODO: exprToUnit(e).String() shouldn't collapse to string
   397  			tail = append(tail, &unit{atom: exprToUnit(e).String()})
   398  		}
   399  		return &unit{atom: callName, tail: tail, where: getWhere(expr)}
   400  
   401  	case *syntax.Comprehension:
   402  		return toUnitTODO("comp()")
   403  
   404  	case *syntax.CondExpr:
   405  		return toUnitTODO("cond()")
   406  
   407  	case *syntax.DictEntry:
   408  		return toUnitTODO("dictEntry()")
   409  
   410  	case *syntax.DictExpr:
   411  		return toUnitTODO("dict()")
   412  
   413  	case *syntax.DotExpr:
   414  		return toUnitTODO("dot()")
   415  
   416  	case *syntax.Ident:
   417  		return toUnitTODO("%s()")
   418  
   419  	case *syntax.IndexExpr:
   420  		return toUnitTODO("index()")
   421  
   422  	case *syntax.LambdaExpr:
   423  		return toUnitTODO("lambda()")
   424  
   425  	case *syntax.ListExpr:
   426  		return toUnitTODO("list()")
   427  
   428  	case *syntax.Literal:
   429  		return toUnitTODO("literal()")
   430  
   431  	case *syntax.ParenExpr:
   432  		return toUnitTODO("paren()")
   433  
   434  	case *syntax.SliceExpr:
   435  		return toUnitTODO("slice()")
   436  
   437  	case *syntax.TupleExpr:
   438  		return toUnitTODO("tuple()")
   439  
   440  	case *syntax.UnaryExpr:
   441  		return toUnitTODO("unary()")
   442  
   443  	}
   444  	return toUnitTODO("????()")
   445  }
   446  
   447  func binaryOpToUnit(binExp *syntax.BinaryExpr) *unit {
   448  	res := &unit{}
   449  	res.atom = binExp.Op.String()
   450  	res.tail = []*unit{exprToUnit(binExp.X), exprToUnit(binExp.Y)}
   451  	return res
   452  }
   453  
   454  func exprToUnit(expr syntax.Expr) *unit {
   455  	switch item := expr.(type) {
   456  	case *syntax.BinaryExpr:
   457  		return binaryOpToUnit(item)
   458  
   459  	case *syntax.CallExpr:
   460  		fn := item.Fn
   461  		callName := simpleExprToFuncName(fn)
   462  		tail := []*unit{}
   463  		for _, e := range item.Args {
   464  			// TODO: exprToUnit(e).String() shouldn't collapse to string
   465  			tail = append(tail, &unit{atom: exprToUnit(e).String()})
   466  		}
   467  		return &unit{atom: callName, tail: tail, where: getWhere(expr)}
   468  
   469  	case *syntax.Comprehension:
   470  		return toUnitTODO("{comprehension}")
   471  
   472  	case *syntax.CondExpr:
   473  		return toUnitTODO("{condExpr}")
   474  
   475  	case *syntax.DictEntry:
   476  		return toUnitTODO("{dictEntry}")
   477  
   478  	case *syntax.DictExpr:
   479  		return toUnitTODO("{dictExpr}")
   480  
   481  	case *syntax.DotExpr:
   482  		return toUnitTODO("{dotExpr}")
   483  
   484  	case *syntax.Ident:
   485  		return &unit{atom: item.Name}
   486  
   487  	case *syntax.IndexExpr:
   488  		return toUnitTODO("{indexExpr}")
   489  
   490  	case *syntax.LambdaExpr:
   491  		return toUnitTODO("{lambdaExpr}")
   492  
   493  	case *syntax.ListExpr:
   494  		return toUnitTODO("{listExpr}")
   495  
   496  	case *syntax.Literal:
   497  		return &unit{atom: item.Raw}
   498  
   499  	case *syntax.ParenExpr:
   500  		return toUnitTODO("{parenExpr}")
   501  
   502  	case *syntax.SliceExpr:
   503  		return toUnitTODO("{sliceExpr}")
   504  
   505  	case *syntax.TupleExpr:
   506  		return toUnitTODO("{tupleExpr}")
   507  
   508  	case *syntax.UnaryExpr:
   509  		return toUnitTODO("{unaryExpr}")
   510  
   511  	default:
   512  		return toUnitTODO("{unknown}")
   513  	}
   514  }
   515  
   516  func getWhere(n syntax.Node) syntax.Position {
   517  	start, _ := n.Span()
   518  	return start
   519  }