golang.org/x/arch@v0.17.0/internal/unify/env.go (about)

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unify
     6  
     7  import (
     8  	"fmt"
     9  	"iter"
    10  	"reflect"
    11  	"slices"
    12  	"strings"
    13  )
    14  
    15  // A nonDetEnv is a non-deterministic mapping from [ident]s to [Value]s.
    16  //
    17  // Logically, this is just a set of deterministic environments, where each
    18  // deterministic environment is a complete mapping from each [ident]s to exactly
    19  // one [Value]. In particular, [ident]s are NOT necessarily independent of each
    20  // other. For example, an environment may have both {x: 1, y: 1} and {x: 2, y:
    21  // 2}, but not {x: 1, y: 2}.
    22  //
    23  // A nonDetEnv is immutable.
    24  //
    25  // Often [ident]s are independent of each other, so the representation optimizes
    26  // for this by using a cross-product of environment factors, where each factor
    27  // is a sum of deterministic environments. These operations obey the usual
    28  // distributional laws, so we can always canonicalize into this form. (It MAY be
    29  // worthwhile to allow more general expressions of sums and products.)
    30  //
    31  // For example, to represent {{x: 1, y: 1}, {x: 2, y: 2}}, in which the
    32  // variables x and y are dependent, we need a single factor that covers x and y
    33  // and consists of two terms: {x: 1, y: 1} + {x: 2, y: 2}.
    34  //
    35  // If we add a third variable z that can be 1 or 2, independent of x and y, we
    36  // get four logical environments:
    37  //
    38  //	{x: 1, y: 1, z: 1}
    39  //	{x: 2, y: 2, z: 1}
    40  //	{x: 1, y: 1, z: 2}
    41  //	{x: 2, y: 2, z: 2}
    42  //
    43  // This could be represented as a single factor that is the sum of these four
    44  // detEnvs, but because z is independent, it can be a separate factor. Hence,
    45  // the most compact representation of this environment is:
    46  //
    47  //	({x: 1, y: 1} + {x: 2, y: 2}) тип ({z: 1} + {z: 2})
    48  //
    49  // That is, two factors, where each is the sum of two terms.
    50  type nonDetEnv struct {
    51  	// factors is a list of the multiplicative factors in this environment. The
    52  	// set of deterministic environments is the cross-product of these factors.
    53  	// All factors must have disjoint variables.
    54  	factors []*envSum
    55  }
    56  
    57  // envSum is a sum of deterministic environments, all with the same set of
    58  // variables.
    59  type envSum struct {
    60  	ids   []*ident // TODO: Do we ever use this as a slice? Should it be a map?
    61  	terms []detEnv
    62  }
    63  
    64  type detEnv struct {
    65  	vals []*Value // Indexes correspond to envSum.ids
    66  }
    67  
    68  var (
    69  	// zeroEnvFactor is the "0" value of an [envSum]. It's a a factor with no
    70  	// sum terms. This is easiest to think of as: an empty sum must be the
    71  	// additive identity, 0.
    72  	zeroEnvFactor = &envSum{}
    73  
    74  	// topEnv is the algebraic one value of a [nonDetEnv]. It has no factors
    75  	// because the product of no factors is the multiplicative identity.
    76  	topEnv = nonDetEnv{}
    77  	// bottomEnv is the algebraic zero value of a [nonDetEnv]. The product of
    78  	// bottomEnv with x is bottomEnv, and the sum of bottomEnv with y is y.
    79  	bottomEnv = nonDetEnv{factors: []*envSum{zeroEnvFactor}}
    80  )
    81  
    82  // bind binds id to each of vals in e.
    83  //
    84  // Its panics if id is already bound in e.
    85  //
    86  // Environments are typically initially constructed by starting with [topEnv]
    87  // and calling bind one or more times.
    88  func (e nonDetEnv) bind(id *ident, vals ...*Value) nonDetEnv {
    89  	if e.isBottom() {
    90  		return bottomEnv
    91  	}
    92  
    93  	// TODO: If any of vals are _, should we just not do anything? We're kind of
    94  	// inconsistent about whether an id missing from e means id is invalid or
    95  	// means id is _.
    96  
    97  	// Check that id isn't present in e.
    98  	for _, f := range e.factors {
    99  		if slices.Contains(f.ids, id) {
   100  			panic("id " + id.name + " already present in environment")
   101  		}
   102  	}
   103  
   104  	// Create the new sum term.
   105  	sum := &envSum{ids: []*ident{id}}
   106  	for _, val := range vals {
   107  		sum.terms = append(sum.terms, detEnv{vals: []*Value{val}})
   108  	}
   109  	// Multiply it in.
   110  	factors := append(e.factors[:len(e.factors):len(e.factors)], sum)
   111  	return nonDetEnv{factors}
   112  }
   113  
   114  func (e nonDetEnv) isBottom() bool {
   115  	if len(e.factors) == 0 {
   116  		// This is top.
   117  		return false
   118  	}
   119  	return len(e.factors[0].terms) == 0
   120  }
   121  
   122  func (e nonDetEnv) vars() iter.Seq[*ident] {
   123  	return func(yield func(*ident) bool) {
   124  		for _, t := range e.factors {
   125  			for _, id := range t.ids {
   126  				if !yield(id) {
   127  					return
   128  				}
   129  			}
   130  		}
   131  	}
   132  }
   133  
   134  // all enumerates all deterministic environments in e.
   135  //
   136  // The result slice is in the same order as the slice returned by
   137  // [nonDetEnv2.vars]. The slice is reused between iterations.
   138  func (e nonDetEnv) all() iter.Seq[[]*Value] {
   139  	return func(yield func([]*Value) bool) {
   140  		var vals []*Value
   141  		var walk func(int) bool
   142  		walk = func(i int) bool {
   143  			if i == len(e.factors) {
   144  				return yield(vals)
   145  			}
   146  			start := len(vals)
   147  			for _, term := range e.factors[i].terms {
   148  				vals = append(vals[:start], term.vals...)
   149  				if !walk(i + 1) {
   150  					return false
   151  				}
   152  			}
   153  			return true
   154  		}
   155  		walk(0)
   156  	}
   157  }
   158  
   159  // allOrdered is like all, but idOrder controls the order of the values in the
   160  // resulting slice. Any [ident]s in idOrder that are missing from e are set to
   161  // topValue. The values of idOrder must be a bijection with [0, n).
   162  func (e nonDetEnv) allOrdered(idOrder map[*ident]int) iter.Seq[[]*Value] {
   163  	valsLen := 0
   164  	for _, idx := range idOrder {
   165  		valsLen = max(valsLen, idx+1)
   166  	}
   167  
   168  	return func(yield func([]*Value) bool) {
   169  		vals := make([]*Value, valsLen)
   170  		// e may not have all of the IDs in idOrder. Make sure any missing
   171  		// values are top.
   172  		for i := range vals {
   173  			vals[i] = topValue
   174  		}
   175  		var walk func(int) bool
   176  		walk = func(i int) bool {
   177  			if i == len(e.factors) {
   178  				return yield(vals)
   179  			}
   180  			for _, term := range e.factors[i].terms {
   181  				for j, id := range e.factors[i].ids {
   182  					vals[idOrder[id]] = term.vals[j]
   183  				}
   184  				if !walk(i + 1) {
   185  					return false
   186  				}
   187  			}
   188  			return true
   189  		}
   190  		walk(0)
   191  	}
   192  }
   193  
   194  func crossEnvs(envs ...nonDetEnv) nonDetEnv {
   195  	// Combine the factors of envs
   196  	var factors []*envSum
   197  	haveIDs := map[*ident]struct{}{}
   198  	for _, e := range envs {
   199  		if e.isBottom() {
   200  			// The environment is bottom, so the whole product goes to
   201  			// bottom.
   202  			return bottomEnv
   203  		}
   204  		// Check that all ids are disjoint.
   205  		for _, f := range e.factors {
   206  			for _, id := range f.ids {
   207  				if _, ok := haveIDs[id]; ok {
   208  					panic("conflict on " + id.name)
   209  				}
   210  				haveIDs[id] = struct{}{}
   211  			}
   212  		}
   213  		// Everything checks out. Multiply the factors.
   214  		factors = append(factors, e.factors...)
   215  	}
   216  	return nonDetEnv{factors: factors}
   217  }
   218  
   219  func sumEnvs(envs ...nonDetEnv) nonDetEnv {
   220  	// nonDetEnv is a product at the top level, so we implement summation using
   221  	// the distributive law. We also use associativity to keep as many top-level
   222  	// factors as we can, since those are what keep the environment compact.
   223  	//
   224  	// a * b * c + a * d         (where a, b, c, and d are factors)
   225  	//                           (combine common factors)
   226  	//   = a * (b * c + d)
   227  	//                           (expand factors into their sum terms)
   228  	//   = a * ((b_1 + b_2 + ...) * (c_1 + c_2 + ...) + d)
   229  	//                           (where b_i and c_i are deterministic environments)
   230  	//                           (FOIL)
   231  	//   = a * (b_1 * c_1 + b_1 * c_2 + b_2 * c_1 + b_2 * c2 + d)
   232  	//                           (all factors are now in canonical form)
   233  	//   = a * e
   234  	//
   235  	// The product of two deterministic environments is a deterministic
   236  	// environment, and the sum of deterministic environments is a factor, so
   237  	// this process results in the canonical product-of-sums form.
   238  	//
   239  	// TODO: This is a bit of a one-way process. We could try to factor the
   240  	// environment to reduce the number of sums. I'm not sure how to do this
   241  	// efficiently. It might be possible to guide it by gathering the
   242  	// distributions of each ID's bindings. E.g., if there are 12 deterministic
   243  	// environments in a sum and $x is bound to 4 different values, each 3
   244  	// times, then it *might* be possible to factor out $x into a 4-way sum of
   245  	// its own.
   246  
   247  	factors, toSum := commonFactors(envs)
   248  
   249  	if len(toSum) > 0 {
   250  		// Collect all IDs into a single order.
   251  		var ids []*ident
   252  		idOrder := make(map[*ident]int)
   253  		for _, e := range toSum {
   254  			for v := range e.vars() {
   255  				if _, ok := idOrder[v]; !ok {
   256  					idOrder[v] = len(ids)
   257  					ids = append(ids, v)
   258  				}
   259  			}
   260  		}
   261  
   262  		// Flatten out each term in the sum.
   263  		var summands []detEnv
   264  		for _, env := range toSum {
   265  			for vals := range env.allOrdered(idOrder) {
   266  				summands = append(summands, detEnv{vals: slices.Clone(vals)})
   267  			}
   268  		}
   269  		factors = append(factors, &envSum{ids: ids, terms: summands})
   270  	}
   271  
   272  	return nonDetEnv{factors: factors}
   273  }
   274  
   275  // commonFactors finds common factors that can be factored out of a summation of
   276  // [nonDetEnv]s.
   277  func commonFactors(envs []nonDetEnv) (common []*envSum, toSum []nonDetEnv) {
   278  	// Drop any bottom environments. They don't contribute to the sum and they
   279  	// would complicate some logic below.
   280  	envs = slices.DeleteFunc(envs, func(e nonDetEnv) bool {
   281  		return e.isBottom()
   282  	})
   283  	if len(envs) == 0 {
   284  		return bottomEnv.factors, nil
   285  	}
   286  
   287  	// It's very common that the exact same factor will appear across all envs.
   288  	// Keep those factored out.
   289  	//
   290  	// TODO: Is it also common to have vars that are bound to the same value
   291  	// across all envs? If so, we could also factor those into common terms.
   292  	counts := map[*envSum]int{}
   293  	for _, e := range envs {
   294  		for _, f := range e.factors {
   295  			counts[f]++
   296  		}
   297  	}
   298  	for _, f := range envs[0].factors {
   299  		if counts[f] == len(envs) {
   300  			// Common factor
   301  			common = append(common, f)
   302  		}
   303  	}
   304  
   305  	// Any other factors need to be multiplied out.
   306  	for _, env := range envs {
   307  		var newFactors []*envSum
   308  		for _, f := range env.factors {
   309  			if counts[f] != len(envs) {
   310  				newFactors = append(newFactors, f)
   311  			}
   312  		}
   313  		if len(newFactors) > 0 {
   314  			toSum = append(toSum, nonDetEnv{factors: newFactors})
   315  		}
   316  	}
   317  
   318  	return common, toSum
   319  }
   320  
   321  // envPartition is a subset of an env where id is bound to value in all
   322  // deterministic environments.
   323  type envPartition struct {
   324  	id    *ident
   325  	value *Value
   326  	env   nonDetEnv
   327  }
   328  
   329  func (e nonDetEnv) partitionBy(id *ident) []envPartition {
   330  	if e.isBottom() {
   331  		// Bottom contains all variables
   332  		return []envPartition{{id: id, value: bottomValue, env: e}}
   333  	}
   334  
   335  	// Find the factor containing id and id's index in that factor.
   336  	idFactor, idIndex := -1, -1
   337  	var newIDs []*ident
   338  	for factI, fact := range e.factors {
   339  		idI := slices.Index(fact.ids, id)
   340  		if idI < 0 {
   341  			continue
   342  		} else if idFactor != -1 {
   343  			panic("multiple factors containing id " + id.name)
   344  		} else {
   345  			idFactor, idIndex = factI, idI
   346  			// Drop id from this factor's IDs
   347  			newIDs = without(fact.ids, idI)
   348  		}
   349  	}
   350  	if idFactor == -1 {
   351  		panic("id " + id.name + " not found in environment")
   352  	}
   353  
   354  	// If id is the only term in its factor, then dropping it is equivalent to
   355  	// making the factor be the unit value, so we can just drop the factor. (And
   356  	// if this is the only factor, we'll arrive at [topEnv], which is exactly
   357  	// what we want!). In this case we can use the same nonDetEnv in all of the
   358  	// partitions.
   359  	isUnit := len(newIDs) == 0
   360  	var unitFactors []*envSum
   361  	if isUnit {
   362  		unitFactors = without(e.factors, idFactor)
   363  	}
   364  
   365  	// Create a partition for each distinct value of id.
   366  	var parts []envPartition
   367  	partIndex := map[*Value]int{}
   368  	for _, det := range e.factors[idFactor].terms {
   369  		val := det.vals[idIndex]
   370  		i, ok := partIndex[val]
   371  		if !ok {
   372  			i = len(parts)
   373  			var factors []*envSum
   374  			if isUnit {
   375  				factors = unitFactors
   376  			} else {
   377  				// Copy all other factor
   378  				factors = slices.Clone(e.factors)
   379  				factors[idFactor] = &envSum{ids: newIDs}
   380  			}
   381  			parts = append(parts, envPartition{id: id, value: val, env: nonDetEnv{factors: factors}})
   382  			partIndex[val] = i
   383  		}
   384  
   385  		if !isUnit {
   386  			factor := parts[i].env.factors[idFactor]
   387  			newVals := without(det.vals, idIndex)
   388  			factor.terms = append(factor.terms, detEnv{vals: newVals})
   389  		}
   390  	}
   391  	return parts
   392  }
   393  
   394  type ident struct {
   395  	_    [0]func() // Not comparable (only compare *ident)
   396  	name string
   397  }
   398  
   399  type Var struct {
   400  	id *ident
   401  }
   402  
   403  func (d Var) Exact() bool {
   404  	// These can't appear in concrete Values.
   405  	panic("Exact called on non-concrete Value")
   406  }
   407  
   408  func (d Var) decode(rv reflect.Value) error {
   409  	return &inexactError{"var", rv.Type().String()}
   410  }
   411  
   412  func (d Var) unify(w *Value, e nonDetEnv, swap bool, uf *unifier) (Domain, nonDetEnv, error) {
   413  	// TODO: Vars from !sums in the input can have a huge number of values.
   414  	// Unifying these could be way more efficient with some indexes over any
   415  	// exact values we can pull out, like Def fields that are exact Strings.
   416  	// Maybe we try to produce an array of yes/no/maybe matches and then we only
   417  	// have to do deeper evaluation of the maybes. We could probably cache this
   418  	// on an envTerm. It may also help to special-case Var/Var unification to
   419  	// pick which one to index versus enumerate.
   420  
   421  	if vd, ok := w.Domain.(Var); ok && d.id == vd.id {
   422  		// Unifying $x with $x results in $x. If we descend into this we'll have
   423  		// problems because we strip $x out of the environment to keep ourselves
   424  		// honest and then can't find it on the other side.
   425  		//
   426  		// TODO: I'm not positive this is the right fix.
   427  		return vd, e, nil
   428  	}
   429  
   430  	// We need to unify w with the value of d in each possible environment. We
   431  	// can save some work by grouping environments by the value of d, since
   432  	// there will be a lot of redundancy here.
   433  	var nEnvs []nonDetEnv
   434  	envParts := e.partitionBy(d.id)
   435  	for i, envPart := range envParts {
   436  		exit := uf.enterVar(d.id, i)
   437  		// Each branch logically gets its own copy of the initial environment
   438  		// (narrowed down to just this binding of the variable), and each branch
   439  		// may result in different changes to that starting environment.
   440  		res, e2, err := w.unify(envPart.value, envPart.env, swap, uf)
   441  		exit.exit()
   442  		if err != nil {
   443  			return nil, nonDetEnv{}, err
   444  		}
   445  		if res.Domain == nil {
   446  			// This branch entirely failed to unify, so it's gone.
   447  			continue
   448  		}
   449  		nEnv := e2.bind(d.id, res)
   450  		nEnvs = append(nEnvs, nEnv)
   451  	}
   452  
   453  	if len(nEnvs) == 0 {
   454  		// All branches failed
   455  		return nil, bottomEnv, nil
   456  	}
   457  
   458  	// The effect of this is entirely captured in the environment. We can return
   459  	// back the same Bind node.
   460  	return d, sumEnvs(nEnvs...), nil
   461  }
   462  
   463  // An identPrinter maps [ident]s to unique string names.
   464  type identPrinter struct {
   465  	ids   map[*ident]string
   466  	idGen map[string]int
   467  }
   468  
   469  func (p *identPrinter) unique(id *ident) string {
   470  	if p.ids == nil {
   471  		p.ids = make(map[*ident]string)
   472  		p.idGen = make(map[string]int)
   473  	}
   474  
   475  	name, ok := p.ids[id]
   476  	if !ok {
   477  		gen := p.idGen[id.name]
   478  		p.idGen[id.name]++
   479  		if gen == 0 {
   480  			name = id.name
   481  		} else {
   482  			name = fmt.Sprintf("%s#%d", id.name, gen)
   483  		}
   484  		p.ids[id] = name
   485  	}
   486  
   487  	return name
   488  }
   489  
   490  func (p *identPrinter) slice(ids []*ident) string {
   491  	var strs []string
   492  	for _, id := range ids {
   493  		strs = append(strs, p.unique(id))
   494  	}
   495  	return fmt.Sprintf("[%s]", strings.Join(strs, ", "))
   496  }
   497  
   498  func without[Elt any](s []Elt, i int) []Elt {
   499  	return append(s[:i:i], s[i+1:]...)
   500  }