github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/walk/select.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package walk
     6  
     7  import (
     8  	"github.com/go-asm/go/cmd/compile/base"
     9  	"github.com/go-asm/go/cmd/compile/ir"
    10  	"github.com/go-asm/go/cmd/compile/typecheck"
    11  	"github.com/go-asm/go/cmd/compile/types"
    12  	"github.com/go-asm/go/cmd/src"
    13  )
    14  
    15  func walkSelect(sel *ir.SelectStmt) {
    16  	lno := ir.SetPos(sel)
    17  	if sel.Walked() {
    18  		base.Fatalf("double walkSelect")
    19  	}
    20  	sel.SetWalked(true)
    21  
    22  	init := ir.TakeInit(sel)
    23  
    24  	init = append(init, walkSelectCases(sel.Cases)...)
    25  	sel.Cases = nil
    26  
    27  	sel.Compiled = init
    28  	walkStmtList(sel.Compiled)
    29  
    30  	base.Pos = lno
    31  }
    32  
    33  func walkSelectCases(cases []*ir.CommClause) []ir.Node {
    34  	ncas := len(cases)
    35  	sellineno := base.Pos
    36  
    37  	// optimization: zero-case select
    38  	if ncas == 0 {
    39  		return []ir.Node{mkcallstmt("block")}
    40  	}
    41  
    42  	// optimization: one-case select: single op.
    43  	if ncas == 1 {
    44  		cas := cases[0]
    45  		ir.SetPos(cas)
    46  		l := cas.Init()
    47  		if cas.Comm != nil { // not default:
    48  			n := cas.Comm
    49  			l = append(l, ir.TakeInit(n)...)
    50  			switch n.Op() {
    51  			default:
    52  				base.Fatalf("select %v", n.Op())
    53  
    54  			case ir.OSEND:
    55  				// already ok
    56  
    57  			case ir.OSELRECV2:
    58  				r := n.(*ir.AssignListStmt)
    59  				if ir.IsBlank(r.Lhs[0]) && ir.IsBlank(r.Lhs[1]) {
    60  					n = r.Rhs[0]
    61  					break
    62  				}
    63  				r.SetOp(ir.OAS2RECV)
    64  			}
    65  
    66  			l = append(l, n)
    67  		}
    68  
    69  		l = append(l, cas.Body...)
    70  		l = append(l, ir.NewBranchStmt(base.Pos, ir.OBREAK, nil))
    71  		return l
    72  	}
    73  
    74  	// convert case value arguments to addresses.
    75  	// this rewrite is used by both the general code and the next optimization.
    76  	var dflt *ir.CommClause
    77  	for _, cas := range cases {
    78  		ir.SetPos(cas)
    79  		n := cas.Comm
    80  		if n == nil {
    81  			dflt = cas
    82  			continue
    83  		}
    84  		switch n.Op() {
    85  		case ir.OSEND:
    86  			n := n.(*ir.SendStmt)
    87  			n.Value = typecheck.NodAddr(n.Value)
    88  			n.Value = typecheck.Expr(n.Value)
    89  
    90  		case ir.OSELRECV2:
    91  			n := n.(*ir.AssignListStmt)
    92  			if !ir.IsBlank(n.Lhs[0]) {
    93  				n.Lhs[0] = typecheck.NodAddr(n.Lhs[0])
    94  				n.Lhs[0] = typecheck.Expr(n.Lhs[0])
    95  			}
    96  		}
    97  	}
    98  
    99  	// optimization: two-case select but one is default: single non-blocking op.
   100  	if ncas == 2 && dflt != nil {
   101  		cas := cases[0]
   102  		if cas == dflt {
   103  			cas = cases[1]
   104  		}
   105  
   106  		n := cas.Comm
   107  		ir.SetPos(n)
   108  		r := ir.NewIfStmt(base.Pos, nil, nil, nil)
   109  		r.SetInit(cas.Init())
   110  		var cond ir.Node
   111  		switch n.Op() {
   112  		default:
   113  			base.Fatalf("select %v", n.Op())
   114  
   115  		case ir.OSEND:
   116  			// if selectnbsend(c, v) { body } else { default body }
   117  			n := n.(*ir.SendStmt)
   118  			ch := n.Chan
   119  			cond = mkcall1(chanfn("selectnbsend", 2, ch.Type()), types.Types[types.TBOOL], r.PtrInit(), ch, n.Value)
   120  
   121  		case ir.OSELRECV2:
   122  			n := n.(*ir.AssignListStmt)
   123  			recv := n.Rhs[0].(*ir.UnaryExpr)
   124  			ch := recv.X
   125  			elem := n.Lhs[0]
   126  			if ir.IsBlank(elem) {
   127  				elem = typecheck.NodNil()
   128  			}
   129  			cond = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
   130  			fn := chanfn("selectnbrecv", 2, ch.Type())
   131  			call := mkcall1(fn, fn.Type().ResultsTuple(), r.PtrInit(), elem, ch)
   132  			as := ir.NewAssignListStmt(r.Pos(), ir.OAS2, []ir.Node{cond, n.Lhs[1]}, []ir.Node{call})
   133  			r.PtrInit().Append(typecheck.Stmt(as))
   134  		}
   135  
   136  		r.Cond = typecheck.Expr(cond)
   137  		r.Body = cas.Body
   138  		r.Else = append(dflt.Init(), dflt.Body...)
   139  		return []ir.Node{r, ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)}
   140  	}
   141  
   142  	if dflt != nil {
   143  		ncas--
   144  	}
   145  	casorder := make([]*ir.CommClause, ncas)
   146  	nsends, nrecvs := 0, 0
   147  
   148  	var init []ir.Node
   149  
   150  	// generate sel-struct
   151  	base.Pos = sellineno
   152  	selv := typecheck.TempAt(base.Pos, ir.CurFunc, types.NewArray(scasetype(), int64(ncas)))
   153  	init = append(init, typecheck.Stmt(ir.NewAssignStmt(base.Pos, selv, nil)))
   154  
   155  	// No initialization for order; runtime.selectgo is responsible for that.
   156  	order := typecheck.TempAt(base.Pos, ir.CurFunc, types.NewArray(types.Types[types.TUINT16], 2*int64(ncas)))
   157  
   158  	var pc0, pcs ir.Node
   159  	if base.Flag.Race {
   160  		pcs = typecheck.TempAt(base.Pos, ir.CurFunc, types.NewArray(types.Types[types.TUINTPTR], int64(ncas)))
   161  		pc0 = typecheck.Expr(typecheck.NodAddr(ir.NewIndexExpr(base.Pos, pcs, ir.NewInt(base.Pos, 0))))
   162  	} else {
   163  		pc0 = typecheck.NodNil()
   164  	}
   165  
   166  	// register cases
   167  	for _, cas := range cases {
   168  		ir.SetPos(cas)
   169  
   170  		init = append(init, ir.TakeInit(cas)...)
   171  
   172  		n := cas.Comm
   173  		if n == nil { // default:
   174  			continue
   175  		}
   176  
   177  		var i int
   178  		var c, elem ir.Node
   179  		switch n.Op() {
   180  		default:
   181  			base.Fatalf("select %v", n.Op())
   182  		case ir.OSEND:
   183  			n := n.(*ir.SendStmt)
   184  			i = nsends
   185  			nsends++
   186  			c = n.Chan
   187  			elem = n.Value
   188  		case ir.OSELRECV2:
   189  			n := n.(*ir.AssignListStmt)
   190  			nrecvs++
   191  			i = ncas - nrecvs
   192  			recv := n.Rhs[0].(*ir.UnaryExpr)
   193  			c = recv.X
   194  			elem = n.Lhs[0]
   195  		}
   196  
   197  		casorder[i] = cas
   198  
   199  		setField := func(f string, val ir.Node) {
   200  			r := ir.NewAssignStmt(base.Pos, ir.NewSelectorExpr(base.Pos, ir.ODOT, ir.NewIndexExpr(base.Pos, selv, ir.NewInt(base.Pos, int64(i))), typecheck.Lookup(f)), val)
   201  			init = append(init, typecheck.Stmt(r))
   202  		}
   203  
   204  		c = typecheck.ConvNop(c, types.Types[types.TUNSAFEPTR])
   205  		setField("c", c)
   206  		if !ir.IsBlank(elem) {
   207  			elem = typecheck.ConvNop(elem, types.Types[types.TUNSAFEPTR])
   208  			setField("elem", elem)
   209  		}
   210  
   211  		// TODO(mdempsky): There should be a cleaner way to
   212  		// handle this.
   213  		if base.Flag.Race {
   214  			r := mkcallstmt("selectsetpc", typecheck.NodAddr(ir.NewIndexExpr(base.Pos, pcs, ir.NewInt(base.Pos, int64(i)))))
   215  			init = append(init, r)
   216  		}
   217  	}
   218  	if nsends+nrecvs != ncas {
   219  		base.Fatalf("walkSelectCases: miscount: %v + %v != %v", nsends, nrecvs, ncas)
   220  	}
   221  
   222  	// run the select
   223  	base.Pos = sellineno
   224  	chosen := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TINT])
   225  	recvOK := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
   226  	r := ir.NewAssignListStmt(base.Pos, ir.OAS2, nil, nil)
   227  	r.Lhs = []ir.Node{chosen, recvOK}
   228  	fn := typecheck.LookupRuntime("selectgo")
   229  	var fnInit ir.Nodes
   230  	r.Rhs = []ir.Node{mkcall1(fn, fn.Type().ResultsTuple(), &fnInit, bytePtrToIndex(selv, 0), bytePtrToIndex(order, 0), pc0, ir.NewInt(base.Pos, int64(nsends)), ir.NewInt(base.Pos, int64(nrecvs)), ir.NewBool(base.Pos, dflt == nil))}
   231  	init = append(init, fnInit...)
   232  	init = append(init, typecheck.Stmt(r))
   233  
   234  	// selv, order, and pcs (if race) are no longer alive after selectgo.
   235  
   236  	// dispatch cases
   237  	dispatch := func(cond ir.Node, cas *ir.CommClause) {
   238  		var list ir.Nodes
   239  
   240  		if n := cas.Comm; n != nil && n.Op() == ir.OSELRECV2 {
   241  			n := n.(*ir.AssignListStmt)
   242  			if !ir.IsBlank(n.Lhs[1]) {
   243  				x := ir.NewAssignStmt(base.Pos, n.Lhs[1], recvOK)
   244  				list.Append(typecheck.Stmt(x))
   245  			}
   246  		}
   247  
   248  		list.Append(cas.Body.Take()...)
   249  		list.Append(ir.NewBranchStmt(base.Pos, ir.OBREAK, nil))
   250  
   251  		var r ir.Node
   252  		if cond != nil {
   253  			cond = typecheck.Expr(cond)
   254  			cond = typecheck.DefaultLit(cond, nil)
   255  			r = ir.NewIfStmt(base.Pos, cond, list, nil)
   256  		} else {
   257  			r = ir.NewBlockStmt(base.Pos, list)
   258  		}
   259  
   260  		init = append(init, r)
   261  	}
   262  
   263  	if dflt != nil {
   264  		ir.SetPos(dflt)
   265  		dispatch(ir.NewBinaryExpr(base.Pos, ir.OLT, chosen, ir.NewInt(base.Pos, 0)), dflt)
   266  	}
   267  	for i, cas := range casorder {
   268  		ir.SetPos(cas)
   269  		if i == len(casorder)-1 {
   270  			dispatch(nil, cas)
   271  			break
   272  		}
   273  		dispatch(ir.NewBinaryExpr(base.Pos, ir.OEQ, chosen, ir.NewInt(base.Pos, int64(i))), cas)
   274  	}
   275  
   276  	return init
   277  }
   278  
   279  // bytePtrToIndex returns a Node representing "(*byte)(&n[i])".
   280  func bytePtrToIndex(n ir.Node, i int64) ir.Node {
   281  	s := typecheck.NodAddr(ir.NewIndexExpr(base.Pos, n, ir.NewInt(base.Pos, i)))
   282  	t := types.NewPtr(types.Types[types.TUINT8])
   283  	return typecheck.ConvNop(s, t)
   284  }
   285  
   286  var scase *types.Type
   287  
   288  // Keep in sync with src/runtime/select.go.
   289  func scasetype() *types.Type {
   290  	if scase == nil {
   291  		n := ir.NewDeclNameAt(src.NoXPos, ir.OTYPE, ir.Pkgs.Runtime.Lookup("scase"))
   292  		scase = types.NewNamed(n)
   293  		n.SetType(scase)
   294  		n.SetTypecheck(1)
   295  
   296  		scase.SetUnderlying(types.NewStruct([]*types.Field{
   297  			types.NewField(base.Pos, typecheck.Lookup("c"), types.Types[types.TUNSAFEPTR]),
   298  			types.NewField(base.Pos, typecheck.Lookup("elem"), types.Types[types.TUNSAFEPTR]),
   299  		}))
   300  	}
   301  	return scase
   302  }