cuelang.org/go@v0.10.1/internal/core/adt/disjunct.go (about)

     1  // Copyright 2020 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package adt
    16  
    17  import (
    18  	"slices"
    19  
    20  	"cuelang.org/go/cue/errors"
    21  	"cuelang.org/go/cue/token"
    22  )
    23  
    24  // Nodes man not reenter a disjunction.
    25  //
    26  // Copy one layer deep; throw away items on failure.
    27  
    28  // DISJUNCTION ALGORITHM
    29  //
    30  // The basic concept of the algorithm is to use backtracking to find valid
    31  // disjunctions. The algorithm can stop if two matching disjuncts are found
    32  // where one does not subsume the other.
    33  //
    34  // At a later point, we can introduce a filter step to filter out possible
    35  // disjuncts based on, say, discriminator fields or field exclusivity (oneOf
    36  // fields in Protobuf).
    37  //
    38  // To understand the details of the algorithm, it is important to understand
    39  // some properties of disjunction.
    40  //
    41  //
    42  // EVALUATION OF A DISJUNCTION IS SELF CONTAINED
    43  //
    44  // In other words, fields outside of a disjunction cannot bind to values within
    45  // a disjunction whilst evaluating that disjunction. This allows the computation
    46  // of disjunctions to be isolated from side effects.
    47  //
    48  // The intuition behind this is as follows: as a disjunction is not a concrete
    49  // value, it is not possible to lookup a field within a disjunction if it has
    50  // not yet been evaluated. So if a reference within a disjunction that is needed
    51  // to disambiguate that disjunction refers to a field outside the scope of the
    52  // disjunction which, in turn, refers to a field within the disjunction, this
    53  // results in a cycle error. We achieve this by not removing the cycle marker of
    54  // the Vertex of the disjunction until the disjunction is resolved.
    55  //
    56  // Note that the following disjunct is still allowed:
    57  //
    58  //    a: 1
    59  //    b: a
    60  //
    61  // Even though `a` refers to the root of the disjunction, it does not _select
    62  // into_ the disjunction. Implementation-wise, it also doesn't have to, as the
    63  // respective vertex is available within the Environment. Referencing a node
    64  // outside the disjunction that in turn selects the disjunction root, however,
    65  // will result in a detected cycle.
    66  //
    67  // As usual, cycle detection should be interpreted marked as incomplete, so that
    68  // the referring node will not be fixed to an error prematurely.
    69  //
    70  //
    71  // SUBSUMPTION OF AMBIGUOUS DISJUNCTS
    72  //
    73  // A disjunction can be evaluated to a concrete value if only one disjunct
    74  // remains. Aside from disambiguating through unification failure, disjuncts
    75  // may also be disambiguated by taking the least specific of two disjuncts.
    76  // For instance, if a subsumes b, then the result of disjunction may be a.
    77  //
    78  //   NEW ALGORITHM NO LONGER VERIFIES SUBSUMPTION. SUBSUMPTION IS INHERENTLY
    79  //   IMPRECISE (DUE TO BULK OPTIONAL FIELDS). OTHER THAN THAT, FOR SCALAR VALUES
    80  //   IT JUST MEANS THERE IS AMBIGUITY, AND FOR STRUCTS IT CAN LEAD TO STRANGE
    81  //   CONSEQUENCES.
    82  //
    83  //   USE EQUALITY INSTEAD:
    84  //     - Undefined == error for optional fields.
    85  //     - So only need to check exact labels for vertices.
    86  
    87  type envDisjunct struct {
    88  	env     *Environment
    89  	cloneID CloseInfo
    90  
    91  	// fields for new evaluator
    92  
    93  	src       Node
    94  	disjuncts []disjunct
    95  
    96  	// fields for old evaluator
    97  
    98  	expr        *DisjunctionExpr
    99  	value       *Disjunction
   100  	hasDefaults bool
   101  
   102  	// These are used for book keeping, tracking whether any of the
   103  	// disjuncts marked with a default marker remains after unification.
   104  	// If no default is used, all other elements are treated as "maybeDefault".
   105  	// Otherwise, elements are treated as is.
   106  	parentDefaultUsed bool
   107  	childDefaultUsed  bool
   108  }
   109  
   110  func (n *nodeContext) addDisjunction(env *Environment, x *DisjunctionExpr, cloneID CloseInfo) {
   111  
   112  	// TODO: precompute
   113  	numDefaults := 0
   114  	for _, v := range x.Values {
   115  		isDef := v.Default // || n.hasDefaults(env, v.Val)
   116  		if isDef {
   117  			numDefaults++
   118  		}
   119  	}
   120  
   121  	n.disjunctions = append(n.disjunctions, envDisjunct{
   122  		env:         env,
   123  		cloneID:     cloneID,
   124  		expr:        x,
   125  		hasDefaults: numDefaults > 0,
   126  	})
   127  }
   128  
   129  func (n *nodeContext) addDisjunctionValue(env *Environment, x *Disjunction, cloneID CloseInfo) {
   130  	n.disjunctions = append(n.disjunctions, envDisjunct{
   131  		env:         env,
   132  		cloneID:     cloneID,
   133  		value:       x,
   134  		hasDefaults: x.HasDefaults,
   135  	})
   136  
   137  }
   138  
   139  func (n *nodeContext) expandDisjuncts(
   140  	state vertexStatus,
   141  	parent *nodeContext,
   142  	parentMode defaultMode, // default mode of this disjunct
   143  	recursive, last bool) {
   144  
   145  	unreachableForDev(n.ctx)
   146  
   147  	n.ctx.stats.Disjuncts++
   148  
   149  	// refNode is used to collect cyclicReferences for all disjuncts to be
   150  	// passed up to the parent node. Note that because the node in the parent
   151  	// context is overwritten in the course of expanding disjunction to retain
   152  	// pointer identity, it is not possible to simply record the refNodes in the
   153  	// parent directly.
   154  	var refNode *RefNode
   155  
   156  	node := n.node
   157  	defer func() {
   158  		n.node = node
   159  	}()
   160  
   161  	for n.expandOne(partial) {
   162  	}
   163  
   164  	// save node to snapShot in nodeContex
   165  	// save nodeContext.
   166  
   167  	if recursive || len(n.disjunctions) > 0 {
   168  		n.snapshot = clone(*n.node)
   169  	} else {
   170  		n.snapshot = *n.node
   171  	}
   172  
   173  	defaultOffset := len(n.usedDefault)
   174  
   175  	switch {
   176  	default: // len(n.disjunctions) == 0
   177  		m := *n
   178  		n.postDisjunct(state)
   179  
   180  		switch {
   181  		case n.hasErr():
   182  			// TODO: consider finalizing the node thusly:
   183  			// if recursive {
   184  			// 	n.node.Finalize(n.ctx)
   185  			// }
   186  			x := n.node
   187  			err, ok := x.BaseValue.(*Bottom)
   188  			if !ok {
   189  				err = n.getErr()
   190  			}
   191  			if err == nil {
   192  				// TODO(disjuncts): Is this always correct? Especially for partial
   193  				// evaluation it is okay for child errors to have incomplete errors.
   194  				// Perhaps introduce an Err() method.
   195  				err = x.ChildErrors
   196  			}
   197  			if err != nil {
   198  				parent.disjunctErrs = append(parent.disjunctErrs, err)
   199  			}
   200  			if recursive {
   201  				n.free()
   202  			}
   203  			return
   204  		}
   205  
   206  		if recursive {
   207  			*n = m
   208  			n.result = *n.node // XXX: n.result = snapshotVertex(n.node)?
   209  			n.node = &n.result
   210  			n.disjuncts = append(n.disjuncts, n)
   211  		}
   212  		if n.node.BaseValue == nil {
   213  			n.node.BaseValue = n.getValidators(state)
   214  		}
   215  
   216  		n.usedDefault = append(n.usedDefault, defaultInfo{
   217  			parentMode: parentMode,
   218  			nestedMode: parentMode,
   219  			origMode:   parentMode,
   220  		})
   221  
   222  	case len(n.disjunctions) > 0:
   223  		// Process full disjuncts to ensure that erroneous disjuncts are
   224  		// eliminated as early as possible.
   225  		state = finalized
   226  
   227  		n.disjuncts = append(n.disjuncts, n)
   228  
   229  		n.refCount++
   230  		defer n.free()
   231  
   232  		for i, d := range n.disjunctions {
   233  			a := n.disjuncts
   234  			n.disjuncts = n.buffer[:0]
   235  			n.buffer = a[:0]
   236  
   237  			last := i+1 == len(n.disjunctions)
   238  			skipNonMonotonicChecks := i+1 < len(n.disjunctions)
   239  			if skipNonMonotonicChecks {
   240  				n.ctx.inDisjunct++
   241  			}
   242  
   243  			for _, dn := range a {
   244  				switch {
   245  				case d.expr != nil:
   246  					for _, v := range d.expr.Values {
   247  						cn := dn.clone()
   248  						*cn.node = clone(dn.snapshot)
   249  						cn.node.state = cn
   250  
   251  						c := MakeConjunct(d.env, v.Val, d.cloneID)
   252  						cn.addExprConjunct(c, state)
   253  
   254  						newMode := mode(d.hasDefaults, v.Default)
   255  
   256  						cn.expandDisjuncts(state, n, newMode, true, last)
   257  
   258  						// Record the cyclicReferences of the conjunct in the
   259  						// parent list.
   260  						// TODO: avoid the copy. It should be okay to "steal"
   261  						// this list and avoid the copy. But this change is best
   262  						// done in a separate CL.
   263  						for r := n.node.cyclicReferences; r != nil; r = r.Next {
   264  							s := *r
   265  							s.Next = refNode
   266  							refNode = &s
   267  						}
   268  					}
   269  
   270  				case d.value != nil:
   271  					for i, v := range d.value.Values {
   272  						cn := dn.clone()
   273  						*cn.node = clone(dn.snapshot)
   274  						cn.node.state = cn
   275  
   276  						cn.addValueConjunct(d.env, v, d.cloneID)
   277  
   278  						newMode := mode(d.hasDefaults, i < d.value.NumDefaults)
   279  
   280  						cn.expandDisjuncts(state, n, newMode, true, last)
   281  
   282  						// See comment above.
   283  						for r := n.node.cyclicReferences; r != nil; r = r.Next {
   284  							s := *r
   285  							s.Next = refNode
   286  							refNode = &s
   287  						}
   288  					}
   289  				}
   290  			}
   291  
   292  			if skipNonMonotonicChecks {
   293  				n.ctx.inDisjunct--
   294  			}
   295  
   296  			if len(n.disjuncts) == 0 {
   297  				n.makeError()
   298  			}
   299  
   300  			if recursive || i > 0 {
   301  				for _, x := range a {
   302  					x.free()
   303  				}
   304  			}
   305  
   306  			if len(n.disjuncts) == 0 {
   307  				break
   308  			}
   309  		}
   310  
   311  		// Annotate disjunctions with whether any of the default disjunctions
   312  		// was used.
   313  		for _, d := range n.disjuncts {
   314  			for i, info := range d.usedDefault[defaultOffset:] {
   315  				if info.parentMode == isDefault {
   316  					n.disjunctions[i].parentDefaultUsed = true
   317  				}
   318  				if info.origMode == isDefault {
   319  					n.disjunctions[i].childDefaultUsed = true
   320  				}
   321  			}
   322  		}
   323  
   324  		// Combine parent and child default markers, considering that a parent
   325  		// "notDefault" is treated as "maybeDefault" if none of the disjuncts
   326  		// marked as default remain.
   327  		//
   328  		// NOTE for a parent marked as "notDefault", a child is *never*
   329  		// considered as default. It may either be "not" or "maybe" default.
   330  		//
   331  		// The result for each disjunction is conjoined into a single value.
   332  		for _, d := range n.disjuncts {
   333  			m := maybeDefault
   334  			orig := maybeDefault
   335  			for i, info := range d.usedDefault[defaultOffset:] {
   336  				parent := info.parentMode
   337  
   338  				used := n.disjunctions[i].parentDefaultUsed
   339  				childUsed := n.disjunctions[i].childDefaultUsed
   340  				hasDefaults := n.disjunctions[i].hasDefaults
   341  
   342  				orig = combineDefault(orig, info.parentMode)
   343  				orig = combineDefault(orig, info.nestedMode)
   344  
   345  				switch {
   346  				case childUsed:
   347  					// One of the children used a default. This is "normal"
   348  					// mode. This may also happen when we are in
   349  					// hasDefaults/notUsed mode. Consider
   350  					//
   351  					//      ("a" | "b") & (*(*"a" | string) | string)
   352  					//
   353  					// Here the doubly nested default is called twice, once
   354  					// for "a" and then for "b", where the second resolves to
   355  					// not using a default. The first does, however, and on that
   356  					// basis the "ot default marker cannot be overridden.
   357  					m = combineDefault(m, info.parentMode)
   358  					m = combineDefault(m, info.origMode)
   359  
   360  				case !hasDefaults, used:
   361  					m = combineDefault(m, info.parentMode)
   362  					m = combineDefault(m, info.nestedMode)
   363  
   364  				case hasDefaults && !used:
   365  					Assertf(n.ctx, parent == notDefault, "unexpected default mode")
   366  				}
   367  			}
   368  			d.defaultMode = m
   369  
   370  			d.usedDefault = d.usedDefault[:defaultOffset]
   371  			d.usedDefault = append(d.usedDefault, defaultInfo{
   372  				parentMode: parentMode,
   373  				nestedMode: m,
   374  				origMode:   orig,
   375  			})
   376  
   377  		}
   378  
   379  		// TODO: this is an old trick that seems no longer necessary for the new
   380  		// implementation. Keep around until we finalize the semantics for
   381  		// defaults, though. The recursion of nested defaults is not entirely
   382  		// proper yet.
   383  		//
   384  		// A better approach, that avoids the need for recursion (semantically),
   385  		// would be to only consider default usage for one level, but then to
   386  		// also allow a default to be passed if only one value is remaining.
   387  		// This means that a nested subsumption would first have to be evaluated
   388  		// in isolation, however, to determine that it is not previous
   389  		// disjunctions that cause the disambiguation.
   390  		//
   391  		// HACK alert: this replaces the hack of the previous algorithm with a
   392  		// slightly less worse hack: instead of dropping the default info when
   393  		// the value was scalar before, we drop this information when there is
   394  		// only one disjunct, while not discarding hard defaults. TODO: a more
   395  		// principled approach would be to recognize that there is only one
   396  		// default at a point where this does not break commutativity.
   397  		// if len(n.disjuncts) == 1 && n.disjuncts[0].defaultMode != isDefault {
   398  		// 	n.disjuncts[0].defaultMode = maybeDefault
   399  		// }
   400  	}
   401  
   402  	// Compare to root, but add to this one.
   403  	switch p := parent; {
   404  	case p != n:
   405  		p.disjunctErrs = append(p.disjunctErrs, n.disjunctErrs...)
   406  		n.disjunctErrs = n.disjunctErrs[:0]
   407  
   408  	outer:
   409  		for _, d := range n.disjuncts {
   410  			for k, v := range p.disjuncts {
   411  				// As long as a vertex isn't finalized, it may be that potential
   412  				// errors are not yet detected. This may lead two structs that
   413  				// are identical except for closedness information,
   414  				// for instance, to appear identical.
   415  				if v.result.status < finalized || d.result.status < finalized {
   416  					break
   417  				}
   418  				// Even if a node is finalized, it may still have an
   419  				// "incomplete" component that may change down the line.
   420  				if !d.done() || !v.done() {
   421  					break
   422  				}
   423  				flags := CheckStructural
   424  				if last {
   425  					flags |= IgnoreOptional
   426  				}
   427  				if Equal(n.ctx, &v.result, &d.result, flags) {
   428  					m := maybeDefault
   429  					for _, u := range d.usedDefault {
   430  						m = combineDefault(m, u.nestedMode)
   431  					}
   432  					if m == isDefault {
   433  						p.disjuncts[k] = d
   434  						v.free()
   435  					} else {
   436  						d.free()
   437  					}
   438  					continue outer
   439  				}
   440  			}
   441  
   442  			p.disjuncts = append(p.disjuncts, d)
   443  		}
   444  
   445  		n.disjuncts = n.disjuncts[:0]
   446  	}
   447  
   448  	// Record the refNodes in the parent.
   449  	for r := refNode; r != nil; {
   450  		next := r.Next
   451  		r.Next = parent.node.cyclicReferences
   452  		parent.node.cyclicReferences = r
   453  		r = next
   454  	}
   455  }
   456  
   457  func (n *nodeContext) makeError() {
   458  	code := IncompleteError
   459  
   460  	if len(n.disjunctErrs) > 0 {
   461  		code = EvalError
   462  		for _, c := range n.disjunctErrs {
   463  			if c.Code > code {
   464  				code = c.Code
   465  			}
   466  		}
   467  	}
   468  
   469  	b := &Bottom{
   470  		Code: code,
   471  		Err:  n.disjunctError(),
   472  	}
   473  	n.node.SetValue(n.ctx, b)
   474  }
   475  
   476  func mode(hasDefault, marked bool) defaultMode {
   477  	var mode defaultMode
   478  	switch {
   479  	case !hasDefault:
   480  		mode = maybeDefault
   481  	case marked:
   482  		mode = isDefault
   483  	default:
   484  		mode = notDefault
   485  	}
   486  	return mode
   487  }
   488  
   489  // clone makes a shallow copy of a Vertex. The purpose is to create different
   490  // disjuncts from the same Vertex under computation. This allows the conjuncts
   491  // of an arc to be reset to a previous position and the reuse of earlier
   492  // computations.
   493  //
   494  // Notes: only Arcs need to be copied recursively. Either the arc is finalized
   495  // and can be used as is, or Structs is assumed to not yet be computed at the
   496  // time that a clone is needed and must be nil. Conjuncts no longer needed and
   497  // can become nil. All other fields can be copied shallowly.
   498  func clone(v Vertex) Vertex {
   499  	v.state = nil
   500  	if a := v.Arcs; len(a) > 0 {
   501  		v.Arcs = make([]*Vertex, len(a))
   502  		for i, arc := range a {
   503  			switch arc.status {
   504  			case finalized:
   505  				v.Arcs[i] = arc
   506  
   507  			case unprocessed:
   508  				a := *arc
   509  				v.Arcs[i] = &a
   510  				a.Conjuncts = slices.Clone(arc.Conjuncts)
   511  
   512  			default:
   513  				a := *arc
   514  				a.state = arc.state.clone()
   515  				a.state.node = &a
   516  				a.state.snapshot = clone(a)
   517  				v.Arcs[i] = &a
   518  			}
   519  		}
   520  	}
   521  
   522  	if a := v.Structs; len(a) > 0 {
   523  		v.Structs = slices.Clone(a)
   524  	}
   525  
   526  	return v
   527  }
   528  
   529  // Default rules from spec:
   530  //
   531  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   532  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   533  //
   534  // D1: (v1, d1) | v2       => (v1|v2, d1)
   535  // D2: (v1, d1) | (v2, d2) => (v1|v2, d1|d2)
   536  //
   537  // M1: *v        => (v, v)
   538  // M2: *(v1, d1) => (v1, d1)
   539  //
   540  // NOTE: M2 cannot be *(v1, d1) => (v1, v1), as this has the weird property
   541  // of making a value less specific. This causes issues, for instance, when
   542  // trimming.
   543  //
   544  // The old implementation does something similar though. It will discard
   545  // default information after first determining if more than one conjunct
   546  // has survived.
   547  //
   548  // def + maybe -> def
   549  // not + maybe -> def
   550  // not + def   -> def
   551  
   552  type defaultMode int
   553  
   554  const (
   555  	maybeDefault defaultMode = iota
   556  	isDefault
   557  	notDefault
   558  )
   559  
   560  // combineDefaults combines default modes for unifying conjuncts.
   561  //
   562  // Default rules from spec:
   563  //
   564  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   565  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   566  func combineDefault(a, b defaultMode) defaultMode {
   567  	if a > b {
   568  		return a
   569  	}
   570  	return b
   571  }
   572  
   573  // disjunctError returns a compound error for a failed disjunction.
   574  //
   575  // TODO(perf): the set of errors is now computed during evaluation. Eventually,
   576  // this could be done lazily.
   577  func (n *nodeContext) disjunctError() (errs errors.Error) {
   578  	ctx := n.ctx
   579  
   580  	disjuncts := selectErrors(n.disjunctErrs)
   581  
   582  	if disjuncts == nil {
   583  		errs = ctx.Newf("empty disjunction") // XXX: add space to sort first
   584  	} else {
   585  		disjuncts = errors.Sanitize(disjuncts)
   586  		k := len(errors.Errors(disjuncts))
   587  		if k == 1 {
   588  			return disjuncts
   589  		}
   590  		// prefix '-' to sort to top
   591  		errs = ctx.Newf("%d errors in empty disjunction:", k)
   592  		errs = errors.Append(errs, disjuncts)
   593  	}
   594  
   595  	return errs
   596  }
   597  
   598  func selectErrors(a []*Bottom) (errs errors.Error) {
   599  	// return all errors if less than a certain number.
   600  	if len(a) <= 2 {
   601  		for _, b := range a {
   602  			errs = errors.Append(errs, b.Err)
   603  
   604  		}
   605  		return errs
   606  	}
   607  
   608  	// First select only relevant errors.
   609  	isIncomplete := false
   610  	k := 0
   611  	for _, b := range a {
   612  		if !isIncomplete && b.Code >= IncompleteError {
   613  			k = 0
   614  			isIncomplete = true
   615  		}
   616  		a[k] = b
   617  		k++
   618  	}
   619  	a = a[:k]
   620  
   621  	// filter errors
   622  	positions := map[token.Pos]bool{}
   623  
   624  	add := func(b *Bottom, p token.Pos) bool {
   625  		if positions[p] {
   626  			return false
   627  		}
   628  		positions[p] = true
   629  		errs = errors.Append(errs, b.Err)
   630  		return true
   631  	}
   632  
   633  	for _, b := range a {
   634  		// TODO: Should we also distinguish by message type?
   635  		if add(b, b.Err.Position()) {
   636  			continue
   637  		}
   638  		for _, p := range b.Err.InputPositions() {
   639  			if add(b, p) {
   640  				break
   641  			}
   642  		}
   643  	}
   644  
   645  	return errs
   646  }