github.com/solo-io/cue@v0.4.7/internal/core/adt/disjunct.go (about)

     1  // Copyright 2020 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package adt
    16  
    17  import (
    18  	"github.com/solo-io/cue/cue/errors"
    19  	"github.com/solo-io/cue/cue/token"
    20  )
    21  
    22  // Nodes man not reenter a disjunction.
    23  //
    24  // Copy one layer deep; throw away items on failure.
    25  
    26  // DISJUNCTION ALGORITHM
    27  //
    28  // The basic concept of the algorithm is to use backtracking to find valid
    29  // disjunctions. The algorithm can stop if two matching disjuncts are found
    30  // where one does not subsume the other.
    31  //
    32  // At a later point, we can introduce a filter step to filter out possible
    33  // disjuncts based on, say, discriminator fields or field exclusivity (oneOf
    34  // fields in Protobuf).
    35  //
    36  // To understand the details of the algorithm, it is important to understand
    37  // some properties of disjunction.
    38  //
    39  //
    40  // EVALUATION OF A DISJUNCTION IS SELF CONTAINED
    41  //
    42  // In other words, fields outside of a disjunction cannot bind to values within
    43  // a disjunction whilst evaluating that disjunction. This allows the computation
    44  // of disjunctions to be isolated from side effects.
    45  //
    46  // The intuition behind this is as follows: as a disjunction is not a concrete
    47  // value, it is not possible to lookup a field within a disjunction if it has
    48  // not yet been evaluated. So if a reference within a disjunction that is needed
    49  // to disambiguate that disjunction refers to a field outside the scope of the
    50  // disjunction which, in turn, refers to a field within the disjunction, this
    51  // results in a cycle error. We achieve this by not removing the cycle marker of
    52  // the Vertex of the disjunction until the disjunction is resolved.
    53  //
    54  // Note that the following disjunct is still allowed:
    55  //
    56  //    a: 1
    57  //    b: a
    58  //
    59  // Even though `a` refers to the root of the disjunction, it does not _select
    60  // into_ the disjunction. Implementation-wise, it also doesn't have to, as the
    61  // respective vertex is available within the Environment. Referencing a node
    62  // outside the disjunction that in turn selects the disjunction root, however,
    63  // will result in a detected cycle.
    64  //
    65  // As usual, cycle detection should be interpreted marked as incomplete, so that
    66  // the referring node will not be fixed to an error prematurely.
    67  //
    68  //
    69  // SUBSUMPTION OF AMBIGUOUS DISJUNCTS
    70  //
    71  // A disjunction can be evaluated to a concrete value if only one disjunct
    72  // remains. Aside from disambiguating through unification failure, disjuncts
    73  // may also be disambiguated by taking the least specific of two disjuncts.
    74  // For instance, if a subsumes b, then the result of disjunction may be a.
    75  //
    76  //   NEW ALGORITHM NO LONGER VERIFIES SUBSUMPTION. SUBSUMPTION IS INHERENTLY
    77  //   IMPRECISE (DUE TO BULK OPTIONAL FIELDS). OTHER THAN THAT, FOR SCALAR VALUES
    78  //   IT JUST MEANS THERE IS AMBIGUITY, AND FOR STRUCTS IT CAN LEAD TO STRANGE
    79  //   CONSEQUENCES.
    80  //
    81  //   USE EQUALITY INSTEAD:
    82  //     - Undefined == error for optional fields.
    83  //     - So only need to check exact labels for vertices.
    84  
    85  type envDisjunct struct {
    86  	env         *Environment
    87  	cloneID     CloseInfo
    88  	expr        *DisjunctionExpr
    89  	value       *Disjunction
    90  	hasDefaults bool
    91  
    92  	// These are used for book keeping, tracking whether any of the
    93  	// disjuncts marked with a default marker remains after unification.
    94  	// If no default is used, all other elements are treated as "maybeDefault".
    95  	// Otherwise, elements are treated as is.
    96  	parentDefaultUsed bool
    97  	childDefaultUsed  bool
    98  }
    99  
   100  func (n *nodeContext) addDisjunction(env *Environment, x *DisjunctionExpr, cloneID CloseInfo) {
   101  
   102  	// TODO: precompute
   103  	numDefaults := 0
   104  	for _, v := range x.Values {
   105  		isDef := v.Default // || n.hasDefaults(env, v.Val)
   106  		if isDef {
   107  			numDefaults++
   108  		}
   109  	}
   110  
   111  	n.disjunctions = append(n.disjunctions,
   112  		envDisjunct{env, cloneID, x, nil, numDefaults > 0, false, false})
   113  }
   114  
   115  func (n *nodeContext) addDisjunctionValue(env *Environment, x *Disjunction, cloneID CloseInfo) {
   116  	n.disjunctions = append(n.disjunctions,
   117  		envDisjunct{env, cloneID, nil, x, x.HasDefaults, false, false})
   118  
   119  }
   120  
   121  func (n *nodeContext) expandDisjuncts(
   122  	state VertexStatus,
   123  	parent *nodeContext,
   124  	parentMode defaultMode, // default mode of this disjunct
   125  	recursive, last bool) {
   126  
   127  	n.ctx.stats.DisjunctCount++
   128  
   129  	node := n.node
   130  	defer func() {
   131  		n.node = node
   132  	}()
   133  
   134  	for n.expandOne() {
   135  	}
   136  
   137  	// save node to snapShot in nodeContex
   138  	// save nodeContext.
   139  
   140  	if recursive || len(n.disjunctions) > 0 {
   141  		n.snapshot = clone(*n.node)
   142  	} else {
   143  		n.snapshot = *n.node
   144  	}
   145  
   146  	defaultOffset := len(n.usedDefault)
   147  
   148  	switch {
   149  	default: // len(n.disjunctions) == 0
   150  		m := *n
   151  		n.postDisjunct(state)
   152  
   153  		switch {
   154  		case n.hasErr():
   155  			// TODO: consider finalizing the node thusly:
   156  			// if recursive {
   157  			// 	n.node.Finalize(n.ctx)
   158  			// }
   159  			x := n.node
   160  			err, ok := x.BaseValue.(*Bottom)
   161  			if !ok {
   162  				err = n.getErr()
   163  			}
   164  			if err == nil {
   165  				// TODO(disjuncts): Is this always correct? Especially for partial
   166  				// evaluation it is okay for child errors to have incomplete errors.
   167  				// Perhaps introduce an Err() method.
   168  				err = x.ChildErrors
   169  			}
   170  			if err.IsIncomplete() {
   171  				break
   172  			}
   173  			if err != nil {
   174  				parent.disjunctErrs = append(parent.disjunctErrs, err)
   175  			}
   176  			if recursive {
   177  				n.free()
   178  			}
   179  			return
   180  		}
   181  
   182  		if recursive {
   183  			*n = m
   184  			n.result = *n.node // XXX: n.result = snapshotVertex(n.node)?
   185  			n.node = &n.result
   186  			n.disjuncts = append(n.disjuncts, n)
   187  		}
   188  		if n.node.BaseValue == nil {
   189  			n.node.BaseValue = n.getValidators()
   190  		}
   191  
   192  		n.usedDefault = append(n.usedDefault, defaultInfo{
   193  			parentMode: parentMode,
   194  			nestedMode: parentMode,
   195  			origMode:   parentMode,
   196  		})
   197  
   198  	case len(n.disjunctions) > 0:
   199  		// Process full disjuncts to ensure that erroneous disjuncts are
   200  		// eliminated as early as possible.
   201  		state = Finalized
   202  
   203  		n.disjuncts = append(n.disjuncts, n)
   204  
   205  		n.refCount++
   206  		defer n.free()
   207  
   208  		for i, d := range n.disjunctions {
   209  			a := n.disjuncts
   210  			n.disjuncts = n.buffer[:0]
   211  			n.buffer = a[:0]
   212  
   213  			last := i+1 == len(n.disjunctions)
   214  			skipNonMonotonicChecks := i+1 < len(n.disjunctions)
   215  			if skipNonMonotonicChecks {
   216  				n.ctx.inDisjunct++
   217  			}
   218  
   219  			for _, dn := range a {
   220  				switch {
   221  				case d.expr != nil:
   222  					for _, v := range d.expr.Values {
   223  						cn := dn.clone()
   224  						*cn.node = clone(dn.snapshot)
   225  						cn.node.state = cn
   226  
   227  						c := MakeConjunct(d.env, v.Val, d.cloneID)
   228  						cn.addExprConjunct(c)
   229  
   230  						newMode := mode(d.hasDefaults, v.Default)
   231  
   232  						cn.expandDisjuncts(state, n, newMode, true, last)
   233  					}
   234  
   235  				case d.value != nil:
   236  					for i, v := range d.value.Values {
   237  						cn := dn.clone()
   238  						*cn.node = clone(dn.snapshot)
   239  						cn.node.state = cn
   240  
   241  						cn.addValueConjunct(d.env, v, d.cloneID)
   242  
   243  						newMode := mode(d.hasDefaults, i < d.value.NumDefaults)
   244  
   245  						cn.expandDisjuncts(state, n, newMode, true, last)
   246  					}
   247  				}
   248  			}
   249  
   250  			if skipNonMonotonicChecks {
   251  				n.ctx.inDisjunct--
   252  			}
   253  
   254  			if len(n.disjuncts) == 0 {
   255  				n.makeError()
   256  			}
   257  
   258  			if recursive || i > 0 {
   259  				for _, x := range a {
   260  					x.free()
   261  				}
   262  			}
   263  
   264  			if len(n.disjuncts) == 0 {
   265  				break
   266  			}
   267  		}
   268  
   269  		// Annotate disjunctions with whether any of the default disjunctions
   270  		// was used.
   271  		for _, d := range n.disjuncts {
   272  			for i, info := range d.usedDefault[defaultOffset:] {
   273  				if info.parentMode == isDefault {
   274  					n.disjunctions[i].parentDefaultUsed = true
   275  				}
   276  				if info.origMode == isDefault {
   277  					n.disjunctions[i].childDefaultUsed = true
   278  				}
   279  			}
   280  		}
   281  
   282  		// Combine parent and child default markers, considering that a parent
   283  		// "notDefault" is treated as "maybeDefault" if none of the disjuncts
   284  		// marked as default remain.
   285  		//
   286  		// NOTE for a parent marked as "notDefault", a child is *never*
   287  		// considered as default. It may either be "not" or "maybe" default.
   288  		//
   289  		// The result for each disjunction is conjoined into a single value.
   290  		for _, d := range n.disjuncts {
   291  			m := maybeDefault
   292  			orig := maybeDefault
   293  			for i, info := range d.usedDefault[defaultOffset:] {
   294  				parent := info.parentMode
   295  
   296  				used := n.disjunctions[i].parentDefaultUsed
   297  				childUsed := n.disjunctions[i].childDefaultUsed
   298  				hasDefaults := n.disjunctions[i].hasDefaults
   299  
   300  				orig = combineDefault(orig, info.parentMode)
   301  				orig = combineDefault(orig, info.nestedMode)
   302  
   303  				switch {
   304  				case childUsed:
   305  					// One of the children used a default. This is "normal"
   306  					// mode. This may also happen when we are in
   307  					// hasDefaults/notUsed mode. Consider
   308  					//
   309  					//      ("a" | "b") & (*(*"a" | string) | string)
   310  					//
   311  					// Here the doubly nested default is called twice, once
   312  					// for "a" and then for "b", where the second resolves to
   313  					// not using a default. The first does, however, and on that
   314  					// basis the "ot default marker cannot be overridden.
   315  					m = combineDefault(m, info.parentMode)
   316  					m = combineDefault(m, info.origMode)
   317  
   318  				case !hasDefaults, used:
   319  					m = combineDefault(m, info.parentMode)
   320  					m = combineDefault(m, info.nestedMode)
   321  
   322  				case hasDefaults && !used:
   323  					Assertf(parent == notDefault, "unexpected default mode")
   324  				}
   325  			}
   326  			d.defaultMode = m
   327  
   328  			d.usedDefault = d.usedDefault[:defaultOffset]
   329  			d.usedDefault = append(d.usedDefault, defaultInfo{
   330  				parentMode: parentMode,
   331  				nestedMode: m,
   332  				origMode:   orig,
   333  			})
   334  
   335  		}
   336  
   337  		// TODO: this is an old trick that seems no longer necessary for the new
   338  		// implementation. Keep around until we finalize the semantics for
   339  		// defaults, though. The recursion of nested defaults is not entirely
   340  		// proper yet.
   341  		//
   342  		// A better approach, that avoids the need for recursion (semantically),
   343  		// would be to only consider default usage for one level, but then to
   344  		// also allow a default to be passed if only one value is remaining.
   345  		// This means that a nested subsumption would first have to be evaluated
   346  		// in isolation, however, to determine that it is not previous
   347  		// disjunctions that cause the disambiguation.
   348  		//
   349  		// HACK alert: this replaces the hack of the previous algorithm with a
   350  		// slightly less worse hack: instead of dropping the default info when
   351  		// the value was scalar before, we drop this information when there is
   352  		// only one disjunct, while not discarding hard defaults. TODO: a more
   353  		// principled approach would be to recognize that there is only one
   354  		// default at a point where this does not break commutativity. if
   355  		// if len(n.disjuncts) == 1 && n.disjuncts[0].defaultMode != isDefault {
   356  		// 	n.disjuncts[0].defaultMode = maybeDefault
   357  		// }
   358  	}
   359  
   360  	// Compare to root, but add to this one.
   361  	switch p := parent; {
   362  	case p != n:
   363  		p.disjunctErrs = append(p.disjunctErrs, n.disjunctErrs...)
   364  		n.disjunctErrs = n.disjunctErrs[:0]
   365  
   366  	outer:
   367  		for _, d := range n.disjuncts {
   368  			for k, v := range p.disjuncts {
   369  				if !d.done() || !v.done() {
   370  					break
   371  				}
   372  				flags := CheckStructural
   373  				if last {
   374  					flags |= IgnoreOptional
   375  				}
   376  				if Equal(n.ctx, &v.result, &d.result, flags) {
   377  					m := maybeDefault
   378  					for _, u := range d.usedDefault {
   379  						m = combineDefault(m, u.nestedMode)
   380  					}
   381  					if m == isDefault {
   382  						p.disjuncts[k] = d
   383  						v.free()
   384  					} else {
   385  						d.free()
   386  					}
   387  					continue outer
   388  				}
   389  			}
   390  
   391  			p.disjuncts = append(p.disjuncts, d)
   392  		}
   393  
   394  		n.disjuncts = n.disjuncts[:0]
   395  	}
   396  }
   397  
   398  func (n *nodeContext) makeError() {
   399  	code := IncompleteError
   400  
   401  	if len(n.disjunctErrs) > 0 {
   402  		code = EvalError
   403  		for _, c := range n.disjunctErrs {
   404  			if c.Code > code {
   405  				code = c.Code
   406  			}
   407  		}
   408  	}
   409  
   410  	b := &Bottom{
   411  		Code: code,
   412  		Err:  n.disjunctError(),
   413  	}
   414  	n.node.SetValue(n.ctx, Finalized, b)
   415  }
   416  
   417  func mode(hasDefault, marked bool) defaultMode {
   418  	var mode defaultMode
   419  	switch {
   420  	case !hasDefault:
   421  		mode = maybeDefault
   422  	case marked:
   423  		mode = isDefault
   424  	default:
   425  		mode = notDefault
   426  	}
   427  	return mode
   428  }
   429  
   430  // clone makes a shallow copy of a Vertex. The purpose is to create different
   431  // disjuncts from the same Vertex under computation. This allows the conjuncts
   432  // of an arc to be reset to a previous position and the reuse of earlier
   433  // computations.
   434  //
   435  // Notes: only Arcs need to be copied recursively. Either the arc is finalized
   436  // and can be used as is, or Structs is assumed to not yet be computed at the
   437  // time that a clone is needed and must be nil. Conjuncts no longer needed and
   438  // can become nil. All other fields can be copied shallowly.
   439  func clone(v Vertex) Vertex {
   440  	v.state = nil
   441  	if a := v.Arcs; len(a) > 0 {
   442  		v.Arcs = make([]*Vertex, len(a))
   443  		for i, arc := range a {
   444  			switch arc.status {
   445  			case Finalized:
   446  				v.Arcs[i] = arc
   447  
   448  			case 0:
   449  				a := *arc
   450  				v.Arcs[i] = &a
   451  
   452  				a.Conjuncts = make([]Conjunct, len(arc.Conjuncts))
   453  				copy(a.Conjuncts, arc.Conjuncts)
   454  
   455  			default:
   456  				a := *arc
   457  				a.state = arc.state.clone()
   458  				a.state.node = &a
   459  				a.state.snapshot = clone(a)
   460  				v.Arcs[i] = &a
   461  			}
   462  		}
   463  	}
   464  
   465  	if a := v.Structs; len(a) > 0 {
   466  		v.Structs = make([]*StructInfo, len(a))
   467  		copy(v.Structs, a)
   468  	}
   469  
   470  	return v
   471  }
   472  
   473  // Default rules from spec:
   474  //
   475  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   476  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   477  //
   478  // D1: (v1, d1) | v2       => (v1|v2, d1)
   479  // D2: (v1, d1) | (v2, d2) => (v1|v2, d1|d2)
   480  //
   481  // M1: *v        => (v, v)
   482  // M2: *(v1, d1) => (v1, d1)
   483  //
   484  // NOTE: M2 cannot be *(v1, d1) => (v1, v1), as this has the weird property
   485  // of making a value less specific. This causes issues, for instance, when
   486  // trimming.
   487  //
   488  // The old implementation does something similar though. It will discard
   489  // default information after first determining if more than one conjunct
   490  // has survived.
   491  //
   492  // def + maybe -> def
   493  // not + maybe -> def
   494  // not + def   -> def
   495  
   496  type defaultMode int
   497  
   498  const (
   499  	maybeDefault defaultMode = iota
   500  	isDefault
   501  	notDefault
   502  )
   503  
   504  // combineDefaults combines default modes for unifying conjuncts.
   505  //
   506  // Default rules from spec:
   507  //
   508  // U1: (v1, d1) & v2       => (v1&v2, d1&v2)
   509  // U2: (v1, d1) & (v2, d2) => (v1&v2, d1&d2)
   510  func combineDefault(a, b defaultMode) defaultMode {
   511  	if a > b {
   512  		return a
   513  	}
   514  	return b
   515  }
   516  
   517  // disjunctError returns a compound error for a failed disjunction.
   518  //
   519  // TODO(perf): the set of errors is now computed during evaluation. Eventually,
   520  // this could be done lazily.
   521  func (n *nodeContext) disjunctError() (errs errors.Error) {
   522  	ctx := n.ctx
   523  
   524  	disjuncts := selectErrors(n.disjunctErrs)
   525  
   526  	if disjuncts == nil {
   527  		errs = ctx.Newf("empty disjunction") // XXX: add space to sort first
   528  	} else {
   529  		disjuncts = errors.Sanitize(disjuncts)
   530  		k := len(errors.Errors(disjuncts))
   531  		// prefix '-' to sort to top
   532  		errs = ctx.Newf("%d errors in empty disjunction:", k)
   533  	}
   534  
   535  	errs = errors.Append(errs, disjuncts)
   536  
   537  	return errs
   538  }
   539  
   540  func selectErrors(a []*Bottom) (errs errors.Error) {
   541  	// return all errors if less than a certain number.
   542  	if len(a) <= 2 {
   543  		for _, b := range a {
   544  			errs = errors.Append(errs, b.Err)
   545  
   546  		}
   547  		return errs
   548  	}
   549  
   550  	// First select only relevant errors.
   551  	isIncomplete := false
   552  	k := 0
   553  	for _, b := range a {
   554  		if !isIncomplete && b.Code >= IncompleteError {
   555  			k = 0
   556  			isIncomplete = true
   557  		}
   558  		a[k] = b
   559  		k++
   560  	}
   561  	a = a[:k]
   562  
   563  	// filter errors
   564  	positions := map[token.Pos]bool{}
   565  
   566  	add := func(b *Bottom, p token.Pos) bool {
   567  		if positions[p] {
   568  			return false
   569  		}
   570  		positions[p] = true
   571  		errs = errors.Append(errs, b.Err)
   572  		return true
   573  	}
   574  
   575  	for _, b := range a {
   576  		// TODO: Should we also distinguish by message type?
   577  		if add(b, b.Err.Position()) {
   578  			continue
   579  		}
   580  		for _, p := range b.Err.InputPositions() {
   581  			if add(b, p) {
   582  				break
   583  			}
   584  		}
   585  	}
   586  
   587  	return errs
   588  }