golang.org/x/arch@v0.17.0/internal/unify/closure.go (about)

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unify
     6  
     7  import (
     8  	"fmt"
     9  	"iter"
    10  	"maps"
    11  	"slices"
    12  )
    13  
    14  type Closure struct {
    15  	val *Value
    16  	env nonDetEnv
    17  }
    18  
    19  func NewSum(vs ...*Value) Closure {
    20  	id := &ident{name: "sum"}
    21  	return Closure{NewValue(Var{id}), topEnv.bind(id, vs...)}
    22  }
    23  
    24  // IsBottom returns whether c consists of no values.
    25  func (c Closure) IsBottom() bool {
    26  	return c.val.Domain == nil
    27  }
    28  
    29  // Summands returns the top-level Values of c. This assumes the top-level of c
    30  // was constructed as a sum, and is mostly useful for debugging.
    31  func (c Closure) Summands() iter.Seq[*Value] {
    32  	if v, ok := c.val.Domain.(Var); ok {
    33  		parts := c.env.partitionBy(v.id)
    34  		return func(yield func(*Value) bool) {
    35  			for _, part := range parts {
    36  				if !yield(part.value) {
    37  					return
    38  				}
    39  			}
    40  		}
    41  	}
    42  	return func(yield func(*Value) bool) {
    43  		yield(c.val)
    44  	}
    45  }
    46  
    47  // All enumerates all possible concrete values of c by substituting variables
    48  // from the environment.
    49  //
    50  // E.g., enumerating this Value
    51  //
    52  //	a: !sum [1, 2]
    53  //	b: !sum [3, 4]
    54  //
    55  // results in
    56  //
    57  //   - {a: 1, b: 3}
    58  //   - {a: 1, b: 4}
    59  //   - {a: 2, b: 3}
    60  //   - {a: 2, b: 4}
    61  func (c Closure) All() iter.Seq[*Value] {
    62  	// In order to enumerate all concrete values under all possible variable
    63  	// bindings, we use a "non-deterministic continuation passing style" to
    64  	// implement this. We use CPS to traverse the Value tree, threading the
    65  	// (possibly narrowing) environment through that CPS following an Euler
    66  	// tour. Where the environment permits multiple choices, we invoke the same
    67  	// continuation for each choice. Similar to a yield function, the
    68  	// continuation can return false to stop the non-deterministic walk.
    69  	return func(yield func(*Value) bool) {
    70  		c.val.all1(c.env, func(v *Value, e nonDetEnv) bool {
    71  			return yield(v)
    72  		})
    73  	}
    74  }
    75  
    76  func (v *Value) all1(e nonDetEnv, cont func(*Value, nonDetEnv) bool) bool {
    77  	switch d := v.Domain.(type) {
    78  	default:
    79  		panic(fmt.Sprintf("unknown domain type %T", d))
    80  
    81  	case nil:
    82  		return true
    83  
    84  	case Top, String:
    85  		return cont(v, e)
    86  
    87  	case Def:
    88  		fields := d.keys()
    89  		// We can reuse this parts slice because we're doing a DFS through the
    90  		// state space. (Otherwise, we'd have to do some messy threading of an
    91  		// immutable slice-like value through allElt.)
    92  		parts := make(map[string]*Value, len(fields))
    93  
    94  		// TODO: If there are no Vars or Sums under this Def, then nothing can
    95  		// change the Value or env, so we could just cont(v, e).
    96  		var allElt func(elt int, e nonDetEnv) bool
    97  		allElt = func(elt int, e nonDetEnv) bool {
    98  			if elt == len(fields) {
    99  				// Build a new Def from the concrete parts. Clone parts because
   100  				// we may reuse it on other non-deterministic branches.
   101  				nVal := newValueFrom(Def{maps.Clone(parts)}, v)
   102  				return cont(nVal, e)
   103  			}
   104  
   105  			return d.fields[fields[elt]].all1(e, func(v *Value, e nonDetEnv) bool {
   106  				parts[fields[elt]] = v
   107  				return allElt(elt+1, e)
   108  			})
   109  		}
   110  		return allElt(0, e)
   111  
   112  	case Tuple:
   113  		// Essentially the same as Def.
   114  		if d.repeat != nil {
   115  			// There's nothing we can do with this.
   116  			return cont(v, e)
   117  		}
   118  		parts := make([]*Value, len(d.vs))
   119  		var allElt func(elt int, e nonDetEnv) bool
   120  		allElt = func(elt int, e nonDetEnv) bool {
   121  			if elt == len(d.vs) {
   122  				// Build a new tuple from the concrete parts. Clone parts because
   123  				// we may reuse it on other non-deterministic branches.
   124  				nVal := newValueFrom(Tuple{vs: slices.Clone(parts)}, v)
   125  				return cont(nVal, e)
   126  			}
   127  
   128  			return d.vs[elt].all1(e, func(v *Value, e nonDetEnv) bool {
   129  				parts[elt] = v
   130  				return allElt(elt+1, e)
   131  			})
   132  		}
   133  		return allElt(0, e)
   134  
   135  	case Var:
   136  		// Go each way this variable can be bound.
   137  		for _, ePart := range e.partitionBy(d.id) {
   138  			// d.id is no longer bound in this environment partition. We'll may
   139  			// need it later in the Euler tour, so bind it back to this single
   140  			// value.
   141  			env := ePart.env.bind(d.id, ePart.value)
   142  			if !ePart.value.all1(env, cont) {
   143  				return false
   144  			}
   145  		}
   146  		return true
   147  	}
   148  }