cuelang.org/go@v0.13.0/internal/core/adt/disjunct.go (about)

     1  // Copyright 2020 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package adt
    16  
    17  import (
    18  	"slices"
    19  
    20  	"cuelang.org/go/cue/errors"
    21  	"cuelang.org/go/cue/token"
    22  )
    23  
    24  // Nodes man not reenter a disjunction.
    25  //
    26  // Copy one layer deep; throw away items on failure.
    27  
    28  // DISJUNCTION ALGORITHM
    29  //
    30  // The basic concept of the algorithm is to use backtracking to find valid
    31  // disjunctions. The algorithm can stop if two matching disjuncts are found
    32  // where one does not subsume the other.
    33  //
    34  // At a later point, we can introduce a filter step to filter out possible
    35  // disjuncts based on, say, discriminator fields or field exclusivity (oneOf
    36  // fields in Protobuf).
    37  //
    38  // To understand the details of the algorithm, it is important to understand
    39  // some properties of disjunction.
    40  //
    41  //
    42  // EVALUATION OF A DISJUNCTION IS SELF CONTAINED
    43  //
    44  // In other words, fields outside of a disjunction cannot bind to values within
    45  // a disjunction whilst evaluating that disjunction. This allows the computation
    46  // of disjunctions to be isolated from side effects.
    47  //
    48  // The intuition behind this is as follows: as a disjunction is not a concrete
    49  // value, it is not possible to lookup a field within a disjunction if it has
    50  // not yet been evaluated. So if a reference within a disjunction that is needed
    51  // to disambiguate that disjunction refers to a field outside the scope of the
    52  // disjunction which, in turn, refers to a field within the disjunction, this
    53  // results in a cycle error. We achieve this by not removing the cycle marker of
    54  // the Vertex of the disjunction until the disjunction is resolved.
    55  //
    56  // Note that the following disjunct is still allowed:
    57  //
    58  //    a: 1
    59  //    b: a
    60  //
    61  // Even though `a` refers to the root of the disjunction, it does not _select
    62  // into_ the disjunction. Implementation-wise, it also doesn't have to, as the
    63  // respective vertex is available within the Environment. Referencing a node
    64  // outside the disjunction that in turn selects the disjunction root, however,
    65  // will result in a detected cycle.
    66  //
    67  // As usual, cycle detection should be interpreted marked as incomplete, so that
    68  // the referring node will not be fixed to an error prematurely.
    69  //
    70  //
    71  // SUBSUMPTION OF AMBIGUOUS DISJUNCTS
    72  //
    73  // A disjunction can be evaluated to a concrete value if only one disjunct
    74  // remains. Aside from disambiguating through unification failure, disjuncts
    75  // may also be disambiguated by taking the least specific of two disjuncts.
    76  // For instance, if a subsumes b, then the result of disjunction may be a.
    77  //
    78  //   NEW ALGORITHM NO LONGER VERIFIES SUBSUMPTION. SUBSUMPTION IS INHERENTLY
    79  //   IMPRECISE (DUE TO BULK OPTIONAL FIELDS). OTHER THAN THAT, FOR SCALAR VALUES
    80  //   IT JUST MEANS THERE IS AMBIGUITY, AND FOR STRUCTS IT CAN LEAD TO STRANGE
    81  //   CONSEQUENCES.
    82  //
    83  //   USE EQUALITY INSTEAD:
    84  //     - Undefined == error for optional fields.
    85  //     - So only need to check exact labels for vertices.
    86  
    87  type envDisjunct struct {
    88  	env     *Environment
    89  	cloneID CloseInfo
    90  	holeID  int
    91  
    92  	// fields for new evaluator
    93  
    94  	src       Node
    95  	disjuncts []disjunct
    96  
    97  	// fields for old evaluator
    98  
    99  	expr        *DisjunctionExpr
   100  	value       *Disjunction
   101  	hasDefaults bool
   102  
   103  	// These are used for book keeping, tracking whether any of the
   104  	// disjuncts marked with a default marker remains after unification.
   105  	// If no default is used, all other elements are treated as "maybeDefault".
   106  	// Otherwise, elements are treated as is.
   107  	parentDefaultUsed bool
   108  	childDefaultUsed  bool
   109  }
   110  
   111  func (n *nodeContext) addDisjunction(env *Environment, x *DisjunctionExpr, cloneID CloseInfo) {
   112  
   113  	// TODO: precompute
   114  	numDefaults := 0
   115  	for _, v := range x.Values {
   116  		isDef := v.Default // || n.hasDefaults(env, v.Val)
   117  		if isDef {
   118  			numDefaults++
   119  		}
   120  	}
   121  
   122  	n.disjunctions = append(n.disjunctions, envDisjunct{
   123  		env:         env,
   124  		cloneID:     cloneID,
   125  		expr:        x,
   126  		hasDefaults: numDefaults > 0,
   127  	})
   128  }
   129  
   130  func (n *nodeContext) addDisjunctionValue(env *Environment, x *Disjunction, cloneID CloseInfo) {
   131  	n.disjunctions = append(n.disjunctions, envDisjunct{
   132  		env:         env,
   133  		cloneID:     cloneID,
   134  		value:       x,
   135  		hasDefaults: x.HasDefaults,
   136  	})
   137  
   138  }
   139  
   140  func (n *nodeContext) expandDisjuncts(
   141  	state vertexStatus,
   142  	parent *nodeContext,
   143  	parentMode defaultMode, // default mode of this disjunct
   144  	recursive, last bool) {
   145  
   146  	unreachableForDev(n.ctx)
   147  
   148  	n.ctx.stats.Disjuncts++
   149  
   150  	// refNode is used to collect cyclicReferences for all disjuncts to be
   151  	// passed up to the parent node. Note that because the node in the parent
   152  	// context is overwritten in the course of expanding disjunction to retain
   153  	// pointer identity, it is not possible to simply record the refNodes in the
   154  	// parent directly.
   155  	var refNode *RefNode
   156  
   157  	node := n.node
   158  	defer func() {
   159  		n.node = node
   160  	}()
   161  
   162  	for n.expandOne(partial) {
   163  	}
   164  
   165  	// save node to snapShot in nodeContex
   166  	// save nodeContext.
   167  
   168  	if recursive || len(n.disjunctions) > 0 {
   169  		n.snapshot = clone(*n.node)
   170  	} else {
   171  		n.snapshot = *n.node
   172  	}
   173  
   174  	defaultOffset := len(n.usedDefault)
   175  
   176  	switch {
   177  	default: // len(n.disjunctions) == 0
   178  		m := *n
   179  		n.postDisjunct(state)
   180  
   181  		switch {
   182  		case n.hasErr():
   183  			// TODO: consider finalizing the node thusly:
   184  			// if recursive {
   185  			// 	n.node.Finalize(n.ctx)
   186  			// }
   187  			x := n.node
   188  			err, ok := x.BaseValue.(*Bottom)
   189  			if !ok {
   190  				err = n.getErr()
   191  			}
   192  			if err == nil {
   193  				// TODO(disjuncts): Is this always correct? Especially for partial
   194  				// evaluation it is okay for child errors to have incomplete errors.
   195  				// Perhaps introduce an Err() method.
   196  				err = x.ChildErrors
   197  			}
   198  			if err != nil {
   199  				parent.disjunctErrs = append(parent.disjunctErrs, err)
   200  			}
   201  			if recursive {
   202  				n.free()
   203  			}
   204  			return
   205  		}
   206  
   207  		if recursive {
   208  			*n = m
   209  			n.result = *n.node // XXX: n.result = snapshotVertex(n.node)?
   210  			n.node = &n.result
   211  			n.disjuncts = append(n.disjuncts, n)
   212  		}
   213  		if n.node.BaseValue == nil {
   214  			n.setBaseValue(n.getValidators(state))
   215  		}
   216  
   217  		n.usedDefault = append(n.usedDefault, defaultInfo{
   218  			parentMode: parentMode,
   219  			nestedMode: parentMode,
   220  			origMode:   parentMode,
   221  		})
   222  
   223  	case len(n.disjunctions) > 0:
   224  		// Process full disjuncts to ensure that erroneous disjuncts are
   225  		// eliminated as early as possible.
   226  		state = finalized
   227  
   228  		n.disjuncts = append(n.disjuncts, n)
   229  
   230  		n.refCount++
   231  		defer n.free()
   232  
   233  		for i, d := range n.disjunctions {
   234  			a := n.disjuncts
   235  			n.disjuncts = n.buffer[:0]
   236  			n.buffer = a[:0]
   237  
   238  			last := i+1 == len(n.disjunctions)
   239  			skipNonMonotonicChecks := i+1 < len(n.disjunctions)
   240  			if skipNonMonotonicChecks {
   241  				n.ctx.inDisjunct++
   242  			}
   243  
   244  			for _, dn := range a {
   245  				switch {
   246  				case d.expr != nil:
   247  					for _, v := range d.expr.Values {
   248  						cn := dn.clone()
   249  						*cn.node = clone(dn.snapshot)
   250  						cn.node.state = cn
   251  
   252  						c := MakeConjunct(d.env, v.Val, d.cloneID)
   253  						cn.addExprConjunct(c, state)
   254  
   255  						newMode := mode(d.hasDefaults, v.Default)
   256  
   257  						cn.expandDisjuncts(state, n, newMode, true, last)
   258  
   259  						// Record the cyclicReferences of the conjunct in the
   260  						// parent list.
   261  						// TODO: avoid the copy. It should be okay to "steal"
   262  						// this list and avoid the copy. But this change is best
   263  						// done in a separate CL.
   264  						for r := n.node.cyclicReferences; r != nil; r = r.Next {
   265  							s := *r
   266  							s.Next = refNode
   267  							refNode = &s
   268  						}
   269  					}
   270  
   271  				case d.value != nil:
   272  					for i, v := range d.value.Values {
   273  						cn := dn.clone()
   274  						*cn.node = clone(dn.snapshot)
   275  						cn.node.state = cn
   276  
   277  						cn.addValueConjunct(d.env, v, d.cloneID)
   278  
   279  						newMode := mode(d.hasDefaults, i < d.value.NumDefaults)
   280  
   281  						cn.expandDisjuncts(state, n, newMode, true, last)
   282  
   283  						// See comment above.
   284  						for r := n.node.cyclicReferences; r != nil; r = r.Next {
   285  							s := *r
   286  							s.Next = refNode
   287  							refNode = &s
   288  						}
   289  					}
   290  				}
   291  			}
   292  
   293  			if skipNonMonotonicChecks {
   294  				n.ctx.inDisjunct--
   295  			}
   296  
   297  			if len(n.disjuncts) == 0 {
   298  				n.makeError()
   299  			}
   300  
   301  			if recursive || i > 0 {
   302  				for _, x := range a {
   303  					x.free()
   304  				}
   305  			}
   306  
   307  			if len(n.disjuncts) == 0 {
   308  				break
   309  			}
   310  		}
   311  
   312  		// Annotate disjunctions with whether any of the default disjunctions
   313  		// was used.
   314  		for _, d := range n.disjuncts {
   315  			for i, info := range d.usedDefault[defaultOffset:] {
   316  				if info.parentMode == isDefault {
   317  					n.disjunctions[i].parentDefaultUsed = true
   318  				}
   319  				if info.origMode == isDefault {
   320  					n.disjunctions[i].childDefaultUsed = true
   321  				}
   322  			}
   323  		}
   324  
   325  		// Combine parent and child default markers, considering that a parent
   326  		// "notDefault" is treated as "maybeDefault" if none of the disjuncts
   327  		// marked as default remain.
   328  		//
   329  		// NOTE for a parent marked as "notDefault", a child is *never*
   330  		// considered as default. It may either be "not" or "maybe" default.
   331  		//
   332  		// The result for each disjunction is conjoined into a single value.
   333  		for _, d := range n.disjuncts {
   334  			m := maybeDefault
   335  			orig := maybeDefault
   336  			for i, info := range d.usedDefault[defaultOffset:] {
   337  				parent := info.parentMode
   338  
   339  				used := n.disjunctions[i].parentDefaultUsed
   340  				childUsed := n.disjunctions[i].childDefaultUsed
   341  				hasDefaults := n.disjunctions[i].hasDefaults
   342  
   343  				orig = combineDefault(orig, info.parentMode)
   344  				orig = combineDefault(orig, info.nestedMode)
   345  
   346  				switch {
   347  				case childUsed:
   348  					// One of the children used a default. This is "normal"
   349  					// mode. This may also happen when we are in
   350  					// hasDefaults/notUsed mode. Consider
   351  					//
   352  					//      ("a" | "b") & (*(*"a" | string) | string)
   353  					//
   354  					// Here the doubly nested default is called twice, once
   355  					// for "a" and then for "b", where the second resolves to
   356  					// not using a default. The first does, however, and on that
   357  					// basis the "ot default marker cannot be overridden.
   358  					m = combineDefault(m, info.parentMode)
   359  					m = combineDefault(m, info.origMode)
   360  
   361  				case !hasDefaults, used:
   362  					m = combineDefault(m, info.parentMode)
   363  					m = combineDefault(m, info.nestedMode)
   364  
   365  				case hasDefaults && !used:
   366  					Assertf(n.ctx, parent == notDefault, "unexpected default mode")
   367  				}
   368  			}
   369  			d.defaultMode = m
   370  
   371  			d.usedDefault = d.usedDefault[:defaultOffset]
   372  			d.usedDefault = append(d.usedDefault, defaultInfo{
   373  				parentMode: parentMode,
   374  				nestedMode: m,
   375  				origMode:   orig,
   376  			})
   377  
   378  		}
   379  
   380  		// TODO: this is an old trick that seems no longer necessary for the new
   381  		// implementation. Keep around until we finalize the semantics for
   382  		// defaults, though. The recursion of nested defaults is not entirely
   383  		// proper yet.
   384  		//
   385  		// A better approach, that avoids the need for recursion (semantically),
   386  		// would be to only consider default usage for one level, but then to
   387  		// also allow a default to be passed if only one value is remaining.
   388  		// This means that a nested subsumption would first have to be evaluated
   389  		// in isolation, however, to determine that it is not previous
   390  		// disjunctions that cause the disambiguation.
   391  		//
   392  		// HACK alert: this replaces the hack of the previous algorithm with a
   393  		// slightly less worse hack: instead of dropping the default info when
   394  		// the value was scalar before, we drop this information when there is
   395  		// only one disjunct, while not discarding hard defaults. TODO: a more
   396  		// principled approach would be to recognize that there is only one
   397  		// default at a point where this does not break commutativity.
   398  		// if len(n.disjuncts) == 1 && n.disjuncts[0].defaultMode != isDefault {
   399  		// 	n.disjuncts[0].defaultMode = maybeDefault
   400  		// }
   401  	}
   402  
   403  	// Compare to root, but add to this one.
   404  	switch p := parent; {
   405  	case p != n:
   406  		p.disjunctErrs = append(p.disjunctErrs, n.disjunctErrs...)
   407  		n.disjunctErrs = n.disjunctErrs[:0]
   408  
   409  	outer:
   410  		for _, d := range n.disjuncts {
   411  			for k, v := range p.disjuncts {
   412  				// As long as a vertex isn't finalized, it may be that potential
   413  				// errors are not yet detected. This may lead two structs that
   414  				// are identical except for closedness information,
   415  				// for instance, to appear identical.
   416  				if v.result.status < finalized || d.result.status < finalized {
   417  					break
   418  				}
   419  				// Even if a node is finalized, it may still have an
   420  				// "incomplete" component that may change down the line.
   421  				if !d.done() || !v.done() {
   422  					break
   423  				}
   424  				flags := CheckStructural
   425  				if last {
   426  					flags |= IgnoreOptional
   427  				}
   428  				if Equal(n.ctx, &v.result, &d.result, flags) {
   429  					m := maybeDefault
   430  					for _, u := range d.usedDefault {
   431  						m = combineDefault(m, u.nestedMode)
   432  					}
   433  					if m == isDefault {
   434  						p.disjuncts[k] = d
   435  						v.free()
   436  					} else {
   437  						d.free()
   438  					}
   439  					continue outer
   440  				}
   441  			}
   442  
   443  			p.disjuncts = append(p.disjuncts, d)
   444  		}
   445  
   446  		n.disjuncts = n.disjuncts[:0]
   447  	}
   448  
   449  	// Record the refNodes in the parent.
   450  	for r := refNode; r != nil; {
   451  		next := r.Next
   452  		r.Next = parent.node.cyclicReferences
   453  		parent.node.cyclicReferences = r
   454  		r = next
   455  	}
   456  }
   457  
   458  func (n *nodeContext) makeError() {
   459  	code := IncompleteError
   460  
   461  	if len(n.disjunctErrs) > 0 {
   462  		code = EvalError
   463  		for _, c := range n.disjunctErrs {
   464  			if c.Code > code {
   465  				code = c.Code
   466  			}
   467  		}
   468  	}
   469  
   470  	b := &Bottom{
   471  		Code: code,
   472  		Err:  n.disjunctError(),
   473  		Node: n.node,
   474  	}
   475  	n.node.SetValue(n.ctx, b)
   476  }
   477  
   478  func mode(hasDefault, marked bool) defaultMode {
   479  	var mode defaultMode
   480  	switch {
   481  	case !hasDefault:
   482  		mode = maybeDefault
   483  	case marked:
   484  		mode = isDefault
   485  	default:
   486  		mode = notDefault
   487  	}
   488  	return mode
   489  }
   490  
   491  // clone makes a shallow copy of a Vertex. The purpose is to create different
   492  // disjuncts from the same Vertex under computation. This allows the conjuncts
   493  // of an arc to be reset to a previous position and the reuse of earlier
   494  // computations.
   495  //
   496  // Notes: only Arcs need to be copied recursively. Either the arc is finalized
   497  // and can be used as is, or Structs is assumed to not yet be computed at the
   498  // time that a clone is needed and must be nil. Conjuncts no longer needed and
   499  // can become nil. All other fields can be copied shallowly.
   500  func clone(v Vertex) Vertex {
   501  	v.state = nil
   502  	if a := v.Arcs; len(a) > 0 {
   503  		v.Arcs = make([]*Vertex, len(a))
   504  		for i, arc := range a {
   505  			switch arc.status {
   506  			case finalized:
   507  				v.Arcs[i] = arc
   508  
   509  			case unprocessed:
   510  				a := *arc
   511  				v.Arcs[i] = &a
   512  				a.Conjuncts = slices.Clone(arc.Conjuncts)
   513  
   514  			default:
   515  				a := *arc
   516  				a.state = arc.state.clone()
   517  				a.state.node = &a
   518  				a.state.snapshot = clone(a)
   519  				v.Arcs[i] = &a
   520  			}
   521  		}
   522  	}
   523  
   524  	if a := v.Structs; len(a) > 0 {
   525  		v.Structs = slices.Clone(a)
   526  	}
   527  
   528  	return v
   529  }
   530  
   531  // Default rules from spec:
   532  //
   533  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   534  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   535  //
   536  // D1: (v1, d1) | v2       => (v1|v2, d1)
   537  // D2: (v1, d1) | (v2, d2) => (v1|v2, d1|d2)
   538  //
   539  // M1: *v        => (v, v)
   540  // M2: *(v1, d1) => (v1, d1)
   541  //
   542  // NOTE: M2 cannot be *(v1, d1) => (v1, v1), as this has the weird property
   543  // of making a value less specific. This causes issues, for instance, when
   544  // trimming.
   545  //
   546  // The old implementation does something similar though. It will discard
   547  // default information after first determining if more than one conjunct
   548  // has survived.
   549  //
   550  // def + maybe -> def
   551  // not + maybe -> def
   552  // not + def   -> def
   553  
   554  type defaultMode int
   555  
   556  const (
   557  	maybeDefault defaultMode = iota
   558  	isDefault
   559  	notDefault
   560  )
   561  
   562  // combineDefaults combines default modes for unifying conjuncts.
   563  //
   564  // Default rules from spec:
   565  //
   566  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   567  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   568  func combineDefault(a, b defaultMode) defaultMode {
   569  	if a > b {
   570  		return a
   571  	}
   572  	return b
   573  }
   574  
   575  // disjunctError returns a compound error for a failed disjunction.
   576  //
   577  // TODO(perf): the set of errors is now computed during evaluation. Eventually,
   578  // this could be done lazily.
   579  func (n *nodeContext) disjunctError() (errs errors.Error) {
   580  	ctx := n.ctx
   581  
   582  	disjuncts := selectErrors(n.disjunctErrs)
   583  
   584  	if disjuncts == nil {
   585  		errs = ctx.Newf("empty disjunction") // XXX: add space to sort first
   586  	} else {
   587  		disjuncts = errors.Sanitize(disjuncts)
   588  		k := len(errors.Errors(disjuncts))
   589  		if k == 1 {
   590  			return disjuncts
   591  		}
   592  		// prefix '-' to sort to top
   593  		errs = ctx.Newf("%d errors in empty disjunction:", k)
   594  		errs = errors.Append(errs, disjuncts)
   595  	}
   596  
   597  	return errs
   598  }
   599  
   600  func selectErrors(a []*Bottom) (errs errors.Error) {
   601  	// return all errors if less than a certain number.
   602  	if len(a) <= 2 {
   603  		for _, b := range a {
   604  			errs = errors.Append(errs, b.Err)
   605  
   606  		}
   607  		return errs
   608  	}
   609  
   610  	// First select only relevant errors.
   611  	isIncomplete := false
   612  	k := 0
   613  	for _, b := range a {
   614  		if !isIncomplete && b.Code >= IncompleteError {
   615  			k = 0
   616  			isIncomplete = true
   617  		}
   618  		a[k] = b
   619  		k++
   620  	}
   621  	a = a[:k]
   622  
   623  	// filter errors
   624  	positions := map[token.Pos]bool{}
   625  
   626  	add := func(b *Bottom, p token.Pos) bool {
   627  		if positions[p] {
   628  			return false
   629  		}
   630  		positions[p] = true
   631  		errs = errors.Append(errs, b.Err)
   632  		return true
   633  	}
   634  
   635  	for _, b := range a {
   636  		// TODO: Should we also distinguish by message type?
   637  		if add(b, b.Err.Position()) {
   638  			continue
   639  		}
   640  		for _, p := range b.Err.InputPositions() {
   641  			if add(b, p) {
   642  				break
   643  			}
   644  		}
   645  	}
   646  
   647  	return errs
   648  }