
     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     5  package cfg
     7  // This file implements the CFG construction pass.
     9  import (
    10  	"fmt"
    11  	"go/ast"
    12  	"go/token"
    13  )
    15  type builder struct {
    16  	cfg       *CFG
    17  	mayReturn func(*ast.CallExpr) bool
    18  	current   *Block
    19  	lblocks   map[*ast.Object]*lblock // labeled blocks
    20  	targets   *targets                // linked stack of branch targets
    21  }
    23  func (b *builder) stmt(_s ast.Stmt) {
    24  	// The label of the current statement.  If non-nil, its _goto
    25  	// target is always set; its _break and _continue are set only
    26  	// within the body of switch/typeswitch/select/for/range.
    27  	// It is effectively an additional default-nil parameter of stmt().
    28  	var label *lblock
    29  start:
    30  	switch s := _s.(type) {
    31  	case *ast.BadStmt,
    32  		*ast.SendStmt,
    33  		*ast.IncDecStmt,
    34  		*ast.GoStmt,
    35  		*ast.DeferStmt,
    36  		*ast.EmptyStmt,
    37  		*ast.AssignStmt:
    38  		// No effect on control flow.
    39  		b.add(s)
    41  	case *ast.ExprStmt:
    42  		b.add(s)
    43  		if call, ok := s.X.(*ast.CallExpr); ok && !b.mayReturn(call) {
    44  			// Calls to panic, os.Exit, etc, never return.
    45  			b.current = b.newUnreachableBlock("")
    46  		}
    48  	case *ast.DeclStmt:
    49  		// Treat each var ValueSpec as a separate statement.
    50  		d := s.Decl.(*ast.GenDecl)
    51  		if d.Tok == token.VAR {
    52  			for _, spec := range d.Specs {
    53  				if spec, ok := spec.(*ast.ValueSpec); ok {
    54  					b.add(spec)
    55  				}
    56  			}
    57  		}
    59  	case *ast.LabeledStmt:
    60  		label = b.labeledBlock(s.Label)
    61  		b.jump(label._goto)
    62  		b.current = label._goto
    63  		_s = s.Stmt
    64  		goto start // effectively: tailcall stmt(g, s.Stmt, label)
    66  	case *ast.ReturnStmt:
    67  		b.add(s)
    68  		b.current = b.newUnreachableBlock("unreachable.return")
    70  	case *ast.BranchStmt:
    71  		var block *Block
    72  		switch s.Tok {
    73  		case token.BREAK:
    74  			if s.Label != nil {
    75  				if lb := b.labeledBlock(s.Label); lb != nil {
    76  					block = lb._break
    77  				}
    78  			} else {
    79  				for t := b.targets; t != nil && block == nil; t = t.tail {
    80  					block = t._break
    81  				}
    82  			}
    84  		case token.CONTINUE:
    85  			if s.Label != nil {
    86  				if lb := b.labeledBlock(s.Label); lb != nil {
    87  					block = lb._continue
    88  				}
    89  			} else {
    90  				for t := b.targets; t != nil && block == nil; t = t.tail {
    91  					block = t._continue
    92  				}
    93  			}
    95  		case token.FALLTHROUGH:
    96  			for t := b.targets; t != nil; t = t.tail {
    97  				block = t._fallthrough
    98  			}
   100  		case token.GOTO:
   101  			if s.Label != nil {
   102  				block = b.labeledBlock(s.Label)._goto
   103  			}
   104  		}
   105  		if block == nil {
   106  			block = b.newBlock("undefined.branch")
   107  		}
   108  		b.jump(block)
   109  		b.current = b.newUnreachableBlock("unreachable.branch")
   111  	case *ast.BlockStmt:
   112  		b.stmtList(s.List)
   114  	case *ast.IfStmt:
   115  		if s.Init != nil {
   116  			b.stmt(s.Init)
   117  		}
   118  		then := b.newBlock("if.then")
   119  		done := b.newBlock("if.done")
   120  		_else := done
   121  		if s.Else != nil {
   122  			_else = b.newBlock("if.else")
   123  		}
   124  		b.add(s.Cond)
   125  		b.ifelse(then, _else)
   126  		b.current = then
   127  		b.stmt(s.Body)
   128  		b.jump(done)
   130  		if s.Else != nil {
   131  			b.current = _else
   132  			b.stmt(s.Else)
   133  			b.jump(done)
   134  		}
   136  		b.current = done
   138  	case *ast.SwitchStmt:
   139  		b.switchStmt(s, label)
   141  	case *ast.TypeSwitchStmt:
   142  		b.typeSwitchStmt(s, label)
   144  	case *ast.SelectStmt:
   145  		b.selectStmt(s, label)
   147  	case *ast.ForStmt:
   148  		b.forStmt(s, label)
   150  	case *ast.RangeStmt:
   151  		b.rangeStmt(s, label)
   153  	default:
   154  		panic(fmt.Sprintf("unexpected statement kind: %T", s))
   155  	}
   156  }
   158  func (b *builder) stmtList(list []ast.Stmt) {
   159  	for _, s := range list {
   160  		b.stmt(s)
   161  	}
   162  }
   164  func (b *builder) switchStmt(s *ast.SwitchStmt, label *lblock) {
   165  	if s.Init != nil {
   166  		b.stmt(s.Init)
   167  	}
   168  	if s.Tag != nil {
   169  		b.add(s.Tag)
   170  	}
   171  	done := b.newBlock("switch.done")
   172  	if label != nil {
   173  		label._break = done
   174  	}
   175  	// We pull the default case (if present) down to the end.
   176  	// But each fallthrough label must point to the next
   177  	// body block in source order, so we preallocate a
   178  	// body block (fallthru) for the next case.
   179  	// Unfortunately this makes for a confusing block order.
   180  	var defaultBody *[]ast.Stmt
   181  	var defaultFallthrough *Block
   182  	var fallthru, defaultBlock *Block
   183  	ncases := len(s.Body.List)
   184  	for i, clause := range s.Body.List {
   185  		body := fallthru
   186  		if body == nil {
   187  			body = b.newBlock("switch.body") // first case only
   188  		}
   190  		// Preallocate body block for the next case.
   191  		fallthru = done
   192  		if i+1 < ncases {
   193  			fallthru = b.newBlock("switch.body")
   194  		}
   196  		cc := clause.(*ast.CaseClause)
   197  		if cc.List == nil {
   198  			// Default case.
   199  			defaultBody = &cc.Body
   200  			defaultFallthrough = fallthru
   201  			defaultBlock = body
   202  			continue
   203  		}
   205  		var nextCond *Block
   206  		for _, cond := range cc.List {
   207  			nextCond = b.newBlock("")
   208  			b.add(cond) // one half of the tag==cond condition
   209  			b.ifelse(body, nextCond)
   210  			b.current = nextCond
   211  		}
   212  		b.current = body
   213  		b.targets = &targets{
   214  			tail:         b.targets,
   215  			_break:       done,
   216  			_fallthrough: fallthru,
   217  		}
   218  		b.stmtList(cc.Body)
   219  		b.targets = b.targets.tail
   220  		b.jump(done)
   221  		b.current = nextCond
   222  	}
   223  	if defaultBlock != nil {
   224  		b.jump(defaultBlock)
   225  		b.current = defaultBlock
   226  		b.targets = &targets{
   227  			tail:         b.targets,
   228  			_break:       done,
   229  			_fallthrough: defaultFallthrough,
   230  		}
   231  		b.stmtList(*defaultBody)
   232  		b.targets = b.targets.tail
   233  	}
   234  	b.jump(done)
   235  	b.current = done
   236  }
   238  func (b *builder) typeSwitchStmt(s *ast.TypeSwitchStmt, label *lblock) {
   239  	if s.Init != nil {
   240  		b.stmt(s.Init)
   241  	}
   242  	if s.Assign != nil {
   243  		b.add(s.Assign)
   244  	}
   246  	done := b.newBlock("typeswitch.done")
   247  	if label != nil {
   248  		label._break = done
   249  	}
   250  	var default_ *ast.CaseClause
   251  	for _, clause := range s.Body.List {
   252  		cc := clause.(*ast.CaseClause)
   253  		if cc.List == nil {
   254  			default_ = cc
   255  			continue
   256  		}
   257  		body := b.newBlock("typeswitch.body")
   258  		var next *Block
   259  		for _, casetype := range cc.List {
   260  			next = b.newBlock("")
   261  			// casetype is a type, so don't call b.add(casetype).
   262  			// This block logically contains a type assertion,
   263  			// x.(casetype), but it's unclear how to represent x.
   264  			_ = casetype
   265  			b.ifelse(body, next)
   266  			b.current = next
   267  		}
   268  		b.current = body
   269  		b.typeCaseBody(cc, done)
   270  		b.current = next
   271  	}
   272  	if default_ != nil {
   273  		b.typeCaseBody(default_, done)
   274  	} else {
   275  		b.jump(done)
   276  	}
   277  	b.current = done
   278  }
   280  func (b *builder) typeCaseBody(cc *ast.CaseClause, done *Block) {
   281  	b.targets = &targets{
   282  		tail:   b.targets,
   283  		_break: done,
   284  	}
   285  	b.stmtList(cc.Body)
   286  	b.targets = b.targets.tail
   287  	b.jump(done)
   288  }
   290  func (b *builder) selectStmt(s *ast.SelectStmt, label *lblock) {
   291  	// First evaluate channel expressions.
   292  	// TODO(adonovan): fix: evaluate only channel exprs here.
   293  	for _, clause := range s.Body.List {
   294  		if comm := clause.(*ast.CommClause).Comm; comm != nil {
   295  			b.stmt(comm)
   296  		}
   297  	}
   299  	done := b.newBlock("select.done")
   300  	if label != nil {
   301  		label._break = done
   302  	}
   304  	var defaultBody *[]ast.Stmt
   305  	for _, cc := range s.Body.List {
   306  		clause := cc.(*ast.CommClause)
   307  		if clause.Comm == nil {
   308  			defaultBody = &clause.Body
   309  			continue
   310  		}
   311  		body := b.newBlock("select.body")
   312  		next := b.newBlock("")
   313  		b.ifelse(body, next)
   314  		b.current = body
   315  		b.targets = &targets{
   316  			tail:   b.targets,
   317  			_break: done,
   318  		}
   319  		switch comm := clause.Comm.(type) {
   320  		case *ast.ExprStmt: // <-ch
   321  			// nop
   322  		case *ast.AssignStmt: // x := <-states[state].Chan
   323  			b.add(comm.Lhs[0])
   324  		}
   325  		b.stmtList(clause.Body)
   326  		b.targets = b.targets.tail
   327  		b.jump(done)
   328  		b.current = next
   329  	}
   330  	if defaultBody != nil {
   331  		b.targets = &targets{
   332  			tail:   b.targets,
   333  			_break: done,
   334  		}
   335  		b.stmtList(*defaultBody)
   336  		b.targets = b.targets.tail
   337  		b.jump(done)
   338  	}
   339  	b.current = done
   340  }
   342  func (b *builder) forStmt(s *ast.ForStmt, label *lblock) {
   343  	//	...init...
   344  	//      jump loop
   345  	// loop:
   346  	//      if cond goto body else done
   347  	// body:
   348  	//      ...body...
   349  	//      jump post
   350  	// post:				 (target of continue)
   351  	//
   352  	//      jump loop
   353  	// done:                                 (target of break)
   354  	if s.Init != nil {
   355  		b.stmt(s.Init)
   356  	}
   357  	body := b.newBlock("for.body")
   358  	done := b.newBlock("for.done") // target of 'break'
   359  	loop := body                   // target of back-edge
   360  	if s.Cond != nil {
   361  		loop = b.newBlock("for.loop")
   362  	}
   363  	cont := loop // target of 'continue'
   364  	if s.Post != nil {
   365  		cont = b.newBlock("")
   366  	}
   367  	if label != nil {
   368  		label._break = done
   369  		label._continue = cont
   370  	}
   371  	b.jump(loop)
   372  	b.current = loop
   373  	if loop != body {
   374  		b.add(s.Cond)
   375  		b.ifelse(body, done)
   376  		b.current = body
   377  	}
   378  	b.targets = &targets{
   379  		tail:      b.targets,
   380  		_break:    done,
   381  		_continue: cont,
   382  	}
   383  	b.stmt(s.Body)
   384  	b.targets = b.targets.tail
   385  	b.jump(cont)
   387  	if s.Post != nil {
   388  		b.current = cont
   389  		b.stmt(s.Post)
   390  		b.jump(loop) // back-edge
   391  	}
   392  	b.current = done
   393  }
   395  func (b *builder) rangeStmt(s *ast.RangeStmt, label *lblock) {
   396  	b.add(s.X)
   398  	if s.Key != nil {
   399  		b.add(s.Key)
   400  	}
   401  	if s.Value != nil {
   402  		b.add(s.Value)
   403  	}
   405  	//      ...
   406  	// loop:                                   (target of continue)
   407  	// 	if ... goto body else done
   408  	// body:
   409  	//      ...
   410  	// 	jump loop
   411  	// done:                                   (target of break)
   413  	loop := b.newBlock("range.loop")
   414  	b.jump(loop)
   415  	b.current = loop
   417  	body := b.newBlock("range.body")
   418  	done := b.newBlock("range.done")
   419  	b.ifelse(body, done)
   420  	b.current = body
   422  	if label != nil {
   423  		label._break = done
   424  		label._continue = loop
   425  	}
   426  	b.targets = &targets{
   427  		tail:      b.targets,
   428  		_break:    done,
   429  		_continue: loop,
   430  	}
   431  	b.stmt(s.Body)
   432  	b.targets = b.targets.tail
   433  	b.jump(loop) // back-edge
   434  	b.current = done
   435  }
   437  // -------- helpers --------
   439  // Destinations associated with unlabeled for/switch/select stmts.
   440  // We push/pop one of these as we enter/leave each construct and for
   441  // each BranchStmt we scan for the innermost target of the right type.
   442  //
   443  type targets struct {
   444  	tail         *targets // rest of stack
   445  	_break       *Block
   446  	_continue    *Block
   447  	_fallthrough *Block
   448  }
   450  // Destinations associated with a labeled block.
   451  // We populate these as labels are encountered in forward gotos or
   452  // labeled statements.
   453  //
   454  type lblock struct {
   455  	_goto     *Block
   456  	_break    *Block
   457  	_continue *Block
   458  }
   460  // labeledBlock returns the branch target associated with the
   461  // specified label, creating it if needed.
   462  //
   463  func (b *builder) labeledBlock(label *ast.Ident) *lblock {
   464  	lb := b.lblocks[label.Obj]
   465  	if lb == nil {
   466  		lb = &lblock{_goto: b.newBlock(label.Name)}
   467  		if b.lblocks == nil {
   468  			b.lblocks = make(map[*ast.Object]*lblock)
   469  		}
   470  		b.lblocks[label.Obj] = lb
   471  	}
   472  	return lb
   473  }
   475  // newBlock appends a new unconnected basic block to b.cfg's block
   476  // slice and returns it.
   477  // It does not automatically become the current block.
   478  // comment is an optional string for more readable debugging output.
   479  func (b *builder) newBlock(comment string) *Block {
   480  	g := b.cfg
   481  	block := &Block{
   482  		index:   int32(len(g.Blocks)),
   483  		comment: comment,
   484  	}
   485  	block.Succs = block.succs2[:0]
   486  	g.Blocks = append(g.Blocks, block)
   487  	return block
   488  }
   490  func (b *builder) newUnreachableBlock(comment string) *Block {
   491  	block := b.newBlock(comment)
   492  	block.unreachable = true
   493  	return block
   494  }
   496  func (b *builder) add(n ast.Node) {
   497  	b.current.Nodes = append(b.current.Nodes, n)
   498  }
   500  // jump adds an edge from the current block to the target block,
   501  // and sets b.current to nil.
   502  func (b *builder) jump(target *Block) {
   503  	b.current.Succs = append(b.current.Succs, target)
   504  	b.current = nil
   505  }
   507  // ifelse emits edges from the current block to the t and f blocks,
   508  // and sets b.current to nil.
   509  func (b *builder) ifelse(t, f *Block) {
   510  	b.current.Succs = append(b.current.Succs, t, f)
   511  	b.current = nil
   512  }