github.com/Konstantin8105/c4go@v0.0.0-20240505174241-768bb1c65a51/transpiler/switch.go (about) 1 // This file contains functions for transpiling a "switch" statement. 2 3 package transpiler 4 5 import ( 6 goast "go/ast" 7 "go/token" 8 "runtime/debug" 9 10 "fmt" 11 12 "github.com/Konstantin8105/c4go/ast" 13 "github.com/Konstantin8105/c4go/program" 14 "github.com/Konstantin8105/c4go/types" 15 ) 16 17 // From : 18 // CaseStmt 19 // |-UnaryOperator 'int' prefix '-' 20 // | `-IntegerLiteral 'int' 1 21 // |-<<<NULL>>> 22 // `-CaseStmt 23 // 24 // |-BinaryOperator 'int' '-' 25 // | |- ... 26 // |-<<<NULL>>> 27 // `-CaseStmt 28 // |- ... 29 // |-<<<NULL>>> 30 // `-DefaultStmt 31 // `- ... 32 // 33 // To: 34 // CaseStmt 35 // |-UnaryOperator 'int' prefix '-' 36 // | `-IntegerLiteral 'int' 1 37 // `-<<<NULL>>> 38 // <<<NULL>>> 39 // CaseStmt 40 // |-BinaryOperator 'int' '-' 41 // | |- ... 42 // `-<<<NULL>>> 43 // <<<NULL>>> 44 // CaseStmt 45 // |- ... 46 // |-<<<NULL>>> 47 // `-DefaultStmt 48 // 49 // `- ... 50 // 51 // <<<NULL>>> 52 // 53 // From: 54 // |-CaseStmt 55 // | `- ... 56 // |-NullStmt 57 // |-BreakStmt 0 58 // |-CaseStmt 59 // | `-... 60 // To: 61 // CaseStmt 62 // `- ... 63 // <<<NULL>>> 64 // NullStmt 65 // <<<NULL>>> 66 // BreakStmt 0 67 // <<<NULL>>> 68 // CaseStmt 69 // `-... 70 // <<<NULL>>> 71 // 72 // From: 73 // |-CaseStmt 74 // | `- ... 75 // |-CompoundAssignOperator 'int' '+=' ComputeLHSTy='int' ComputeResultTy='int' 76 // | `-... 77 // |-BreakStmt 78 // |-CaseStmt 79 // | |- ... 80 // To: 81 // `-CaseStmt 82 // 83 // `- ... 84 // 85 // <<<NULL>>> 86 // CompoundAssignOperator 'int' '+=' ComputeLHSTy='int' ComputeResultTy='int' 87 // `-... 88 // <<<NULL>>> 89 // BreakStmt 90 // <<<NULL>>> 91 // CaseStmt 92 // `- ... 93 // <<<NULL>>> 94 func caseSplitter(nodes ...ast.Node) (cs []ast.Node) { 95 96 if nodes == nil { 97 return 98 } 99 if len(nodes) == 0 { 100 return 101 } 102 if len(nodes) > 1 { 103 for i := range nodes { 104 cs = append(cs, caseSplitter(nodes[i])...) 105 } 106 return 107 } 108 109 node := nodes[0] 110 if node == nil { 111 return 112 } 113 114 var compountWithCase func(ast.Node) bool 115 compountWithCase = func(node ast.Node) bool { 116 if node == nil { 117 return false 118 } 119 120 if node == (*ast.CompoundStmt)(nil) { 121 return false 122 } 123 124 if len(node.Children()) == 0 { 125 return false 126 } 127 128 switch node.(type) { 129 case *ast.CaseStmt, *ast.DefaultStmt: 130 return true 131 } 132 133 for i, n := range node.Children() { 134 if _, ok := n.(*ast.SwitchStmt); ok { 135 continue 136 } 137 if compountWithCase(node.Children()[i]) { 138 return true 139 } 140 } 141 142 return false 143 } 144 145 var ns []ast.Node 146 switch node.(type) { 147 case *ast.CompoundStmt: 148 if compountWithCase(node) { 149 ns = node.(*ast.CompoundStmt).ChildNodes 150 node.(*ast.CompoundStmt).ChildNodes = nil 151 } 152 153 case *ast.CaseStmt: 154 ns = node.(*ast.CaseStmt).ChildNodes 155 node.(*ast.CaseStmt).ChildNodes = nil 156 157 case *ast.DefaultStmt: 158 ns = node.(*ast.DefaultStmt).ChildNodes 159 node.(*ast.DefaultStmt).ChildNodes = nil 160 } 161 162 cs = append(cs, node) 163 cs = append(cs, caseSplitter(ns...)...) 164 165 return 166 } 167 168 func caseMerge(nodes []ast.Node) (pre, cs []ast.Node) { 169 for i := range nodes { 170 // ignore empty *ast.CompountStmt 171 if comp, ok := nodes[i].(*ast.CompoundStmt); ok && 172 (comp == (*ast.CompoundStmt)(nil) || len(comp.Children()) == 0) { 173 continue 174 } 175 176 var isCaseType bool 177 178 if _, ok := nodes[i].(*ast.CaseStmt); ok { 179 isCaseType = true 180 } 181 if _, ok := nodes[i].(*ast.DefaultStmt); ok { 182 isCaseType = true 183 } 184 185 if isCaseType { 186 cs = append(cs, nodes[i]) 187 continue 188 } 189 190 if len(cs) == 0 { 191 pre = append(pre, nodes[i]) 192 continue 193 } 194 195 cs[len(cs)-1].AddChild(nodes[i]) 196 } 197 198 return 199 } 200 201 func transpileSwitchStmt(n *ast.SwitchStmt, p *program.Program) ( 202 _ *goast.SwitchStmt, preStmts []goast.Stmt, postStmts []goast.Stmt, err error) { 203 defer func() { 204 if err != nil { 205 err = fmt.Errorf("cannot transpileSwitchStmt : err = %v", err) 206 } 207 }() 208 209 defer func() { 210 if r := recover(); r != nil { 211 err = fmt.Errorf("transpileSwitchStmt: error - panic: %s", string(debug.Stack())) 212 } 213 }() 214 215 // The first two children are nil. I don't know what they are supposed to be 216 // for. It looks like the number of children is also not reliable, but we 217 // know that we need the last two which represent the condition and body 218 // respectively. 219 220 if len(n.Children()) < 2 { 221 // I don't know what causes this condition. Need to investigate. 222 panic(fmt.Sprintf("Less than two children for switch: %#v", n)) 223 } 224 225 // The condition is the expression to be evaluated against each of the 226 // cases. 227 condition, conditionType, newPre, newPost, err := atomicOperation( 228 n.Children()[len(n.Children())-2], p) 229 if err != nil { 230 return nil, nil, nil, err 231 } 232 if conditionType == "bool" { 233 condition, err = types.CastExpr(p, condition, conditionType, "int") 234 p.AddMessage(p.GenerateWarningMessage(err, n)) 235 } 236 237 preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost) 238 239 // separation body of switch on cases 240 var body *ast.CompoundStmt 241 var ok bool 242 body, ok = n.Children()[len(n.Children())-1].(*ast.CompoundStmt) 243 if !ok { 244 body = &ast.CompoundStmt{} 245 body.AddChild(n.Children()[len(n.Children())-1]) 246 } 247 248 // CompoundStmt 249 // `-CaseStmt 250 // |-UnaryOperator 'int' prefix '-' 251 // | `-IntegerLiteral 'int' 1 252 // |-<<<NULL>>> 253 // `-CaseStmt 254 // |-BinaryOperator 'int' '-' 255 // | |- ... 256 // |-<<<NULL>>> 257 // `-CaseStmt 258 // |- ... 259 // |-<<<NULL>>> 260 // `-DefaultStmt 261 // `- ... 262 if body == (*ast.CompoundStmt)(nil) { 263 body = &ast.CompoundStmt{} 264 } 265 parts := caseSplitter(body.Children()...) 266 pre, parts := caseMerge(parts) 267 body.ChildNodes = parts 268 269 if len(pre) > 0 { 270 stmt, newPre, newPost, err := transpileCompoundStmt(&ast.CompoundStmt{ 271 ChildNodes: pre, 272 }, p) 273 p.AddMessage(p.GenerateWarningMessage(err, n)) 274 preStmts = append(preStmts, newPre...) 275 preStmts = append(preStmts, stmt.List...) 276 preStmts = append(preStmts, newPost...) 277 } 278 279 // The body will always be a CompoundStmt because a switch statement is not 280 // valid without curly brackets. 281 cases, newPre, newPost, err := normalizeSwitchCases(body, p) 282 if err != nil { 283 return nil, nil, nil, err 284 } 285 286 preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost) 287 288 // TODO 289 // 290 // from: 291 // case 2: 292 // { 293 // } 294 // fallthrough 295 // case 3: 296 // fallthrough 297 // case 4: 298 // break 299 // to: 300 // case 2,3: 301 // fallthrough 302 // case 4: 303 // break 304 // 305 // from: 306 // fallthrough 307 // break 308 // to: 309 // --- 310 311 // cases with 2 nodes 312 // from : 313 // { 314 // ... 315 // } 316 // break or fallthrough 317 // to: 318 // ... 319 // break or fallthrough 320 for i := range cases { 321 body := cases[i].Body 322 if len(body) != 2 { 323 continue 324 } 325 var ( 326 last = body[len(body)-1] 327 prelast = body[len(body)-2] 328 ) 329 br, ok := last.(*goast.BranchStmt) 330 if !ok || !(br.Tok == token.FALLTHROUGH || br.Tok == token.BREAK) { 331 continue 332 } 333 bl, ok := prelast.(*goast.BlockStmt) 334 if !ok { 335 continue 336 } 337 338 cases[i].Body = append(bl.List, br) 339 } 340 341 // from: 342 // return 343 // fallthrough 344 // to: 345 // return 346 for i := range cases { 347 body := cases[i].Body 348 if len(body) < 2 { 349 continue 350 } 351 var ( 352 last = body[len(body)-1] 353 prelast = body[len(body)-2] 354 ) 355 if br, ok := last.(*goast.BranchStmt); !ok || br.Tok != token.FALLTHROUGH { 356 continue 357 } 358 if _, ok := prelast.(*goast.ReturnStmt); !ok { 359 continue 360 } 361 cases[i].Body = cases[i].Body[:len(cases[i].Body)-1] 362 } 363 364 // from: 365 // break 366 // fallthrough 367 // to: 368 // --- 369 for i := range cases { 370 body := cases[i].Body 371 if len(body) < 2 { 372 continue 373 } 374 var ( 375 last = body[len(body)-1] 376 prelast = body[len(body)-2] 377 ) 378 if br, ok := last.(*goast.BranchStmt); !ok || br.Tok != token.FALLTHROUGH { 379 continue 380 } 381 if br, ok := prelast.(*goast.BranchStmt); !ok || br.Tok != token.BREAK { 382 continue 383 } 384 cases[i].Body = cases[i].Body[:len(cases[i].Body)-2] 385 } 386 387 // cases with 1 node 388 // from : 389 // { 390 // ... 391 // } 392 // to: 393 // ... 394 for i := range cases { 395 body := cases[i].Body 396 if len(body) != 1 { 397 continue 398 } 399 var ( 400 last = body[len(body)-1] 401 ) 402 bl, ok := last.(*goast.BlockStmt) 403 if !ok { 404 continue 405 } 406 407 cases[i].Body = bl.List 408 } 409 410 // Convert the normalized cases back into statements so they can be children 411 // of goast.SwitchStmt. 412 stmts := []goast.Stmt{} 413 for _, singleCase := range cases { 414 if singleCase == nil { 415 panic("nil single case") 416 } 417 418 stmts = append(stmts, singleCase) 419 } 420 421 return &goast.SwitchStmt{ 422 Tag: condition, 423 Body: &goast.BlockStmt{ 424 List: stmts, 425 }, 426 }, preStmts, postStmts, nil 427 } 428 429 func normalizeSwitchCases(body *ast.CompoundStmt, p *program.Program) ( 430 _ []*goast.CaseClause, preStmts []goast.Stmt, postStmts []goast.Stmt, err error) { 431 // The body of a switch has a non uniform structure. For example: 432 // 433 // switch a { 434 // case 1: 435 // foo(); 436 // bar(); 437 // break; 438 // default: 439 // baz(); 440 // qux(); 441 // } 442 // 443 // Looks like: 444 // 445 // *ast.CompountStmt 446 // *ast.CaseStmt // case 1: 447 // *ast.CallExpr // foo() 448 // *ast.CallExpr // bar() 449 // *ast.BreakStmt // break 450 // *ast.DefaultStmt // default: 451 // *ast.CallExpr // baz() 452 // *ast.CallExpr // qux() 453 // 454 // Each of the cases contains one child that is the first statement, but all 455 // the rest are children of the parent CompountStmt. This makes it 456 // especially tricky when we want to remove the 'break' or add a 457 // 'fallthrough'. 458 // 459 // To make it easier we normalise the cases. This means that we iterate 460 // through all of the statements of the CompountStmt and merge any children 461 // that are not 'case' or 'break' with the previous node to give us a 462 // structure like: 463 // 464 // []*goast.CaseClause 465 // *goast.CaseClause // case 1: 466 // *goast.CallExpr // foo() 467 // *goast.CallExpr // bar() 468 // // *ast.BreakStmt // break (was removed) 469 // *goast.CaseClause // default: 470 // *goast.CallExpr // baz() 471 // *goast.CallExpr // qux() 472 // 473 // During this translation we also remove 'break' or append a 'fallthrough'. 474 475 cases := []*goast.CaseClause{} 476 477 for _, x := range body.Children() { 478 switch c := x.(type) { 479 case *ast.CaseStmt, *ast.DefaultStmt: 480 var newPre, newPost []goast.Stmt 481 cases, newPre, newPost, err = appendCaseOrDefaultToNormalizedCases(cases, c, p) 482 if err != nil { 483 return []*goast.CaseClause{}, nil, nil, err 484 } 485 486 preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost) 487 case *ast.BreakStmt: 488 489 default: 490 var stmt goast.Stmt 491 var newPre, newPost []goast.Stmt 492 stmt, newPre, newPost, err = transpileToStmt(x, p) 493 if err != nil { 494 return []*goast.CaseClause{}, nil, nil, err 495 } 496 preStmts = append(preStmts, newPre...) 497 preStmts = append(preStmts, stmt) 498 preStmts = append(preStmts, newPost...) 499 } 500 } 501 502 return cases, preStmts, postStmts, nil 503 } 504 505 func appendCaseOrDefaultToNormalizedCases(cases []*goast.CaseClause, 506 stmt ast.Node, p *program.Program) ( 507 []*goast.CaseClause, []goast.Stmt, []goast.Stmt, error) { 508 preStmts := []goast.Stmt{} 509 postStmts := []goast.Stmt{} 510 511 if len(cases) > 0 { 512 cases[len(cases)-1].Body = append(cases[len(cases)-1].Body, &goast.BranchStmt{ 513 Tok: token.FALLTHROUGH, 514 }) 515 } 516 517 var singleCase *goast.CaseClause 518 var err error 519 var newPre []goast.Stmt 520 var newPost []goast.Stmt 521 522 switch c := stmt.(type) { 523 case *ast.CaseStmt: 524 singleCase, newPre, newPost, err = transpileCaseStmt(c, p) 525 526 case *ast.DefaultStmt: 527 singleCase, err = transpileDefaultStmt(c, p) 528 } 529 530 if singleCase != nil { 531 cases = append(cases, singleCase) 532 } 533 534 preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost) 535 536 if err != nil { 537 return []*goast.CaseClause{}, nil, nil, err 538 } 539 540 return cases, preStmts, postStmts, nil 541 } 542 543 func transpileCaseStmt(n *ast.CaseStmt, p *program.Program) ( 544 _ *goast.CaseClause, _ []goast.Stmt, _ []goast.Stmt, err error) { 545 defer func() { 546 if err != nil { 547 err = fmt.Errorf("cannot transpileCaseStmt: %v", err) 548 } 549 }() 550 preStmts := []goast.Stmt{} 551 postStmts := []goast.Stmt{} 552 553 c, cType, newPre, newPost, err := transpileToExpr(n.Children()[0], p, false) 554 if err != nil { 555 return nil, nil, nil, err 556 } 557 if cType == "bool" { 558 c, err = types.CastExpr(p, c, cType, "int") 559 p.AddMessage(p.GenerateWarningMessage(err, n)) 560 } 561 preStmts, postStmts = combinePreAndPostStmts(preStmts, postStmts, newPre, newPost) 562 563 stmts, err := transpileStmts(n.Children()[1:], p) 564 if err != nil { 565 return nil, nil, nil, err 566 } 567 568 return &goast.CaseClause{ 569 List: []goast.Expr{c}, 570 Body: stmts, 571 }, preStmts, postStmts, nil 572 } 573 574 func transpileDefaultStmt(n *ast.DefaultStmt, p *program.Program) ( 575 *goast.CaseClause, error) { 576 577 stmts, err := transpileStmts(n.Children()[0:], p) 578 if err != nil { 579 return nil, err 580 } 581 582 return &goast.CaseClause{ 583 List: nil, 584 Body: stmts, 585 }, nil 586 }