github.com/Konstantin8105/c4go@v0.0.0-20240505174241-768bb1c65a51/transpiler/switch.go (about)

     1  // This file contains functions for transpiling a "switch" statement.
     2  
     3  package transpiler
     4  
     5  import (
     6  	goast "go/ast"
     7  	"go/token"
     8  	"runtime/debug"
     9  
    10  	"fmt"
    11  
    12  	"github.com/Konstantin8105/c4go/ast"
    13  	"github.com/Konstantin8105/c4go/program"
    14  	"github.com/Konstantin8105/c4go/types"
    15  )
    16  
    17  // From :
    18  // CaseStmt
    19  // |-UnaryOperator 'int' prefix '-'
    20  // | `-IntegerLiteral 'int' 1
    21  // |-<<<NULL>>>
    22  // `-CaseStmt
    23  //
    24  //	|-BinaryOperator 'int' '-'
    25  //	| |- ...
    26  //	|-<<<NULL>>>
    27  //	`-CaseStmt
    28  //	  |- ...
    29  //	  |-<<<NULL>>>
    30  //	  `-DefaultStmt
    31  //	    `- ...
    32  //
    33  // To:
    34  // CaseStmt
    35  // |-UnaryOperator 'int' prefix '-'
    36  // | `-IntegerLiteral 'int' 1
    37  // `-<<<NULL>>>
    38  // <<<NULL>>>
    39  // CaseStmt
    40  // |-BinaryOperator 'int' '-'
    41  // | |- ...
    42  // `-<<<NULL>>>
    43  // <<<NULL>>>
    44  // CaseStmt
    45  // |- ...
    46  // |-<<<NULL>>>
    47  // `-DefaultStmt
    48  //
    49  //	`- ...
    50  //
    51  // <<<NULL>>>
    52  //
    53  // From:
    54  // |-CaseStmt
    55  // | `- ...
    56  // |-NullStmt
    57  // |-BreakStmt 0
    58  // |-CaseStmt
    59  // | `-...
    60  // To:
    61  // CaseStmt
    62  // `- ...
    63  // <<<NULL>>>
    64  // NullStmt
    65  // <<<NULL>>>
    66  // BreakStmt 0
    67  // <<<NULL>>>
    68  // CaseStmt
    69  // `-...
    70  // <<<NULL>>>
    71  //
    72  // From:
    73  // |-CaseStmt
    74  // | `- ...
    75  // |-CompoundAssignOperator  'int' '+=' ComputeLHSTy='int' ComputeResultTy='int'
    76  // | `-...
    77  // |-BreakStmt
    78  // |-CaseStmt
    79  // | |- ...
    80  // To:
    81  // `-CaseStmt
    82  //
    83  //	`- ...
    84  //
    85  // <<<NULL>>>
    86  // CompoundAssignOperator  'int' '+=' ComputeLHSTy='int' ComputeResultTy='int'
    87  // `-...
    88  // <<<NULL>>>
    89  // BreakStmt
    90  // <<<NULL>>>
    91  // CaseStmt
    92  // `- ...
    93  // <<<NULL>>>
    94  func caseSplitter(nodes ...ast.Node) (cs []ast.Node) {
    95  
    96  	if nodes == nil {
    97  		return
    98  	}
    99  	if len(nodes) == 0 {
   100  		return
   101  	}
   102  	if len(nodes) > 1 {
   103  		for i := range nodes {
   104  			cs = append(cs, caseSplitter(nodes[i])...)
   105  		}
   106  		return
   107  	}
   108  
   109  	node := nodes[0]
   110  	if node == nil {
   111  		return
   112  	}
   113  
   114  	var compountWithCase func(ast.Node) bool
   115  	compountWithCase = func(node ast.Node) bool {
   116  		if node == nil {
   117  			return false
   118  		}
   119  
   120  		if node == (*ast.CompoundStmt)(nil) {
   121  			return false
   122  		}
   123  
   124  		if len(node.Children()) == 0 {
   125  			return false
   126  		}
   127  
   128  		switch node.(type) {
   129  		case *ast.CaseStmt, *ast.DefaultStmt:
   130  			return true
   131  		}
   132  
   133  		for i, n := range node.Children() {
   134  			if _, ok := n.(*ast.SwitchStmt); ok {
   135  				continue
   136  			}
   137  			if compountWithCase(node.Children()[i]) {
   138  				return true
   139  			}
   140  		}
   141  
   142  		return false
   143  	}
   144  
   145  	var ns []ast.Node
   146  	switch node.(type) {
   147  	case *ast.CompoundStmt:
   148  		if compountWithCase(node) {
   149  			ns = node.(*ast.CompoundStmt).ChildNodes
   150  			node.(*ast.CompoundStmt).ChildNodes = nil
   151  		}
   152  
   153  	case *ast.CaseStmt:
   154  		ns = node.(*ast.CaseStmt).ChildNodes
   155  		node.(*ast.CaseStmt).ChildNodes = nil
   156  
   157  	case *ast.DefaultStmt:
   158  		ns = node.(*ast.DefaultStmt).ChildNodes
   159  		node.(*ast.DefaultStmt).ChildNodes = nil
   160  	}
   161  
   162  	cs = append(cs, node)
   163  	cs = append(cs, caseSplitter(ns...)...)
   164  
   165  	return
   166  }
   167  
   168  func caseMerge(nodes []ast.Node) (pre, cs []ast.Node) {
   169  	for i := range nodes {
   170  		// ignore empty *ast.CompountStmt
   171  		if comp, ok := nodes[i].(*ast.CompoundStmt); ok &&
   172  			(comp == (*ast.CompoundStmt)(nil) || len(comp.Children()) == 0) {
   173  			continue
   174  		}
   175  
   176  		var isCaseType bool
   177  
   178  		if _, ok := nodes[i].(*ast.CaseStmt); ok {
   179  			isCaseType = true
   180  		}
   181  		if _, ok := nodes[i].(*ast.DefaultStmt); ok {
   182  			isCaseType = true
   183  		}
   184  
   185  		if isCaseType {
   186  			cs = append(cs, nodes[i])
   187  			continue
   188  		}
   189  
   190  		if len(cs) == 0 {
   191  			pre = append(pre, nodes[i])
   192  			continue
   193  		}
   194  
   195  		cs[len(cs)-1].AddChild(nodes[i])
   196  	}
   197  
   198  	return
   199  }
   200  
   201  func transpileSwitchStmt(n *ast.SwitchStmt, p *program.Program) (
   202  	_ *goast.SwitchStmt, preStmts []goast.Stmt, postStmts []goast.Stmt, err error) {
   203  	defer func() {
   204  		if err != nil {
   205  			err = fmt.Errorf("cannot transpileSwitchStmt : err = %v", err)
   206  		}
   207  	}()
   208  
   209  	defer func() {
   210  		if r := recover(); r != nil {
   211  			err = fmt.Errorf("transpileSwitchStmt: error - panic: %s", string(debug.Stack()))
   212  		}
   213  	}()
   214  
   215  	// The first two children are nil. I don't know what they are supposed to be
   216  	// for. It looks like the number of children is also not reliable, but we
   217  	// know that we need the last two which represent the condition and body
   218  	// respectively.
   219  
   220  	if len(n.Children()) < 2 {
   221  		// I don't know what causes this condition. Need to investigate.
   222  		panic(fmt.Sprintf("Less than two children for switch: %#v", n))
   223  	}
   224  
   225  	// The condition is the expression to be evaluated against each of the
   226  	// cases.
   227  	condition, conditionType, newPre, newPost, err := atomicOperation(
   228  		n.Children()[len(n.Children())-2], p)
   229  	if err != nil {
   230  		return nil, nil, nil, err
   231  	}
   232  	if conditionType == "bool" {
   233  		condition, err = types.CastExpr(p, condition, conditionType, "int")
   234  		p.AddMessage(p.GenerateWarningMessage(err, n))
   235  	}
   236  
   237  	preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost)
   238  
   239  	// separation body of switch on cases
   240  	var body *ast.CompoundStmt
   241  	var ok bool
   242  	body, ok = n.Children()[len(n.Children())-1].(*ast.CompoundStmt)
   243  	if !ok {
   244  		body = &ast.CompoundStmt{}
   245  		body.AddChild(n.Children()[len(n.Children())-1])
   246  	}
   247  
   248  	// CompoundStmt
   249  	// `-CaseStmt
   250  	//   |-UnaryOperator 'int' prefix '-'
   251  	//   | `-IntegerLiteral 'int' 1
   252  	//   |-<<<NULL>>>
   253  	//   `-CaseStmt
   254  	//     |-BinaryOperator 'int' '-'
   255  	//     | |- ...
   256  	//     |-<<<NULL>>>
   257  	//     `-CaseStmt
   258  	//       |- ...
   259  	//       |-<<<NULL>>>
   260  	//       `-DefaultStmt
   261  	//         `- ...
   262  	if body == (*ast.CompoundStmt)(nil) {
   263  		body = &ast.CompoundStmt{}
   264  	}
   265  	parts := caseSplitter(body.Children()...)
   266  	pre, parts := caseMerge(parts)
   267  	body.ChildNodes = parts
   268  
   269  	if len(pre) > 0 {
   270  		stmt, newPre, newPost, err := transpileCompoundStmt(&ast.CompoundStmt{
   271  			ChildNodes: pre,
   272  		}, p)
   273  		p.AddMessage(p.GenerateWarningMessage(err, n))
   274  		preStmts = append(preStmts, newPre...)
   275  		preStmts = append(preStmts, stmt.List...)
   276  		preStmts = append(preStmts, newPost...)
   277  	}
   278  
   279  	// The body will always be a CompoundStmt because a switch statement is not
   280  	// valid without curly brackets.
   281  	cases, newPre, newPost, err := normalizeSwitchCases(body, p)
   282  	if err != nil {
   283  		return nil, nil, nil, err
   284  	}
   285  
   286  	preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost)
   287  
   288  	// TODO
   289  	//
   290  	// from:
   291  	// case 2:
   292  	//  {
   293  	//  }
   294  	//  fallthrough
   295  	// case 3:
   296  	//  fallthrough
   297  	// case 4:
   298  	//  break
   299  	// to:
   300  	// case 2,3:
   301  	//  fallthrough
   302  	// case 4:
   303  	//  break
   304  	//
   305  	// from:
   306  	// 		fallthrough
   307  	//		break
   308  	// to:
   309  	//		---
   310  
   311  	// cases with 2 nodes
   312  	// from :
   313  	//		{
   314  	//			...
   315  	//		}
   316  	//		break or fallthrough
   317  	// to:
   318  	//		...
   319  	//		break or fallthrough
   320  	for i := range cases {
   321  		body := cases[i].Body
   322  		if len(body) != 2 {
   323  			continue
   324  		}
   325  		var (
   326  			last    = body[len(body)-1]
   327  			prelast = body[len(body)-2]
   328  		)
   329  		br, ok := last.(*goast.BranchStmt)
   330  		if !ok || !(br.Tok == token.FALLTHROUGH || br.Tok == token.BREAK) {
   331  			continue
   332  		}
   333  		bl, ok := prelast.(*goast.BlockStmt)
   334  		if !ok {
   335  			continue
   336  		}
   337  
   338  		cases[i].Body = append(bl.List, br)
   339  	}
   340  
   341  	// from:
   342  	//		return
   343  	//		fallthrough
   344  	// to:
   345  	//		return
   346  	for i := range cases {
   347  		body := cases[i].Body
   348  		if len(body) < 2 {
   349  			continue
   350  		}
   351  		var (
   352  			last    = body[len(body)-1]
   353  			prelast = body[len(body)-2]
   354  		)
   355  		if br, ok := last.(*goast.BranchStmt); !ok || br.Tok != token.FALLTHROUGH {
   356  			continue
   357  		}
   358  		if _, ok := prelast.(*goast.ReturnStmt); !ok {
   359  			continue
   360  		}
   361  		cases[i].Body = cases[i].Body[:len(cases[i].Body)-1]
   362  	}
   363  
   364  	// from:
   365  	//		break
   366  	// 		fallthrough
   367  	// to:
   368  	//		---
   369  	for i := range cases {
   370  		body := cases[i].Body
   371  		if len(body) < 2 {
   372  			continue
   373  		}
   374  		var (
   375  			last    = body[len(body)-1]
   376  			prelast = body[len(body)-2]
   377  		)
   378  		if br, ok := last.(*goast.BranchStmt); !ok || br.Tok != token.FALLTHROUGH {
   379  			continue
   380  		}
   381  		if br, ok := prelast.(*goast.BranchStmt); !ok || br.Tok != token.BREAK {
   382  			continue
   383  		}
   384  		cases[i].Body = cases[i].Body[:len(cases[i].Body)-2]
   385  	}
   386  
   387  	// cases with 1 node
   388  	// from :
   389  	//		{
   390  	//			...
   391  	//		}
   392  	// to:
   393  	//		...
   394  	for i := range cases {
   395  		body := cases[i].Body
   396  		if len(body) != 1 {
   397  			continue
   398  		}
   399  		var (
   400  			last = body[len(body)-1]
   401  		)
   402  		bl, ok := last.(*goast.BlockStmt)
   403  		if !ok {
   404  			continue
   405  		}
   406  
   407  		cases[i].Body = bl.List
   408  	}
   409  
   410  	// Convert the normalized cases back into statements so they can be children
   411  	// of goast.SwitchStmt.
   412  	stmts := []goast.Stmt{}
   413  	for _, singleCase := range cases {
   414  		if singleCase == nil {
   415  			panic("nil single case")
   416  		}
   417  
   418  		stmts = append(stmts, singleCase)
   419  	}
   420  
   421  	return &goast.SwitchStmt{
   422  		Tag: condition,
   423  		Body: &goast.BlockStmt{
   424  			List: stmts,
   425  		},
   426  	}, preStmts, postStmts, nil
   427  }
   428  
   429  func normalizeSwitchCases(body *ast.CompoundStmt, p *program.Program) (
   430  	_ []*goast.CaseClause, preStmts []goast.Stmt, postStmts []goast.Stmt, err error) {
   431  	// The body of a switch has a non uniform structure. For example:
   432  	//
   433  	//     switch a {
   434  	//     case 1:
   435  	//         foo();
   436  	//         bar();
   437  	//         break;
   438  	//     default:
   439  	//         baz();
   440  	//         qux();
   441  	//     }
   442  	//
   443  	// Looks like:
   444  	//
   445  	//     *ast.CompountStmt
   446  	//         *ast.CaseStmt     // case 1:
   447  	//             *ast.CallExpr //     foo()
   448  	//         *ast.CallExpr     //     bar()
   449  	//         *ast.BreakStmt    //     break
   450  	//         *ast.DefaultStmt  // default:
   451  	//             *ast.CallExpr //     baz()
   452  	//         *ast.CallExpr     //     qux()
   453  	//
   454  	// Each of the cases contains one child that is the first statement, but all
   455  	// the rest are children of the parent CompountStmt. This makes it
   456  	// especially tricky when we want to remove the 'break' or add a
   457  	// 'fallthrough'.
   458  	//
   459  	// To make it easier we normalise the cases. This means that we iterate
   460  	// through all of the statements of the CompountStmt and merge any children
   461  	// that are not 'case' or 'break' with the previous node to give us a
   462  	// structure like:
   463  	//
   464  	//     []*goast.CaseClause
   465  	//         *goast.CaseClause      // case 1:
   466  	//             *goast.CallExpr    //     foo()
   467  	//             *goast.CallExpr    //     bar()
   468  	//             // *ast.BreakStmt  //     break (was removed)
   469  	//         *goast.CaseClause      // default:
   470  	//             *goast.CallExpr    //     baz()
   471  	//             *goast.CallExpr    //     qux()
   472  	//
   473  	// During this translation we also remove 'break' or append a 'fallthrough'.
   474  
   475  	cases := []*goast.CaseClause{}
   476  
   477  	for _, x := range body.Children() {
   478  		switch c := x.(type) {
   479  		case *ast.CaseStmt, *ast.DefaultStmt:
   480  			var newPre, newPost []goast.Stmt
   481  			cases, newPre, newPost, err = appendCaseOrDefaultToNormalizedCases(cases, c, p)
   482  			if err != nil {
   483  				return []*goast.CaseClause{}, nil, nil, err
   484  			}
   485  
   486  			preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost)
   487  		case *ast.BreakStmt:
   488  
   489  		default:
   490  			var stmt goast.Stmt
   491  			var newPre, newPost []goast.Stmt
   492  			stmt, newPre, newPost, err = transpileToStmt(x, p)
   493  			if err != nil {
   494  				return []*goast.CaseClause{}, nil, nil, err
   495  			}
   496  			preStmts = append(preStmts, newPre...)
   497  			preStmts = append(preStmts, stmt)
   498  			preStmts = append(preStmts, newPost...)
   499  		}
   500  	}
   501  
   502  	return cases, preStmts, postStmts, nil
   503  }
   504  
   505  func appendCaseOrDefaultToNormalizedCases(cases []*goast.CaseClause,
   506  	stmt ast.Node, p *program.Program) (
   507  	[]*goast.CaseClause, []goast.Stmt, []goast.Stmt, error) {
   508  	preStmts := []goast.Stmt{}
   509  	postStmts := []goast.Stmt{}
   510  
   511  	if len(cases) > 0 {
   512  		cases[len(cases)-1].Body = append(cases[len(cases)-1].Body, &goast.BranchStmt{
   513  			Tok: token.FALLTHROUGH,
   514  		})
   515  	}
   516  
   517  	var singleCase *goast.CaseClause
   518  	var err error
   519  	var newPre []goast.Stmt
   520  	var newPost []goast.Stmt
   521  
   522  	switch c := stmt.(type) {
   523  	case *ast.CaseStmt:
   524  		singleCase, newPre, newPost, err = transpileCaseStmt(c, p)
   525  
   526  	case *ast.DefaultStmt:
   527  		singleCase, err = transpileDefaultStmt(c, p)
   528  	}
   529  
   530  	if singleCase != nil {
   531  		cases = append(cases, singleCase)
   532  	}
   533  
   534  	preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost)
   535  
   536  	if err != nil {
   537  		return []*goast.CaseClause{}, nil, nil, err
   538  	}
   539  
   540  	return cases, preStmts, postStmts, nil
   541  }
   542  
   543  func transpileCaseStmt(n *ast.CaseStmt, p *program.Program) (
   544  	_ *goast.CaseClause, _ []goast.Stmt, _ []goast.Stmt, err error) {
   545  	defer func() {
   546  		if err != nil {
   547  			err = fmt.Errorf("cannot transpileCaseStmt: %v", err)
   548  		}
   549  	}()
   550  	preStmts := []goast.Stmt{}
   551  	postStmts := []goast.Stmt{}
   552  
   553  	c, cType, newPre, newPost, err := transpileToExpr(n.Children()[0], p, false)
   554  	if err != nil {
   555  		return nil, nil, nil, err
   556  	}
   557  	if cType == "bool" {
   558  		c, err = types.CastExpr(p, c, cType, "int")
   559  		p.AddMessage(p.GenerateWarningMessage(err, n))
   560  	}
   561  	preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost)
   562  
   563  	stmts, err := transpileStmts(n.Children()[1:], p)
   564  	if err != nil {
   565  		return nil, nil, nil, err
   566  	}
   567  
   568  	return &goast.CaseClause{
   569  		List: []goast.Expr{c},
   570  		Body: stmts,
   571  	}, preStmts, postStmts, nil
   572  }
   573  
   574  func transpileDefaultStmt(n *ast.DefaultStmt, p *program.Program) (
   575  	*goast.CaseClause, error) {
   576  
   577  	stmts, err := transpileStmts(n.Children()[0:], p)
   578  	if err != nil {
   579  		return nil, err
   580  	}
   581  
   582  	return &goast.CaseClause{
   583  		List: nil,
   584  		Body: stmts,
   585  	}, nil
   586  }