github.com/gotranspile/cxgo@v0.3.7/flow.go (about)

     1  package cxgo
     2  
     3  import (
     4  	"fmt"
     5  )
     6  
     7  type Block interface {
     8  	AddPrevBlock(b2 Block)
     9  	PrevBlocks() []Block
    10  	NextBlocks() []Block
    11  	ReplaceNext(old, rep Block)
    12  }
    13  
    14  func replacePrev(v Block, p, to Block) {
    15  	if v == nil {
    16  		return
    17  	}
    18  	prev := v.PrevBlocks()
    19  	found := false
    20  	for i, p2 := range prev {
    21  		if p == p2 {
    22  			prev[i] = to
    23  			found = true
    24  		}
    25  	}
    26  	if !found {
    27  		panic("not found")
    28  	}
    29  }
    30  
    31  func replaceBlock(a, b Block) {
    32  	if a == b {
    33  		return
    34  	}
    35  	for _, p := range a.PrevBlocks() {
    36  		if p == a {
    37  			b.AddPrevBlock(b)
    38  		} else {
    39  			b.AddPrevBlock(p)
    40  			p.ReplaceNext(a, b)
    41  		}
    42  	}
    43  }
    44  
    45  func (g *translator) NewControlFlow(stmts []CStmt) *ControlFlow {
    46  	cf := &ControlFlow{
    47  		g:      g,
    48  		labels: make(map[string]Block),
    49  	}
    50  	cf.Start, _ = cf.process(stmts, nil)
    51  	cf.buildDoms()
    52  	return cf
    53  }
    54  
    55  type ControlFlow struct {
    56  	g      *translator
    57  	Start  Block
    58  	labels map[string]Block
    59  	breaks []Block
    60  	conts  []Block
    61  	doms   map[Block]BlockSet
    62  }
    63  
    64  func (cf *ControlFlow) eachBlock(fnc func(b Block)) {
    65  	cf.eachBlockSub(cf.Start, fnc, make(BlockSet))
    66  }
    67  
    68  func (cf *ControlFlow) allBlocks() BlockSet {
    69  	m := make(BlockSet)
    70  	cf.eachBlockSub(cf.Start, nil, m)
    71  	return m
    72  }
    73  
    74  func (cf *ControlFlow) eachBlockSub(b Block, fnc func(b Block), seen BlockSet) {
    75  	if b == nil {
    76  		return
    77  	}
    78  	if _, ok := seen[b]; ok {
    79  		return
    80  	}
    81  	seen[b] = struct{}{}
    82  	if fnc != nil {
    83  		fnc(b)
    84  	}
    85  	for _, b2 := range b.NextBlocks() {
    86  		cf.eachBlockSub(b2, fnc, seen)
    87  	}
    88  }
    89  
    90  func (cf *ControlFlow) process(stmts []CStmt, after Block) (Block, bool) {
    91  	if len(stmts) == 0 {
    92  		return after, false
    93  	}
    94  	switch s := stmts[0].(type) {
    95  	case *BlockStmt:
    96  		if len(stmts) == 1 {
    97  			return cf.process(s.Stmts, after)
    98  		}
    99  		next, _ := cf.process(stmts[1:], after)
   100  		return cf.process(s.Stmts, next)
   101  	case *CReturnStmt:
   102  		_, _ = cf.process(stmts[1:], after)
   103  		return &ReturnBlock{
   104  			CReturnStmt: s,
   105  		}, false
   106  	case *CLabelStmt:
   107  		name := s.Label
   108  		next, _ := cf.process(stmts[1:], after)
   109  		if next == nil {
   110  			panic("must not be nil")
   111  		}
   112  		if tmp := cf.labels[name]; tmp != nil {
   113  			replaceBlock(tmp, next)
   114  		}
   115  		cf.labels[name] = next
   116  		return next, false
   117  	case *CGotoStmt:
   118  		_, _ = cf.process(stmts[1:], after)
   119  		name := s.Label
   120  		b, ok := cf.labels[name]
   121  		if !ok {
   122  			b = &CodeBlock{}
   123  			cf.labels[name] = b
   124  		}
   125  		return b, false
   126  	case *CContinueStmt:
   127  		_, _ = cf.process(stmts[1:], after)
   128  		return cf.conts[len(cf.conts)-1], false
   129  	case *CBreakStmt:
   130  		_, _ = cf.process(stmts[1:], after)
   131  		return cf.breaks[len(cf.breaks)-1], false
   132  	//case *CFallthroughStmt:
   133  	//	_, _ = cf.process(stmts[1:], after)
   134  	//	return cf.falls[len(cf.falls)-1], false
   135  	case *CIfStmt:
   136  		next, _ := cf.process(stmts[1:], after)
   137  
   138  		then, _ := cf.process(s.Then.Stmts, next)
   139  		els := next
   140  		if s.Else != nil {
   141  			els, _ = cf.process([]CStmt{s.Else}, next)
   142  		}
   143  		return NewCondBlock(s.Cond, then, els), false
   144  	case *CForStmt:
   145  		// continue target, temporary
   146  		cont := &CodeBlock{}
   147  		// break target
   148  		brk, _ := cf.process(stmts[1:], after)
   149  
   150  		// put break/continue blocks to the stack
   151  		bi := len(cf.breaks)
   152  		ci := len(cf.conts)
   153  		cf.breaks = append(cf.breaks, brk)
   154  		cf.conts = append(cf.conts, cont)
   155  		// process the body assuming those break/continue blocks
   156  		body, _ := cf.process(s.Body.Stmts, cont)
   157  		// restore the break/continue stack
   158  		cf.breaks = cf.breaks[:bi]
   159  		cf.conts = cf.conts[:ci]
   160  		// if loop is empty - set an empty body
   161  		if body == nil {
   162  			body = &CodeBlock{Next: cont}
   163  			cont.AddPrevBlock(body)
   164  		}
   165  
   166  		// loop with no condition should return to the beginning of the body
   167  		loop := body
   168  		if s.Cond != nil {
   169  			// loop with the condition returns to it instead of the body
   170  			// we invert the expression and break in the positive if branch
   171  			loop = NewCondBlock(cf.g.cNot(s.Cond), brk, body)
   172  		}
   173  
   174  		// loop without an iter should continue to the beginning of the loop body
   175  		cont.Next = loop
   176  		if s.Iter != nil {
   177  			// inline iter into the temporary continue block
   178  			// do not remove the temporary block - it's now permanent
   179  			cont.Stmts = append(cont.Stmts, s.Iter)
   180  			loop.AddPrevBlock(cont)
   181  		} else {
   182  			if loop == cont {
   183  				// fix for an empty loop body
   184  				loop.AddPrevBlock(loop)
   185  			}
   186  			replaceBlock(cont, loop)
   187  		}
   188  
   189  		if s.Init == nil {
   190  			// no init - start from the loop cond/body
   191  			return loop, false
   192  		}
   193  		// start from the init, then continue to the cond/body
   194  		in := &CodeBlock{Stmts: []CStmt{s.Init}}
   195  		in.Next = loop
   196  		loop.AddPrevBlock(in)
   197  		return in, false
   198  	case *CCaseStmt:
   199  		panic("must not contain cases")
   200  	case *CSwitchStmt:
   201  		// TODO: add fallthrough statement and rewrite this
   202  		next, _ := cf.process(stmts[1:], after)
   203  
   204  		b := &SwitchBlock{
   205  			Expr:   s.Cond,
   206  			Cases:  make([]Expr, len(s.Cases)),
   207  			Blocks: make([]Block, len(s.Cases)),
   208  		}
   209  		bi := len(cf.breaks)
   210  		cf.breaks = append(cf.breaks, next)
   211  		// process backward because we need to handle falltrough
   212  		hasDef := false
   213  		fall := next
   214  		for i := len(s.Cases) - 1; i >= 0; i-- {
   215  			c := s.Cases[i]
   216  			if c.Expr == nil {
   217  				hasDef = true
   218  			}
   219  			b.Cases[i] = c.Expr
   220  			cb, _ := cf.process(c.Stmts, fall)
   221  			b.Blocks[i] = cb
   222  			cb.AddPrevBlock(b)
   223  			fall = cb
   224  		}
   225  		cf.breaks = cf.breaks[:bi]
   226  		if !hasDef {
   227  			b.Cases = append(b.Cases, nil)
   228  			b.Blocks = append(b.Blocks, next)
   229  			next.AddPrevBlock(b)
   230  		}
   231  		return b, false
   232  	}
   233  	if after == nil {
   234  		after = &ReturnBlock{CReturnStmt: &CReturnStmt{}}
   235  	}
   236  	b := &CodeBlock{}
   237  	b.Stmts = append(b.Stmts, stmts[0])
   238  	b2, merge := cf.process(stmts[1:], after)
   239  	if c2, ok := b2.(*CodeBlock); ok && merge {
   240  		b.Stmts = append(b.Stmts, c2.Stmts...)
   241  		b.Next = c2.Next
   242  		if b.Next != nil {
   243  			replacePrev(b.Next, c2, b)
   244  		} else {
   245  			b.Next = after
   246  			after.AddPrevBlock(b)
   247  		}
   248  		return b, true
   249  	}
   250  	if b2 == nil {
   251  		b2 = after
   252  	}
   253  	b.Next = b2
   254  	b.Next.AddPrevBlock(b)
   255  	return b, true
   256  }
   257  
   258  type BaseBlock struct {
   259  	prev []Block
   260  }
   261  
   262  func (b *BaseBlock) AddPrevBlock(b2 Block) {
   263  	for _, b := range b.prev {
   264  		if b == b2 {
   265  			return
   266  		}
   267  	}
   268  	b.prev = append(b.prev, b2)
   269  }
   270  func (b *BaseBlock) PrevBlocks() []Block {
   271  	return b.prev
   272  }
   273  
   274  type CodeBlock struct {
   275  	BaseBlock
   276  	Stmts []CStmt
   277  	Next  Block
   278  }
   279  
   280  func (b *CodeBlock) NextBlocks() []Block {
   281  	if b.Next == nil {
   282  		// FIXME
   283  		//panic("no following block")
   284  		return nil
   285  	}
   286  	return []Block{b.Next}
   287  }
   288  func (b *CodeBlock) ReplaceNext(old, rep Block) {
   289  	if b.Next == old {
   290  		b.Next = rep
   291  	}
   292  }
   293  
   294  func NewCondBlock(expr Expr, then, els Block) *CondBlock {
   295  	if then == nil || els == nil {
   296  		panic("both branches must be set")
   297  	}
   298  	b := &CondBlock{
   299  		Expr: expr,
   300  		Then: then,
   301  		Else: els,
   302  	}
   303  	b.Then.AddPrevBlock(b)
   304  	b.Else.AddPrevBlock(b)
   305  	return b
   306  }
   307  
   308  type CondBlock struct {
   309  	BaseBlock
   310  	Expr Expr
   311  	Then Block
   312  	Else Block
   313  }
   314  
   315  func (b *CondBlock) NextBlocks() []Block {
   316  	if b.Then == nil || b.Else == nil {
   317  		panic("no following block")
   318  	}
   319  	return []Block{b.Then, b.Else}
   320  }
   321  func (b *CondBlock) ReplaceNext(old, rep Block) {
   322  	if b.Then == old {
   323  		b.Then = rep
   324  	}
   325  	if b.Else == old {
   326  		b.Else = rep
   327  	}
   328  }
   329  
   330  type SwitchBlock struct {
   331  	BaseBlock
   332  	Expr   Expr
   333  	Cases  []Expr
   334  	Blocks []Block
   335  }
   336  
   337  func (b *SwitchBlock) NextBlocks() []Block {
   338  	return append([]Block{}, b.Blocks...)
   339  }
   340  func (b *SwitchBlock) ReplaceNext(old, rep Block) {
   341  	for i, p := range b.Blocks {
   342  		if p == old {
   343  			b.Blocks[i] = rep
   344  		}
   345  	}
   346  }
   347  
   348  type ReturnBlock struct {
   349  	BaseBlock
   350  	*CReturnStmt
   351  }
   352  
   353  func (b *ReturnBlock) NextBlocks() []Block {
   354  	return nil
   355  }
   356  func (b *ReturnBlock) ReplaceNext(old, rep Block) {}
   357  
   358  func (cf *ControlFlow) Dom(a, b Block) bool {
   359  	if a == b {
   360  		return true
   361  	}
   362  	_, ok := cf.doms[b][a]
   363  	return ok
   364  }
   365  
   366  func (cf *ControlFlow) SDom(a, b Block) bool {
   367  	if a == b {
   368  		// strict dominance - nodes shouldn't be the same
   369  		return false
   370  	}
   371  	return cf.Dom(a, b)
   372  }
   373  
   374  func (cf *ControlFlow) IDom(a Block) Block {
   375  	var doms []Block
   376  	for d := range cf.doms[a] {
   377  		if d != a {
   378  			doms = append(doms, d)
   379  		}
   380  	}
   381  	for len(doms) > 1 {
   382  	loop:
   383  		for i := 0; i < len(doms); i++ {
   384  			d1 := doms[i]
   385  			for j := 0; j < len(doms); j++ {
   386  				if i == j {
   387  					continue
   388  				}
   389  				d2 := doms[j]
   390  				if _, ok := cf.doms[d2][d1]; ok {
   391  					doms = append(doms[:i], doms[i+1:]...)
   392  					i--
   393  					continue loop
   394  				}
   395  			}
   396  		}
   397  	}
   398  	if len(doms) == 1 {
   399  		return doms[0]
   400  	}
   401  	return nil
   402  }
   403  
   404  type BlockSet map[Block]struct{}
   405  
   406  func (b BlockSet) Clone() BlockSet {
   407  	b2 := make(BlockSet, len(b))
   408  	for k := range b {
   409  		b2[k] = struct{}{}
   410  	}
   411  	return b2
   412  }
   413  
   414  func (b BlockSet) Union(b2 BlockSet) BlockSet {
   415  	if len(b) < len(b2) {
   416  		b, b2 = b2, b
   417  	}
   418  	all := true
   419  	for k := range b2 {
   420  		if _, ok := b[k]; !ok {
   421  			all = false
   422  			break
   423  		}
   424  	}
   425  	if all {
   426  		return b
   427  	}
   428  	m := make(BlockSet, len(b))
   429  	for k := range b {
   430  		m[k] = struct{}{}
   431  	}
   432  	for k := range b2 {
   433  		m[k] = struct{}{}
   434  	}
   435  	return m
   436  }
   437  
   438  func (b BlockSet) Intersect(b2 BlockSet) BlockSet {
   439  	if len(b) > len(b2) {
   440  		b, b2 = b2, b
   441  	}
   442  	m := make(BlockSet, len(b))
   443  	for k := range b {
   444  		if _, ok := b2[k]; ok {
   445  			m[k] = struct{}{}
   446  		}
   447  	}
   448  	return m
   449  }
   450  
   451  func (b BlockSet) Contains(b2 BlockSet) bool {
   452  	if len(b) < len(b2) {
   453  		return false
   454  	}
   455  	for k := range b2 {
   456  		if _, ok := b[k]; !ok {
   457  			return false
   458  		}
   459  	}
   460  	return true
   461  }
   462  
   463  func (cf *ControlFlow) buildDoms() {
   464  	// TODO: quadratic complexity! use a different algorithm
   465  	blocks := cf.allBlocks()
   466  	cf.doms = make(map[Block]BlockSet, len(blocks))
   467  	for b := range blocks {
   468  		if b == cf.Start {
   469  			cf.doms[b] = BlockSet{b: {}}
   470  		} else {
   471  			cf.doms[b] = blocks
   472  		}
   473  	}
   474  
   475  	changes := true
   476  	for changes {
   477  		changes = false
   478  		for b := range blocks {
   479  			d := cf.doms[b]
   480  			if len(d) == 1 {
   481  				continue
   482  			}
   483  			var m BlockSet
   484  			for _, b2 := range b.PrevBlocks() {
   485  				if b == b2 {
   486  					continue
   487  				}
   488  				d2 := cf.doms[b2]
   489  				if m == nil {
   490  					m = d2
   491  					continue
   492  				}
   493  				m = m.Intersect(d2)
   494  				if len(m) == 0 {
   495  					break
   496  				}
   497  			}
   498  			m = m.Union(BlockSet{b: {}})
   499  			if len(d) != len(m) {
   500  				changes = true
   501  				cf.doms[b] = m
   502  			}
   503  			if len(d) < len(m) {
   504  				panic(fmt.Errorf("set size increased: %p: %v vs %v", b, d, m))
   505  			}
   506  		}
   507  	}
   508  }
   509  
   510  type varDecls struct {
   511  	Decls []*CVarDecl
   512  }
   513  
   514  func (cf *ControlFlow) Flatten() []CStmt {
   515  	labels := make(map[Block]int)
   516  	cf.eachBlock(func(b Block) {
   517  		prev := b.PrevBlocks()
   518  		if len(prev) == 0 {
   519  			return
   520  		}
   521  		labels[b] = len(labels) + 1
   522  	})
   523  	var decls varDecls
   524  	stmts := cf.flatten(cf.Start, &decls, nil, labels, make(map[Block]struct{}))
   525  	var out []CStmt
   526  	if len(decls.Decls) != 0 {
   527  		for _, d := range decls.Decls {
   528  			out = append(out, &CDeclStmt{Decl: d})
   529  		}
   530  	}
   531  	out = append(out, stmts...)
   532  	return out
   533  }
   534  
   535  func numLabelName(n int) string {
   536  	return fmt.Sprintf("L_%d", n)
   537  }
   538  
   539  func numLabel(n int) *CLabelStmt {
   540  	return &CLabelStmt{Label: numLabelName(n)}
   541  }
   542  
   543  func numGoto(n int) *CGotoStmt {
   544  	return &CGotoStmt{Label: numLabelName(n)}
   545  }
   546  
   547  func (cf *ControlFlow) slitDecls(decl *varDecls, stmts []CStmt) []CStmt {
   548  	out := make([]CStmt, 0, len(stmts))
   549  	for _, st := range stmts {
   550  		ds, ok := st.(*CDeclStmt)
   551  		if !ok {
   552  			out = append(out, st)
   553  			continue
   554  		}
   555  		d, ok := ds.Decl.(*CVarDecl)
   556  		if !ok {
   557  			out = append(out, st)
   558  			continue
   559  		} else if len(d.Inits) == 0 || d.Const {
   560  			out = append(out, st)
   561  			continue
   562  		} else if d.Names[0].Name == "__func__" {
   563  			out = append(out, st)
   564  			continue
   565  		}
   566  		decl.Decls = append(decl.Decls, &CVarDecl{
   567  			Const:  d.Const,
   568  			Single: d.Single,
   569  			CVarSpec: CVarSpec{
   570  				g:     d.g,
   571  				Type:  d.Type,
   572  				Names: d.Names,
   573  			},
   574  		})
   575  		for j, val := range d.Inits {
   576  			if val != nil {
   577  				out = append(out, d.g.NewCAssignStmt(IdentExpr{d.Names[j]}, "", val)...)
   578  			}
   579  		}
   580  	}
   581  	return out
   582  }
   583  
   584  func (cf *ControlFlow) flatten(b Block, decl *varDecls, stmts []CStmt, labels map[Block]int, seen map[Block]struct{}) []CStmt {
   585  	if b == nil {
   586  		return stmts
   587  	}
   588  	if _, ok := seen[b]; ok {
   589  		l, ok := labels[b]
   590  		if !ok {
   591  			panic(fmt.Errorf("must have a label: %T, %v", b, len(b.PrevBlocks())))
   592  		}
   593  		stmts = append(stmts, numGoto(l))
   594  		return stmts
   595  	}
   596  	seen[b] = struct{}{}
   597  	if id, ok := labels[b]; ok {
   598  		stmts = append(stmts, numLabel(id))
   599  	}
   600  	switch b := b.(type) {
   601  	case *CodeBlock:
   602  		cur := cf.slitDecls(decl, b.Stmts)
   603  		stmts = append(stmts, cur...)
   604  		if l, ok := labels[b.Next]; ok {
   605  			stmts = append(stmts, numGoto(l))
   606  			if _, ok := seen[b.Next]; ok {
   607  				return stmts
   608  			}
   609  		}
   610  		stmts = cf.flatten(b.Next, decl, stmts, labels, seen)
   611  	case *CondBlock:
   612  		then, ok := labels[b.Then]
   613  		if !ok {
   614  			panic("must have a label")
   615  		}
   616  		els, ok := labels[b.Else]
   617  		if !ok {
   618  			panic("must have a label")
   619  		}
   620  		stmts = append(stmts, &CIfStmt{
   621  			Cond: cf.g.ToBool(b.Expr),
   622  			Then: cf.g.NewCBlock(numGoto(then)),
   623  			Else: cf.g.NewCBlock(numGoto(els)),
   624  		})
   625  		if _, ok := seen[b.Then]; !ok {
   626  			stmts = cf.flatten(b.Then, decl, stmts, labels, seen)
   627  		}
   628  		if _, ok := seen[b.Else]; !ok {
   629  			stmts = cf.flatten(b.Else, decl, stmts, labels, seen)
   630  		}
   631  	case *ReturnBlock:
   632  		stmts = append(stmts, b.CReturnStmt)
   633  	case *SwitchBlock:
   634  		s := &CSwitchStmt{
   635  			Cond: b.Expr,
   636  		}
   637  		for i, e := range b.Cases {
   638  			l, ok := labels[b.Blocks[i]]
   639  			if !ok {
   640  				panic("must have a label")
   641  			}
   642  			s.Cases = append(s.Cases, cf.g.NewCaseStmt(
   643  				e, numGoto(l),
   644  			))
   645  		}
   646  		stmts = append(stmts, s)
   647  		for _, b := range b.Blocks {
   648  			if _, ok := seen[b]; !ok {
   649  				stmts = cf.flatten(b, decl, stmts, labels, seen)
   650  			}
   651  		}
   652  	default:
   653  		panic(b)
   654  	}
   655  	return stmts
   656  }