github.com/bir3/gocompiler@v0.9.2202/src/cmd/compile/internal/rangefunc/rewrite.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 /* 6 Package rangefunc rewrites range-over-func to code that doesn't use range-over-funcs. 7 Rewriting the construct in the front end, before noder, means the functions generated during 8 the rewrite are available in a noder-generated representation for inlining by the back end. 9 10 # Theory of Operation 11 12 The basic idea is to rewrite 13 14 for x := range f { 15 ... 16 } 17 18 into 19 20 f(func(x T) bool { 21 ... 22 }) 23 24 But it's not usually that easy. 25 26 # Range variables 27 28 For a range not using :=, the assigned variables cannot be function parameters 29 in the generated body function. Instead, we allocate fake parameters and 30 start the body with an assignment. For example: 31 32 for expr1, expr2 = range f { 33 ... 34 } 35 36 becomes 37 38 f(func(#p1 T1, #p2 T2) bool { 39 expr1, expr2 = #p1, #p2 40 ... 41 }) 42 43 (All the generated variables have a # at the start to signal that they 44 are internal variables when looking at the generated code in a 45 debugger. Because variables have all been resolved to the specific 46 objects they represent, there is no danger of using plain "p1" and 47 colliding with a Go variable named "p1"; the # is just nice to have, 48 not for correctness.) 49 50 It can also happen that there are fewer range variables than function 51 arguments, in which case we end up with something like 52 53 f(func(x T1, _ T2) bool { 54 ... 55 }) 56 57 or 58 59 f(func(#p1 T1, #p2 T2, _ T3) bool { 60 expr1, expr2 = #p1, #p2 61 ... 62 }) 63 64 # Return 65 66 If the body contains a "break", that break turns into "return false", 67 to tell f to stop. And if the body contains a "continue", that turns 68 into "return true", to tell f to proceed with the next value. 69 Those are the easy cases. 70 71 If the body contains a return or a break/continue/goto L, then we need 72 to rewrite that into code that breaks out of the loop and then 73 triggers that control flow. In general we rewrite 74 75 for x := range f { 76 ... 77 } 78 79 into 80 81 { 82 var #next int 83 f(func(x T1) bool { 84 ... 85 return true 86 }) 87 ... check #next ... 88 } 89 90 The variable #next is an integer code that says what to do when f 91 returns. Each difficult statement sets #next and then returns false to 92 stop f. 93 94 A plain "return" rewrites to {#next = -1; return false}. 95 The return false breaks the loop. Then when f returns, the "check 96 #next" section includes 97 98 if #next == -1 { return } 99 100 which causes the return we want. 101 102 Return with arguments is more involved. We need somewhere to store the 103 arguments while we break out of f, so we add them to the var 104 declaration, like: 105 106 { 107 var ( 108 #next int 109 #r1 type1 110 #r2 type2 111 ) 112 f(func(x T1) bool { 113 ... 114 { 115 // return a, b 116 #r1, #r2 = a, b 117 #next = -2 118 return false 119 } 120 ... 121 return true 122 }) 123 if #next == -2 { return #r1, #r2 } 124 } 125 126 TODO: What about: 127 128 func f() (x bool) { 129 for range g(&x) { 130 return true 131 } 132 } 133 134 func g(p *bool) func(func() bool) { 135 return func(yield func() bool) { 136 yield() 137 // Is *p true or false here? 138 } 139 } 140 141 With this rewrite the "return true" is not visible after yield returns, 142 but maybe it should be? 143 144 # Checking 145 146 To permit checking that an iterator is well-behaved -- that is, that 147 it does not call the loop body again after it has returned false or 148 after the entire loop has exited (it might retain a copy of the body 149 function, or pass it to another goroutine) -- each generated loop has 150 its own #exitK flag that is checked before each iteration, and set both 151 at any early exit and after the iteration completes. 152 153 For example: 154 155 for x := range f { 156 ... 157 if ... { break } 158 ... 159 } 160 161 becomes 162 163 { 164 var #exit1 bool 165 f(func(x T1) bool { 166 if #exit1 { runtime.panicrangeexit() } 167 ... 168 if ... { #exit1 = true ; return false } 169 ... 170 return true 171 }) 172 #exit1 = true 173 } 174 175 # Nested Loops 176 177 So far we've only considered a single loop. If a function contains a 178 sequence of loops, each can be translated individually. But loops can 179 be nested. It would work to translate the innermost loop and then 180 translate the loop around it, and so on, except that there'd be a lot 181 of rewriting of rewritten code and the overall traversals could end up 182 taking time quadratic in the depth of the nesting. To avoid all that, 183 we use a single rewriting pass that handles a top-most range-over-func 184 loop and all the range-over-func loops it contains at the same time. 185 186 If we need to return from inside a doubly-nested loop, the rewrites 187 above stay the same, but the check after the inner loop only says 188 189 if #next < 0 { return false } 190 191 to stop the outer loop so it can do the actual return. That is, 192 193 for range f { 194 for range g { 195 ... 196 return a, b 197 ... 198 } 199 } 200 201 becomes 202 203 { 204 var ( 205 #next int 206 #r1 type1 207 #r2 type2 208 ) 209 var #exit1 bool 210 f(func() { 211 if #exit1 { runtime.panicrangeexit() } 212 var #exit2 bool 213 g(func() { 214 if #exit2 { runtime.panicrangeexit() } 215 ... 216 { 217 // return a, b 218 #r1, #r2 = a, b 219 #next = -2 220 #exit1, #exit2 = true, true 221 return false 222 } 223 ... 224 return true 225 }) 226 #exit2 = true 227 if #next < 0 { 228 return false 229 } 230 return true 231 }) 232 #exit1 = true 233 if #next == -2 { 234 return #r1, #r2 235 } 236 } 237 238 Note that the #next < 0 after the inner loop handles both kinds of 239 return with a single check. 240 241 # Labeled break/continue of range-over-func loops 242 243 For a labeled break or continue of an outer range-over-func, we 244 use positive #next values. Any such labeled break or continue 245 really means "do N breaks" or "do N breaks and 1 continue". 246 We encode that as perLoopStep*N or perLoopStep*N+1 respectively. 247 248 Loops that might need to propagate a labeled break or continue 249 add one or both of these to the #next checks: 250 251 if #next >= 2 { 252 #next -= 2 253 return false 254 } 255 256 if #next == 1 { 257 #next = 0 258 return true 259 } 260 261 For example 262 263 F: for range f { 264 for range g { 265 for range h { 266 ... 267 break F 268 ... 269 ... 270 continue F 271 ... 272 } 273 } 274 ... 275 } 276 277 becomes 278 279 { 280 var #next int 281 var #exit1 bool 282 f(func() { 283 if #exit1 { runtime.panicrangeexit() } 284 var #exit2 bool 285 g(func() { 286 if #exit2 { runtime.panicrangeexit() } 287 var #exit3 bool 288 h(func() { 289 if #exit3 { runtime.panicrangeexit() } 290 ... 291 { 292 // break F 293 #next = 4 294 #exit1, #exit2, #exit3 = true, true, true 295 return false 296 } 297 ... 298 { 299 // continue F 300 #next = 3 301 #exit2, #exit3 = true, true 302 return false 303 } 304 ... 305 return true 306 }) 307 #exit3 = true 308 if #next >= 2 { 309 #next -= 2 310 return false 311 } 312 return true 313 }) 314 #exit2 = true 315 if #next >= 2 { 316 #next -= 2 317 return false 318 } 319 if #next == 1 { 320 #next = 0 321 return true 322 } 323 ... 324 return true 325 }) 326 #exit1 = true 327 } 328 329 Note that the post-h checks only consider a break, 330 since no generated code tries to continue g. 331 332 # Gotos and other labeled break/continue 333 334 The final control flow translations are goto and break/continue of a 335 non-range-over-func statement. In both cases, we may need to break out 336 of one or more range-over-func loops before we can do the actual 337 control flow statement. Each such break/continue/goto L statement is 338 assigned a unique negative #next value (below -2, since -1 and -2 are 339 for the two kinds of return). Then the post-checks for a given loop 340 test for the specific codes that refer to labels directly targetable 341 from that block. Otherwise, the generic 342 343 if #next < 0 { return false } 344 345 check handles stopping the next loop to get one step closer to the label. 346 347 For example 348 349 Top: print("start\n") 350 for range f { 351 for range g { 352 ... 353 for range h { 354 ... 355 goto Top 356 ... 357 } 358 } 359 } 360 361 becomes 362 363 Top: print("start\n") 364 { 365 var #next int 366 var #exit1 bool 367 f(func() { 368 if #exit1 { runtime.panicrangeexit() } 369 var #exit2 bool 370 g(func() { 371 if #exit2 { runtime.panicrangeexit() } 372 ... 373 var #exit3 bool 374 h(func() { 375 if #exit3 { runtime.panicrangeexit() } 376 ... 377 { 378 // goto Top 379 #next = -3 380 #exit1, #exit2, #exit3 = true, true, true 381 return false 382 } 383 ... 384 return true 385 }) 386 #exit3 = true 387 if #next < 0 { 388 return false 389 } 390 return true 391 }) 392 #exit2 = true 393 if #next < 0 { 394 return false 395 } 396 return true 397 }) 398 #exit1 = true 399 if #next == -3 { 400 #next = 0 401 goto Top 402 } 403 } 404 405 Labeled break/continue to non-range-over-funcs are handled the same 406 way as goto. 407 408 # Defers 409 410 The last wrinkle is handling defer statements. If we have 411 412 for range f { 413 defer print("A") 414 } 415 416 we cannot rewrite that into 417 418 f(func() { 419 defer print("A") 420 }) 421 422 because the deferred code will run at the end of the iteration, not 423 the end of the containing function. To fix that, the runtime provides 424 a special hook that lets us obtain a defer "token" representing the 425 outer function and then use it in a later defer to attach the deferred 426 code to that outer function. 427 428 Normally, 429 430 defer print("A") 431 432 compiles to 433 434 runtime.deferproc(func() { print("A") }) 435 436 This changes in a range-over-func. For example: 437 438 for range f { 439 defer print("A") 440 } 441 442 compiles to 443 444 var #defers = runtime.deferrangefunc() 445 f(func() { 446 runtime.deferprocat(func() { print("A") }, #defers) 447 }) 448 449 For this rewriting phase, we insert the explicit initialization of 450 #defers and then attach the #defers variable to the CallStmt 451 representing the defer. That variable will be propagated to the 452 backend and will cause the backend to compile the defer using 453 deferprocat instead of an ordinary deferproc. 454 455 TODO: Could call runtime.deferrangefuncend after f. 456 */ 457 package rangefunc 458 459 import ( 460 "github.com/bir3/gocompiler/src/cmd/compile/internal/base" 461 "github.com/bir3/gocompiler/src/cmd/compile/internal/syntax" 462 "github.com/bir3/gocompiler/src/cmd/compile/internal/types2" 463 "fmt" 464 "github.com/bir3/gocompiler/src/go/constant" 465 "os" 466 ) 467 468 // nopos is the zero syntax.Pos. 469 var nopos syntax.Pos 470 471 // A rewriter implements rewriting the range-over-funcs in a given function. 472 type rewriter struct { 473 pkg *types2.Package 474 info *types2.Info 475 outer *syntax.FuncType 476 body *syntax.BlockStmt 477 478 // References to important types and values. 479 any types2.Object 480 bool types2.Object 481 int types2.Object 482 true types2.Object 483 false types2.Object 484 485 // Branch numbering, computed as needed. 486 branchNext map[branch]int // branch -> #next value 487 labelLoop map[string]*syntax.ForStmt // label -> innermost rangefunc loop it is declared inside (nil for no loop) 488 489 // Stack of nodes being visited. 490 stack []syntax.Node // all nodes 491 forStack []*forLoop // range-over-func loops 492 493 rewritten map[*syntax.ForStmt]syntax.Stmt 494 495 // Declared variables in generated code for outermost loop. 496 declStmt *syntax.DeclStmt 497 nextVar types2.Object 498 retVars []types2.Object 499 defers types2.Object 500 exitVarCount int // exitvars are referenced from their respective loops 501 } 502 503 // A branch is a single labeled branch. 504 type branch struct { 505 tok syntax.Token 506 label string 507 } 508 509 // A forLoop describes a single range-over-func loop being processed. 510 type forLoop struct { 511 nfor *syntax.ForStmt // actual syntax 512 exitFlag *types2.Var // #exit variable for this loop 513 exitFlagDecl *syntax.VarDecl 514 515 checkRet bool // add check for "return" after loop 516 checkRetArgs bool // add check for "return args" after loop 517 checkBreak bool // add check for "break" after loop 518 checkContinue bool // add check for "continue" after loop 519 checkBranch []branch // add check for labeled branch after loop 520 } 521 522 // Rewrite rewrites all the range-over-funcs in the files. 523 func Rewrite(pkg *types2.Package, info *types2.Info, files []*syntax.File) { 524 for _, file := range files { 525 syntax.Inspect(file, func(n syntax.Node) bool { 526 switch n := n.(type) { 527 case *syntax.FuncDecl: 528 rewriteFunc(pkg, info, n.Type, n.Body) 529 return false 530 case *syntax.FuncLit: 531 rewriteFunc(pkg, info, n.Type, n.Body) 532 return false 533 } 534 return true 535 }) 536 } 537 } 538 539 // rewriteFunc rewrites all the range-over-funcs in a single function (a top-level func or a func literal). 540 // The typ and body are the function's type and body. 541 func rewriteFunc(pkg *types2.Package, info *types2.Info, typ *syntax.FuncType, body *syntax.BlockStmt) { 542 if body == nil { 543 return 544 } 545 r := &rewriter{ 546 pkg: pkg, 547 info: info, 548 outer: typ, 549 body: body, 550 } 551 syntax.Inspect(body, r.inspect) 552 if (base.Flag.W != 0) && r.forStack != nil { 553 syntax.Fdump(os.Stderr, body) 554 } 555 } 556 557 // checkFuncMisuse reports whether to check for misuse of iterator callbacks functions. 558 func (r *rewriter) checkFuncMisuse() bool { 559 return base.Debug.RangeFuncCheck != 0 560 } 561 562 // inspect is a callback for syntax.Inspect that drives the actual rewriting. 563 // If it sees a func literal, it kicks off a separate rewrite for that literal. 564 // Otherwise, it maintains a stack of range-over-func loops and 565 // converts each in turn. 566 func (r *rewriter) inspect(n syntax.Node) bool { 567 switch n := n.(type) { 568 case *syntax.FuncLit: 569 rewriteFunc(r.pkg, r.info, n.Type, n.Body) 570 return false 571 572 default: 573 // Push n onto stack. 574 r.stack = append(r.stack, n) 575 if nfor, ok := forRangeFunc(n); ok { 576 loop := &forLoop{nfor: nfor} 577 r.forStack = append(r.forStack, loop) 578 r.startLoop(loop) 579 } 580 581 case nil: 582 // n == nil signals that we are done visiting 583 // the top-of-stack node's children. Find it. 584 n = r.stack[len(r.stack)-1] 585 586 // If we are inside a range-over-func, 587 // take this moment to replace any break/continue/goto/return 588 // statements directly contained in this node. 589 // Also replace any converted for statements 590 // with the rewritten block. 591 switch n := n.(type) { 592 case *syntax.BlockStmt: 593 for i, s := range n.List { 594 n.List[i] = r.editStmt(s) 595 } 596 case *syntax.CaseClause: 597 for i, s := range n.Body { 598 n.Body[i] = r.editStmt(s) 599 } 600 case *syntax.CommClause: 601 for i, s := range n.Body { 602 n.Body[i] = r.editStmt(s) 603 } 604 case *syntax.LabeledStmt: 605 n.Stmt = r.editStmt(n.Stmt) 606 } 607 608 // Pop n. 609 if len(r.forStack) > 0 && r.stack[len(r.stack)-1] == r.forStack[len(r.forStack)-1].nfor { 610 r.endLoop(r.forStack[len(r.forStack)-1]) 611 r.forStack = r.forStack[:len(r.forStack)-1] 612 } 613 r.stack = r.stack[:len(r.stack)-1] 614 } 615 return true 616 } 617 618 // startLoop sets up for converting a range-over-func loop. 619 func (r *rewriter) startLoop(loop *forLoop) { 620 // For first loop in function, allocate syntax for any, bool, int, true, and false. 621 if r.any == nil { 622 r.any = types2.Universe.Lookup("any") 623 r.bool = types2.Universe.Lookup("bool") 624 r.int = types2.Universe.Lookup("int") 625 r.true = types2.Universe.Lookup("true") 626 r.false = types2.Universe.Lookup("false") 627 r.rewritten = make(map[*syntax.ForStmt]syntax.Stmt) 628 } 629 if r.checkFuncMisuse() { 630 // declare the exit flag for this loop's body 631 loop.exitFlag, loop.exitFlagDecl = r.exitVar(loop.nfor.Pos()) 632 } 633 } 634 635 // editStmt returns the replacement for the statement x, 636 // or x itself if it should be left alone. 637 // This includes the for loops we are converting, 638 // as left in x.rewritten by r.endLoop. 639 func (r *rewriter) editStmt(x syntax.Stmt) syntax.Stmt { 640 if x, ok := x.(*syntax.ForStmt); ok { 641 if s := r.rewritten[x]; s != nil { 642 return s 643 } 644 } 645 646 if len(r.forStack) > 0 { 647 switch x := x.(type) { 648 case *syntax.BranchStmt: 649 return r.editBranch(x) 650 case *syntax.CallStmt: 651 if x.Tok == syntax.Defer { 652 return r.editDefer(x) 653 } 654 case *syntax.ReturnStmt: 655 return r.editReturn(x) 656 } 657 } 658 659 return x 660 } 661 662 // editDefer returns the replacement for the defer statement x. 663 // See the "Defers" section in the package doc comment above for more context. 664 func (r *rewriter) editDefer(x *syntax.CallStmt) syntax.Stmt { 665 if r.defers == nil { 666 // Declare and initialize the #defers token. 667 init := &syntax.CallExpr{ 668 Fun: runtimeSym(r.info, "deferrangefunc"), 669 } 670 tv := syntax.TypeAndValue{Type: r.any.Type()} 671 tv.SetIsValue() 672 init.SetTypeInfo(tv) 673 r.defers = r.declVar("#defers", r.any.Type(), init) 674 } 675 676 // Attach the token as an "extra" argument to the defer. 677 x.DeferAt = r.useVar(r.defers) 678 setPos(x.DeferAt, x.Pos()) 679 return x 680 } 681 682 func (r *rewriter) exitVar(pos syntax.Pos) (*types2.Var, *syntax.VarDecl) { 683 r.exitVarCount++ 684 685 name := fmt.Sprintf("#exit%d", r.exitVarCount) 686 typ := r.bool.Type() 687 obj := types2.NewVar(pos, r.pkg, name, typ) 688 n := syntax.NewName(pos, name) 689 setValueType(n, typ) 690 r.info.Defs[n] = obj 691 692 return obj, &syntax.VarDecl{NameList: []*syntax.Name{n}} 693 } 694 695 // editReturn returns the replacement for the return statement x. 696 // See the "Return" section in the package doc comment above for more context. 697 func (r *rewriter) editReturn(x *syntax.ReturnStmt) syntax.Stmt { 698 // #next = -1 is return with no arguments; -2 is return with arguments. 699 var next int 700 if x.Results == nil { 701 next = -1 702 r.forStack[0].checkRet = true 703 } else { 704 next = -2 705 r.forStack[0].checkRetArgs = true 706 } 707 708 // Tell the loops along the way to check for a return. 709 for _, loop := range r.forStack[1:] { 710 loop.checkRet = true 711 } 712 713 // Assign results, set #next, and return false. 714 bl := &syntax.BlockStmt{} 715 if x.Results != nil { 716 if r.retVars == nil { 717 for i, a := range r.outer.ResultList { 718 obj := r.declVar(fmt.Sprintf("#r%d", i+1), a.Type.GetTypeInfo().Type, nil) 719 r.retVars = append(r.retVars, obj) 720 } 721 } 722 bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.useList(r.retVars), Rhs: x.Results}) 723 } 724 bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)}) 725 if r.checkFuncMisuse() { 726 // mark all enclosing loop bodies as exited 727 for i := 0; i < len(r.forStack); i++ { 728 bl.List = append(bl.List, r.setExitedAt(i)) 729 } 730 } 731 bl.List = append(bl.List, &syntax.ReturnStmt{Results: r.useVar(r.false)}) 732 setPos(bl, x.Pos()) 733 return bl 734 } 735 736 // perLoopStep is part of the encoding of loop-spanning control flow 737 // for function range iterators. Each multiple of two encodes a "return false" 738 // passing control to an enclosing iterator; a terminal value of 1 encodes 739 // "return true" (i.e., local continue) from the body function, and a terminal 740 // value of 0 encodes executing the remainder of the body function. 741 const perLoopStep = 2 742 743 // editBranch returns the replacement for the branch statement x, 744 // or x itself if it should be left alone. 745 // See the package doc comment above for more context. 746 func (r *rewriter) editBranch(x *syntax.BranchStmt) syntax.Stmt { 747 if x.Tok == syntax.Fallthrough { 748 // Fallthrough is unaffected by the rewrite. 749 return x 750 } 751 752 // Find target of break/continue/goto in r.forStack. 753 // (The target may not be in r.forStack at all.) 754 targ := x.Target 755 i := len(r.forStack) - 1 756 if x.Label == nil && r.forStack[i].nfor != targ { 757 // Unlabeled break or continue that's not nfor must be inside nfor. Leave alone. 758 return x 759 } 760 for i >= 0 && r.forStack[i].nfor != targ { 761 i-- 762 } 763 // exitFrom is the index of the loop interior to the target of the control flow, 764 // if such a loop exists (it does not if i == len(r.forStack) - 1) 765 exitFrom := i + 1 766 767 // Compute the value to assign to #next and the specific return to use. 768 var next int 769 var ret *syntax.ReturnStmt 770 if x.Tok == syntax.Goto || i < 0 { 771 // goto Label 772 // or break/continue of labeled non-range-over-func loop. 773 // We may be able to leave it alone, or we may have to break 774 // out of one or more nested loops and then use #next to signal 775 // to complete the break/continue/goto. 776 // Figure out which range-over-func loop contains the label. 777 r.computeBranchNext() 778 nfor := r.forStack[len(r.forStack)-1].nfor 779 label := x.Label.Value 780 targ := r.labelLoop[label] 781 if nfor == targ { 782 // Label is in the innermost range-over-func loop; use it directly. 783 return x 784 } 785 786 // Set #next to the code meaning break/continue/goto label. 787 next = r.branchNext[branch{x.Tok, label}] 788 789 // Break out of nested loops up to targ. 790 i := len(r.forStack) - 1 791 for i >= 0 && r.forStack[i].nfor != targ { 792 i-- 793 } 794 exitFrom = i + 1 795 796 // Mark loop we exit to get to targ to check for that branch. 797 // When i==-1 that's the outermost func body 798 top := r.forStack[i+1] 799 top.checkBranch = append(top.checkBranch, branch{x.Tok, label}) 800 801 // Mark loops along the way to check for a plain return, so they break. 802 for j := i + 2; j < len(r.forStack); j++ { 803 r.forStack[j].checkRet = true 804 } 805 806 // In the innermost loop, use a plain "return false". 807 ret = &syntax.ReturnStmt{Results: r.useVar(r.false)} 808 } else { 809 // break/continue of labeled range-over-func loop. 810 depth := len(r.forStack) - 1 - i 811 812 // For continue of innermost loop, use "return true". 813 // Otherwise we are breaking the innermost loop, so "return false". 814 815 if depth == 0 && x.Tok == syntax.Continue { 816 ret = &syntax.ReturnStmt{Results: r.useVar(r.true)} 817 setPos(ret, x.Pos()) 818 return ret 819 } 820 ret = &syntax.ReturnStmt{Results: r.useVar(r.false)} 821 822 // If this is a simple break, mark this loop as exited and return false. 823 // No adjustments to #next. 824 if depth == 0 { 825 var stmts []syntax.Stmt 826 if r.checkFuncMisuse() { 827 stmts = []syntax.Stmt{r.setExited(), ret} 828 } else { 829 stmts = []syntax.Stmt{ret} 830 } 831 bl := &syntax.BlockStmt{ 832 List: stmts, 833 } 834 setPos(bl, x.Pos()) 835 return bl 836 } 837 838 // The loop inside the one we are break/continue-ing 839 // needs to make that happen when we break out of it. 840 if x.Tok == syntax.Continue { 841 r.forStack[exitFrom].checkContinue = true 842 } else { 843 exitFrom = i 844 r.forStack[exitFrom].checkBreak = true 845 } 846 847 // The loops along the way just need to break. 848 for j := exitFrom + 1; j < len(r.forStack); j++ { 849 r.forStack[j].checkBreak = true 850 } 851 852 // Set next to break the appropriate number of times; 853 // the final time may be a continue, not a break. 854 next = perLoopStep * depth 855 if x.Tok == syntax.Continue { 856 next-- 857 } 858 } 859 860 // Assign #next = next and do the return. 861 as := &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)} 862 bl := &syntax.BlockStmt{ 863 List: []syntax.Stmt{as}, 864 } 865 866 if r.checkFuncMisuse() { 867 // Set #exitK for this loop and those exited by the control flow. 868 for i := exitFrom; i < len(r.forStack); i++ { 869 bl.List = append(bl.List, r.setExitedAt(i)) 870 } 871 } 872 873 bl.List = append(bl.List, ret) 874 setPos(bl, x.Pos()) 875 return bl 876 } 877 878 // computeBranchNext computes the branchNext numbering 879 // and determines which labels end up inside which range-over-func loop bodies. 880 func (r *rewriter) computeBranchNext() { 881 if r.labelLoop != nil { 882 return 883 } 884 885 r.labelLoop = make(map[string]*syntax.ForStmt) 886 r.branchNext = make(map[branch]int) 887 888 var labels []string 889 var stack []syntax.Node 890 var forStack []*syntax.ForStmt 891 forStack = append(forStack, nil) 892 syntax.Inspect(r.body, func(n syntax.Node) bool { 893 if n != nil { 894 stack = append(stack, n) 895 if nfor, ok := forRangeFunc(n); ok { 896 forStack = append(forStack, nfor) 897 } 898 if n, ok := n.(*syntax.LabeledStmt); ok { 899 l := n.Label.Value 900 labels = append(labels, l) 901 f := forStack[len(forStack)-1] 902 r.labelLoop[l] = f 903 } 904 } else { 905 n := stack[len(stack)-1] 906 stack = stack[:len(stack)-1] 907 if n == forStack[len(forStack)-1] { 908 forStack = forStack[:len(forStack)-1] 909 } 910 } 911 return true 912 }) 913 914 // Assign numbers to all the labels we observed. 915 used := -2 916 for _, l := range labels { 917 used -= 3 918 r.branchNext[branch{syntax.Break, l}] = used 919 r.branchNext[branch{syntax.Continue, l}] = used + 1 920 r.branchNext[branch{syntax.Goto, l}] = used + 2 921 } 922 } 923 924 // endLoop finishes the conversion of a range-over-func loop. 925 // We have inspected and rewritten the body of the loop and can now 926 // construct the body function and rewrite the for loop into a call 927 // bracketed by any declarations and checks it requires. 928 func (r *rewriter) endLoop(loop *forLoop) { 929 // Pick apart for range X { ... } 930 nfor := loop.nfor 931 start, end := nfor.Pos(), nfor.Body.Rbrace // start, end position of for loop 932 rclause := nfor.Init.(*syntax.RangeClause) 933 rfunc := types2.CoreType(rclause.X.GetTypeInfo().Type).(*types2.Signature) // type of X - func(func(...)bool) 934 if rfunc.Params().Len() != 1 { 935 base.Fatalf("invalid typecheck of range func") 936 } 937 ftyp := types2.CoreType(rfunc.Params().At(0).Type()).(*types2.Signature) // func(...) bool 938 if ftyp.Results().Len() != 1 { 939 base.Fatalf("invalid typecheck of range func") 940 } 941 942 // Build X(bodyFunc) 943 call := &syntax.ExprStmt{ 944 X: &syntax.CallExpr{ 945 Fun: rclause.X, 946 ArgList: []syntax.Expr{ 947 r.bodyFunc(nfor.Body.List, syntax.UnpackListExpr(rclause.Lhs), rclause.Def, ftyp, start, end), 948 }, 949 }, 950 } 951 setPos(call, start) 952 953 // Build checks based on #next after X(bodyFunc) 954 checks := r.checks(loop, end) 955 956 // Rewrite for vars := range X { ... } to 957 // 958 // { 959 // r.declStmt 960 // call 961 // checks 962 // } 963 // 964 // The r.declStmt can be added to by this loop or any inner loop 965 // during the creation of r.bodyFunc; it is only emitted in the outermost 966 // converted range loop. 967 block := &syntax.BlockStmt{Rbrace: end} 968 setPos(block, start) 969 if len(r.forStack) == 1 && r.declStmt != nil { 970 setPos(r.declStmt, start) 971 block.List = append(block.List, r.declStmt) 972 } 973 974 // declare the exitFlag here so it has proper scope and zeroing 975 if r.checkFuncMisuse() { 976 exitFlagDecl := &syntax.DeclStmt{DeclList: []syntax.Decl{loop.exitFlagDecl}} 977 block.List = append(block.List, exitFlagDecl) 978 } 979 980 // iteratorFunc(bodyFunc) 981 block.List = append(block.List, call) 982 983 if r.checkFuncMisuse() { 984 // iteratorFunc has exited, mark the exit flag for the body 985 block.List = append(block.List, r.setExited()) 986 } 987 block.List = append(block.List, checks...) 988 989 if len(r.forStack) == 1 { // ending an outermost loop 990 r.declStmt = nil 991 r.nextVar = nil 992 r.retVars = nil 993 r.defers = nil 994 } 995 996 r.rewritten[nfor] = block 997 } 998 999 func (r *rewriter) setExited() *syntax.AssignStmt { 1000 return r.setExitedAt(len(r.forStack) - 1) 1001 } 1002 1003 func (r *rewriter) setExitedAt(index int) *syntax.AssignStmt { 1004 loop := r.forStack[index] 1005 return &syntax.AssignStmt{ 1006 Lhs: r.useVar(loop.exitFlag), 1007 Rhs: r.useVar(r.true), 1008 } 1009 } 1010 1011 // bodyFunc converts the loop body (control flow has already been updated) 1012 // to a func literal that can be passed to the range function. 1013 // 1014 // vars is the range variables from the range statement. 1015 // def indicates whether this is a := range statement. 1016 // ftyp is the type of the function we are creating 1017 // start and end are the syntax positions to use for new nodes 1018 // that should be at the start or end of the loop. 1019 func (r *rewriter) bodyFunc(body []syntax.Stmt, lhs []syntax.Expr, def bool, ftyp *types2.Signature, start, end syntax.Pos) *syntax.FuncLit { 1020 // Starting X(bodyFunc); build up bodyFunc first. 1021 var params, results []*types2.Var 1022 results = append(results, types2.NewVar(start, nil, "", r.bool.Type())) 1023 bodyFunc := &syntax.FuncLit{ 1024 // Note: Type is ignored but needs to be non-nil to avoid panic in syntax.Inspect. 1025 Type: &syntax.FuncType{}, 1026 Body: &syntax.BlockStmt{ 1027 List: []syntax.Stmt{}, 1028 Rbrace: end, 1029 }, 1030 } 1031 setPos(bodyFunc, start) 1032 1033 for i := 0; i < ftyp.Params().Len(); i++ { 1034 typ := ftyp.Params().At(i).Type() 1035 var paramVar *types2.Var 1036 if i < len(lhs) && def { 1037 // Reuse range variable as parameter. 1038 x := lhs[i] 1039 paramVar = r.info.Defs[x.(*syntax.Name)].(*types2.Var) 1040 } else { 1041 // Declare new parameter and assign it to range expression. 1042 paramVar = types2.NewVar(start, r.pkg, fmt.Sprintf("#p%d", 1+i), typ) 1043 if i < len(lhs) { 1044 x := lhs[i] 1045 as := &syntax.AssignStmt{Lhs: x, Rhs: r.useVar(paramVar)} 1046 as.SetPos(x.Pos()) 1047 setPos(as.Rhs, x.Pos()) 1048 bodyFunc.Body.List = append(bodyFunc.Body.List, as) 1049 } 1050 } 1051 params = append(params, paramVar) 1052 } 1053 1054 tv := syntax.TypeAndValue{ 1055 Type: types2.NewSignatureType(nil, nil, nil, 1056 types2.NewTuple(params...), 1057 types2.NewTuple(results...), 1058 false), 1059 } 1060 tv.SetIsValue() 1061 bodyFunc.SetTypeInfo(tv) 1062 1063 loop := r.forStack[len(r.forStack)-1] 1064 1065 if r.checkFuncMisuse() { 1066 bodyFunc.Body.List = append(bodyFunc.Body.List, r.assertNotExited(start, loop)) 1067 } 1068 1069 // Original loop body (already rewritten by editStmt during inspect). 1070 bodyFunc.Body.List = append(bodyFunc.Body.List, body...) 1071 1072 // return true to continue at end of loop body 1073 ret := &syntax.ReturnStmt{Results: r.useVar(r.true)} 1074 ret.SetPos(end) 1075 bodyFunc.Body.List = append(bodyFunc.Body.List, ret) 1076 1077 return bodyFunc 1078 } 1079 1080 // checks returns the post-call checks that need to be done for the given loop. 1081 func (r *rewriter) checks(loop *forLoop, pos syntax.Pos) []syntax.Stmt { 1082 var list []syntax.Stmt 1083 if len(loop.checkBranch) > 0 { 1084 did := make(map[branch]bool) 1085 for _, br := range loop.checkBranch { 1086 if did[br] { 1087 continue 1088 } 1089 did[br] = true 1090 doBranch := &syntax.BranchStmt{Tok: br.tok, Label: &syntax.Name{Value: br.label}} 1091 list = append(list, r.ifNext(syntax.Eql, r.branchNext[br], doBranch)) 1092 } 1093 } 1094 if len(r.forStack) == 1 { 1095 if loop.checkRetArgs { 1096 list = append(list, r.ifNext(syntax.Eql, -2, retStmt(r.useList(r.retVars)))) 1097 } 1098 if loop.checkRet { 1099 list = append(list, r.ifNext(syntax.Eql, -1, retStmt(nil))) 1100 } 1101 } else { 1102 if loop.checkRetArgs || loop.checkRet { 1103 // Note: next < 0 also handles gotos handled by outer loops. 1104 // We set checkRet in that case to trigger this check. 1105 list = append(list, r.ifNext(syntax.Lss, 0, retStmt(r.useVar(r.false)))) 1106 } 1107 if loop.checkBreak { 1108 list = append(list, r.ifNext(syntax.Geq, perLoopStep, retStmt(r.useVar(r.false)))) 1109 } 1110 if loop.checkContinue { 1111 list = append(list, r.ifNext(syntax.Eql, perLoopStep-1, retStmt(r.useVar(r.true)))) 1112 } 1113 } 1114 1115 for _, j := range list { 1116 setPos(j, pos) 1117 } 1118 return list 1119 } 1120 1121 // retStmt returns a return statement returning the given return values. 1122 func retStmt(results syntax.Expr) *syntax.ReturnStmt { 1123 return &syntax.ReturnStmt{Results: results} 1124 } 1125 1126 // ifNext returns the statement: 1127 // 1128 // if #next op c { adjust; then } 1129 // 1130 // When op is >=, adjust is #next -= c. 1131 // When op is == and c is not -1 or -2, adjust is #next = 0. 1132 // Otherwise adjust is omitted. 1133 func (r *rewriter) ifNext(op syntax.Operator, c int, then syntax.Stmt) syntax.Stmt { 1134 nif := &syntax.IfStmt{ 1135 Cond: &syntax.Operation{Op: op, X: r.next(), Y: r.intConst(c)}, 1136 Then: &syntax.BlockStmt{ 1137 List: []syntax.Stmt{then}, 1138 }, 1139 } 1140 tv := syntax.TypeAndValue{Type: r.bool.Type()} 1141 tv.SetIsValue() 1142 nif.Cond.SetTypeInfo(tv) 1143 1144 if op == syntax.Geq { 1145 sub := &syntax.AssignStmt{ 1146 Op: syntax.Sub, 1147 Lhs: r.next(), 1148 Rhs: r.intConst(c), 1149 } 1150 nif.Then.List = []syntax.Stmt{sub, then} 1151 } 1152 if op == syntax.Eql && c != -1 && c != -2 { 1153 clr := &syntax.AssignStmt{ 1154 Lhs: r.next(), 1155 Rhs: r.intConst(0), 1156 } 1157 nif.Then.List = []syntax.Stmt{clr, then} 1158 } 1159 1160 return nif 1161 } 1162 1163 // setValueType marks x as a value with type typ. 1164 func setValueType(x syntax.Expr, typ syntax.Type) { 1165 tv := syntax.TypeAndValue{Type: typ} 1166 tv.SetIsValue() 1167 x.SetTypeInfo(tv) 1168 } 1169 1170 // assertNotExited returns the statement: 1171 // 1172 // if #exitK { runtime.panicrangeexit() } 1173 // 1174 // where #exitK is the exit guard for loop. 1175 func (r *rewriter) assertNotExited(start syntax.Pos, loop *forLoop) syntax.Stmt { 1176 callPanicExpr := &syntax.CallExpr{ 1177 Fun: runtimeSym(r.info, "panicrangeexit"), 1178 } 1179 setValueType(callPanicExpr, nil) // no result type 1180 1181 callPanic := &syntax.ExprStmt{X: callPanicExpr} 1182 1183 nif := &syntax.IfStmt{ 1184 Cond: r.useVar(loop.exitFlag), 1185 Then: &syntax.BlockStmt{ 1186 List: []syntax.Stmt{callPanic}, 1187 }, 1188 } 1189 setPos(nif, start) 1190 return nif 1191 } 1192 1193 // next returns a reference to the #next variable. 1194 func (r *rewriter) next() *syntax.Name { 1195 if r.nextVar == nil { 1196 r.nextVar = r.declVar("#next", r.int.Type(), nil) 1197 } 1198 return r.useVar(r.nextVar) 1199 } 1200 1201 // forRangeFunc checks whether n is a range-over-func. 1202 // If so, it returns n.(*syntax.ForStmt), true. 1203 // Otherwise it returns nil, false. 1204 func forRangeFunc(n syntax.Node) (*syntax.ForStmt, bool) { 1205 nfor, ok := n.(*syntax.ForStmt) 1206 if !ok { 1207 return nil, false 1208 } 1209 nrange, ok := nfor.Init.(*syntax.RangeClause) 1210 if !ok { 1211 return nil, false 1212 } 1213 _, ok = types2.CoreType(nrange.X.GetTypeInfo().Type).(*types2.Signature) 1214 if !ok { 1215 return nil, false 1216 } 1217 return nfor, true 1218 } 1219 1220 // intConst returns syntax for an integer literal with the given value. 1221 func (r *rewriter) intConst(c int) *syntax.BasicLit { 1222 lit := &syntax.BasicLit{ 1223 Value: fmt.Sprint(c), 1224 Kind: syntax.IntLit, 1225 } 1226 tv := syntax.TypeAndValue{Type: r.int.Type(), Value: constant.MakeInt64(int64(c))} 1227 tv.SetIsValue() 1228 lit.SetTypeInfo(tv) 1229 return lit 1230 } 1231 1232 // useVar returns syntax for a reference to decl, which should be its declaration. 1233 func (r *rewriter) useVar(obj types2.Object) *syntax.Name { 1234 n := syntax.NewName(nopos, obj.Name()) 1235 tv := syntax.TypeAndValue{Type: obj.Type()} 1236 tv.SetIsValue() 1237 n.SetTypeInfo(tv) 1238 r.info.Uses[n] = obj 1239 return n 1240 } 1241 1242 // useList is useVar for a list of decls. 1243 func (r *rewriter) useList(vars []types2.Object) syntax.Expr { 1244 var new []syntax.Expr 1245 for _, obj := range vars { 1246 new = append(new, r.useVar(obj)) 1247 } 1248 if len(new) == 1 { 1249 return new[0] 1250 } 1251 return &syntax.ListExpr{ElemList: new} 1252 } 1253 1254 // declVar declares a variable with a given name type and initializer value. 1255 func (r *rewriter) declVar(name string, typ types2.Type, init syntax.Expr) *types2.Var { 1256 if r.declStmt == nil { 1257 r.declStmt = &syntax.DeclStmt{} 1258 } 1259 stmt := r.declStmt 1260 obj := types2.NewVar(stmt.Pos(), r.pkg, name, typ) 1261 n := syntax.NewName(stmt.Pos(), name) 1262 tv := syntax.TypeAndValue{Type: typ} 1263 tv.SetIsValue() 1264 n.SetTypeInfo(tv) 1265 r.info.Defs[n] = obj 1266 stmt.DeclList = append(stmt.DeclList, &syntax.VarDecl{ 1267 NameList: []*syntax.Name{n}, 1268 // Note: Type is ignored 1269 Values: init, 1270 }) 1271 return obj 1272 } 1273 1274 // declType declares a type with the given name and type. 1275 // This is more like "type name = typ" than "type name typ". 1276 func declType(pos syntax.Pos, name string, typ types2.Type) *syntax.Name { 1277 n := syntax.NewName(pos, name) 1278 n.SetTypeInfo(syntax.TypeAndValue{Type: typ}) 1279 return n 1280 } 1281 1282 // runtimePkg is a fake runtime package that contains what we need to refer to in package runtime. 1283 var runtimePkg = func() *types2.Package { 1284 var nopos syntax.Pos 1285 pkg := types2.NewPackage("runtime", "runtime") 1286 anyType := types2.Universe.Lookup("any").Type() 1287 1288 // func deferrangefunc() unsafe.Pointer 1289 obj := types2.NewFunc(nopos, pkg, "deferrangefunc", types2.NewSignatureType(nil, nil, nil, nil, types2.NewTuple(types2.NewParam(nopos, pkg, "extra", anyType)), false)) 1290 pkg.Scope().Insert(obj) 1291 1292 // func panicrangeexit() 1293 obj = types2.NewFunc(nopos, pkg, "panicrangeexit", types2.NewSignatureType(nil, nil, nil, nil, nil, false)) 1294 pkg.Scope().Insert(obj) 1295 1296 return pkg 1297 }() 1298 1299 // runtimeSym returns a reference to a symbol in the fake runtime package. 1300 func runtimeSym(info *types2.Info, name string) *syntax.Name { 1301 obj := runtimePkg.Scope().Lookup(name) 1302 n := syntax.NewName(nopos, "runtime."+name) 1303 tv := syntax.TypeAndValue{Type: obj.Type()} 1304 tv.SetIsValue() 1305 tv.SetIsRuntimeHelper() 1306 n.SetTypeInfo(tv) 1307 info.Uses[n] = obj 1308 return n 1309 } 1310 1311 // setPos walks the top structure of x that has no position assigned 1312 // and assigns it all to have position pos. 1313 // When setPos encounters a syntax node with a position assigned, 1314 // setPos does not look inside that node. 1315 // setPos only needs to handle syntax we create in this package; 1316 // all other syntax should have positions assigned already. 1317 func setPos(x syntax.Node, pos syntax.Pos) { 1318 if x == nil { 1319 return 1320 } 1321 syntax.Inspect(x, func(n syntax.Node) bool { 1322 if n == nil || n.Pos() != nopos { 1323 return false 1324 } 1325 n.SetPos(pos) 1326 switch n := n.(type) { 1327 case *syntax.BlockStmt: 1328 if n.Rbrace == nopos { 1329 n.Rbrace = pos 1330 } 1331 } 1332 return true 1333 }) 1334 }