github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/compare/compare.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package compare contains code for generating comparison
     6  // routines for structs, strings and interfaces.
     7  package compare
     8  
     9  import (
    10  	"fmt"
    11  	"math/bits"
    12  	"sort"
    13  
    14  	"github.com/go-asm/go/cmd/compile/base"
    15  	"github.com/go-asm/go/cmd/compile/ir"
    16  	"github.com/go-asm/go/cmd/compile/typecheck"
    17  	"github.com/go-asm/go/cmd/compile/types"
    18  )
    19  
    20  // IsRegularMemory reports whether t can be compared/hashed as regular memory.
    21  func IsRegularMemory(t *types.Type) bool {
    22  	a, _ := types.AlgType(t)
    23  	return a == types.AMEM
    24  }
    25  
    26  // Memrun finds runs of struct fields for which memory-only algs are appropriate.
    27  // t is the parent struct type, and start is the field index at which to start the run.
    28  // size is the length in bytes of the memory included in the run.
    29  // next is the index just after the end of the memory run.
    30  func Memrun(t *types.Type, start int) (size int64, next int) {
    31  	next = start
    32  	for {
    33  		next++
    34  		if next == t.NumFields() {
    35  			break
    36  		}
    37  		// Stop run after a padded field.
    38  		if types.IsPaddedField(t, next-1) {
    39  			break
    40  		}
    41  		// Also, stop before a blank or non-memory field.
    42  		if f := t.Field(next); f.Sym.IsBlank() || !IsRegularMemory(f.Type) {
    43  			break
    44  		}
    45  		// For issue 46283, don't combine fields if the resulting load would
    46  		// require a larger alignment than the component fields.
    47  		if base.Ctxt.Arch.Alignment > 1 {
    48  			align := t.Alignment()
    49  			if off := t.Field(start).Offset; off&(align-1) != 0 {
    50  				// Offset is less aligned than the containing type.
    51  				// Use offset to determine alignment.
    52  				align = 1 << uint(bits.TrailingZeros64(uint64(off)))
    53  			}
    54  			size := t.Field(next).End() - t.Field(start).Offset
    55  			if size > align {
    56  				break
    57  			}
    58  		}
    59  	}
    60  	return t.Field(next-1).End() - t.Field(start).Offset, next
    61  }
    62  
    63  // EqCanPanic reports whether == on type t could panic (has an interface somewhere).
    64  // t must be comparable.
    65  func EqCanPanic(t *types.Type) bool {
    66  	switch t.Kind() {
    67  	default:
    68  		return false
    69  	case types.TINTER:
    70  		return true
    71  	case types.TARRAY:
    72  		return EqCanPanic(t.Elem())
    73  	case types.TSTRUCT:
    74  		for _, f := range t.Fields() {
    75  			if !f.Sym.IsBlank() && EqCanPanic(f.Type) {
    76  				return true
    77  			}
    78  		}
    79  		return false
    80  	}
    81  }
    82  
    83  // EqStructCost returns the cost of an equality comparison of two structs.
    84  //
    85  // The cost is determined using an algorithm which takes into consideration
    86  // the size of the registers in the current architecture and the size of the
    87  // memory-only fields in the struct.
    88  func EqStructCost(t *types.Type) int64 {
    89  	cost := int64(0)
    90  
    91  	for i, fields := 0, t.Fields(); i < len(fields); {
    92  		f := fields[i]
    93  
    94  		// Skip blank-named fields.
    95  		if f.Sym.IsBlank() {
    96  			i++
    97  			continue
    98  		}
    99  
   100  		n, _, next := eqStructFieldCost(t, i)
   101  
   102  		cost += n
   103  		i = next
   104  	}
   105  
   106  	return cost
   107  }
   108  
   109  // eqStructFieldCost returns the cost of an equality comparison of two struct fields.
   110  // t is the parent struct type, and i is the index of the field in the parent struct type.
   111  // eqStructFieldCost may compute the cost of several adjacent fields at once. It returns
   112  // the cost, the size of the set of fields it computed the cost for (in bytes), and the
   113  // index of the first field not part of the set of fields for which the cost
   114  // has already been calculated.
   115  func eqStructFieldCost(t *types.Type, i int) (int64, int64, int) {
   116  	var (
   117  		cost    = int64(0)
   118  		regSize = int64(types.RegSize)
   119  
   120  		size int64
   121  		next int
   122  	)
   123  
   124  	if base.Ctxt.Arch.CanMergeLoads {
   125  		// If we can merge adjacent loads then we can calculate the cost of the
   126  		// comparison using the size of the memory run and the size of the registers.
   127  		size, next = Memrun(t, i)
   128  		cost = size / regSize
   129  		if size%regSize != 0 {
   130  			cost++
   131  		}
   132  		return cost, size, next
   133  	}
   134  
   135  	// If we cannot merge adjacent loads then we have to use the size of the
   136  	// field and take into account the type to determine how many loads and compares
   137  	// are needed.
   138  	ft := t.Field(i).Type
   139  	size = ft.Size()
   140  	next = i + 1
   141  
   142  	return calculateCostForType(ft), size, next
   143  }
   144  
   145  func calculateCostForType(t *types.Type) int64 {
   146  	var cost int64
   147  	switch t.Kind() {
   148  	case types.TSTRUCT:
   149  		return EqStructCost(t)
   150  	case types.TSLICE:
   151  		// Slices are not comparable.
   152  		base.Fatalf("eqStructFieldCost: unexpected slice type")
   153  	case types.TARRAY:
   154  		elemCost := calculateCostForType(t.Elem())
   155  		cost = t.NumElem() * elemCost
   156  	case types.TSTRING, types.TINTER, types.TCOMPLEX64, types.TCOMPLEX128:
   157  		cost = 2
   158  	case types.TINT64, types.TUINT64:
   159  		cost = 8 / int64(types.RegSize)
   160  	default:
   161  		cost = 1
   162  	}
   163  	return cost
   164  }
   165  
   166  // EqStruct compares two structs np and nq for equality.
   167  // It works by building a list of boolean conditions to satisfy.
   168  // Conditions must be evaluated in the returned order and
   169  // properly short-circuited by the caller.
   170  // The first return value is the flattened list of conditions,
   171  // the second value is a boolean indicating whether any of the
   172  // comparisons could panic.
   173  func EqStruct(t *types.Type, np, nq ir.Node) ([]ir.Node, bool) {
   174  	// The conditions are a list-of-lists. Conditions are reorderable
   175  	// within each inner list. The outer lists must be evaluated in order.
   176  	var conds [][]ir.Node
   177  	conds = append(conds, []ir.Node{})
   178  	and := func(n ir.Node) {
   179  		i := len(conds) - 1
   180  		conds[i] = append(conds[i], n)
   181  	}
   182  
   183  	// Walk the struct using memequal for runs of AMEM
   184  	// and calling specific equality tests for the others.
   185  	for i, fields := 0, t.Fields(); i < len(fields); {
   186  		f := fields[i]
   187  
   188  		// Skip blank-named fields.
   189  		if f.Sym.IsBlank() {
   190  			i++
   191  			continue
   192  		}
   193  
   194  		typeCanPanic := EqCanPanic(f.Type)
   195  
   196  		// Compare non-memory fields with field equality.
   197  		if !IsRegularMemory(f.Type) {
   198  			if typeCanPanic {
   199  				// Enforce ordering by starting a new set of reorderable conditions.
   200  				conds = append(conds, []ir.Node{})
   201  			}
   202  			switch {
   203  			case f.Type.IsString():
   204  				p := typecheck.DotField(base.Pos, typecheck.Expr(np), i)
   205  				q := typecheck.DotField(base.Pos, typecheck.Expr(nq), i)
   206  				eqlen, eqmem := EqString(p, q)
   207  				and(eqlen)
   208  				and(eqmem)
   209  			default:
   210  				and(eqfield(np, nq, i))
   211  			}
   212  			if typeCanPanic {
   213  				// Also enforce ordering after something that can panic.
   214  				conds = append(conds, []ir.Node{})
   215  			}
   216  			i++
   217  			continue
   218  		}
   219  
   220  		cost, size, next := eqStructFieldCost(t, i)
   221  		if cost <= 4 {
   222  			// Cost of 4 or less: use plain field equality.
   223  			for j := i; j < next; j++ {
   224  				and(eqfield(np, nq, j))
   225  			}
   226  		} else {
   227  			// Higher cost: use memequal.
   228  			cc := eqmem(np, nq, i, size)
   229  			and(cc)
   230  		}
   231  		i = next
   232  	}
   233  
   234  	// Sort conditions to put runtime calls last.
   235  	// Preserve the rest of the ordering.
   236  	var flatConds []ir.Node
   237  	for _, c := range conds {
   238  		isCall := func(n ir.Node) bool {
   239  			return n.Op() == ir.OCALL || n.Op() == ir.OCALLFUNC
   240  		}
   241  		sort.SliceStable(c, func(i, j int) bool {
   242  			return !isCall(c[i]) && isCall(c[j])
   243  		})
   244  		flatConds = append(flatConds, c...)
   245  	}
   246  	return flatConds, len(conds) > 1
   247  }
   248  
   249  // EqString returns the nodes
   250  //
   251  //	len(s) == len(t)
   252  //
   253  // and
   254  //
   255  //	memequal(s.ptr, t.ptr, len(s))
   256  //
   257  // which can be used to construct string equality comparison.
   258  // eqlen must be evaluated before eqmem, and shortcircuiting is required.
   259  func EqString(s, t ir.Node) (eqlen *ir.BinaryExpr, eqmem *ir.CallExpr) {
   260  	s = typecheck.Conv(s, types.Types[types.TSTRING])
   261  	t = typecheck.Conv(t, types.Types[types.TSTRING])
   262  	sptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, s)
   263  	tptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, t)
   264  	slen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, s), types.Types[types.TUINTPTR])
   265  	tlen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, t), types.Types[types.TUINTPTR])
   266  
   267  	// Pick the 3rd arg to memequal. Both slen and tlen are fine to use, because we short
   268  	// circuit the memequal call if they aren't the same. But if one is a constant some
   269  	// memequal optimizations are easier to apply.
   270  	probablyConstant := func(n ir.Node) bool {
   271  		if n.Op() == ir.OCONVNOP {
   272  			n = n.(*ir.ConvExpr).X
   273  		}
   274  		if n.Op() == ir.OLITERAL {
   275  			return true
   276  		}
   277  		if n.Op() != ir.ONAME {
   278  			return false
   279  		}
   280  		name := n.(*ir.Name)
   281  		if name.Class != ir.PAUTO {
   282  			return false
   283  		}
   284  		if def := name.Defn; def == nil {
   285  			// n starts out as the empty string
   286  			return true
   287  		} else if def.Op() == ir.OAS && (def.(*ir.AssignStmt).Y == nil || def.(*ir.AssignStmt).Y.Op() == ir.OLITERAL) {
   288  			// n starts out as a constant string
   289  			return true
   290  		}
   291  		return false
   292  	}
   293  	cmplen := slen
   294  	if probablyConstant(t) && !probablyConstant(s) {
   295  		cmplen = tlen
   296  	}
   297  
   298  	fn := typecheck.LookupRuntime("memequal", types.Types[types.TUINT8], types.Types[types.TUINT8])
   299  	call := typecheck.Call(base.Pos, fn, []ir.Node{sptr, tptr, ir.Copy(cmplen)}, false).(*ir.CallExpr)
   300  
   301  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, slen, tlen)
   302  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   303  	cmp.SetType(types.Types[types.TBOOL])
   304  	return cmp, call
   305  }
   306  
   307  // EqInterface returns the nodes
   308  //
   309  //	s.tab == t.tab (or s.typ == t.typ, as appropriate)
   310  //
   311  // and
   312  //
   313  //	ifaceeq(s.tab, s.data, t.data) (or efaceeq(s.typ, s.data, t.data), as appropriate)
   314  //
   315  // which can be used to construct interface equality comparison.
   316  // eqtab must be evaluated before eqdata, and shortcircuiting is required.
   317  func EqInterface(s, t ir.Node) (eqtab *ir.BinaryExpr, eqdata *ir.CallExpr) {
   318  	if !types.Identical(s.Type(), t.Type()) {
   319  		base.Fatalf("EqInterface %v %v", s.Type(), t.Type())
   320  	}
   321  	// func ifaceeq(tab *uintptr, x, y unsafe.Pointer) (ret bool)
   322  	// func efaceeq(typ *uintptr, x, y unsafe.Pointer) (ret bool)
   323  	var fn ir.Node
   324  	if s.Type().IsEmptyInterface() {
   325  		fn = typecheck.LookupRuntime("efaceeq")
   326  	} else {
   327  		fn = typecheck.LookupRuntime("ifaceeq")
   328  	}
   329  
   330  	stab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s)
   331  	ttab := ir.NewUnaryExpr(base.Pos, ir.OITAB, t)
   332  	sdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s)
   333  	tdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, t)
   334  	sdata.SetType(types.Types[types.TUNSAFEPTR])
   335  	tdata.SetType(types.Types[types.TUNSAFEPTR])
   336  	sdata.SetTypecheck(1)
   337  	tdata.SetTypecheck(1)
   338  
   339  	call := typecheck.Call(base.Pos, fn, []ir.Node{stab, sdata, tdata}, false).(*ir.CallExpr)
   340  
   341  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, stab, ttab)
   342  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   343  	cmp.SetType(types.Types[types.TBOOL])
   344  	return cmp, call
   345  }
   346  
   347  // eqfield returns the node
   348  //
   349  //	p.field == q.field
   350  func eqfield(p, q ir.Node, field int) ir.Node {
   351  	nx := typecheck.DotField(base.Pos, typecheck.Expr(p), field)
   352  	ny := typecheck.DotField(base.Pos, typecheck.Expr(q), field)
   353  	return typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, nx, ny))
   354  }
   355  
   356  // eqmem returns the node
   357  //
   358  //	memequal(&p.field, &q.field, size)
   359  func eqmem(p, q ir.Node, field int, size int64) ir.Node {
   360  	nx := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, p, field)))
   361  	ny := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, q, field)))
   362  
   363  	fn, needsize := eqmemfunc(size, nx.Type().Elem())
   364  	call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
   365  	call.Args.Append(nx)
   366  	call.Args.Append(ny)
   367  	if needsize {
   368  		call.Args.Append(ir.NewInt(base.Pos, size))
   369  	}
   370  
   371  	return call
   372  }
   373  
   374  func eqmemfunc(size int64, t *types.Type) (fn *ir.Name, needsize bool) {
   375  	switch size {
   376  	case 1, 2, 4, 8, 16:
   377  		buf := fmt.Sprintf("memequal%d", int(size)*8)
   378  		return typecheck.LookupRuntime(buf, t, t), false
   379  	}
   380  
   381  	return typecheck.LookupRuntime("memequal", t, t), true
   382  }