github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/json/contains.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package json
    12  
    13  import "sort"
    14  
    15  // Contains returns true if a contains b. This implements the @>, <@ operators.
    16  // See the Postgres docs for the expected semantics of Contains.
    17  // https://www.postgresql.org/docs/10/static/datatype-json.html#JSON-CONTAINMENT
    18  // The naive approach to doing array containment would be to do an O(n^2)
    19  // nested loop through the arrays to check if one is contained in the
    20  // other.  We're out of luck when the arrays contain other arrays or
    21  // objects (there might actually be something fancy we can do, but there's nothing
    22  // obvious).
    23  // When the arrays contain scalars however, we can optimize this by
    24  // pre-sorting both arrays and iterating through them in lockstep.
    25  // To this end, we preprocess the JSON document to sort all of its arrays so
    26  // that when we perform contains we can extract the scalars sorted, and then
    27  // also the arrays and objects in separate arrays, so that we can do the fast
    28  // thing for the subset of the arrays which are scalars.
    29  func Contains(a, b JSON) (bool, error) {
    30  	if a.Type() == ArrayJSONType && b.isScalar() {
    31  		decoded, err := a.tryDecode()
    32  		if err != nil {
    33  			return false, err
    34  		}
    35  		ary := decoded.(jsonArray)
    36  		return checkArrayContainsScalar(ary, b)
    37  	}
    38  
    39  	preA, err := a.preprocessForContains()
    40  	if err != nil {
    41  		return false, err
    42  	}
    43  	preB, err := b.preprocessForContains()
    44  	if err != nil {
    45  		return false, err
    46  	}
    47  	return preA.contains(preB)
    48  }
    49  
    50  // checkArrayContainsScalar performs a unique case of contains (and is
    51  // described as such in the Postgres docs) - a top-level array contains a
    52  // scalar which is an element of it.  This contradicts the general rule of
    53  // contains that the contained object must have the same "shape" as the
    54  // containing object.
    55  func checkArrayContainsScalar(ary jsonArray, s JSON) (bool, error) {
    56  	for _, j := range ary {
    57  		cmp, err := j.Compare(s)
    58  		if err != nil {
    59  			return false, err
    60  		}
    61  		if cmp == 0 {
    62  			return true, nil
    63  		}
    64  	}
    65  	return false, nil
    66  }
    67  
    68  // containsable is an interface used internally for the implementation of @>.
    69  type containsable interface {
    70  	contains(other containsable) (bool, error)
    71  }
    72  
    73  // containsableScalar is a preprocessed JSON scalar. The JSON it holds will
    74  // never be a JSON object or a JSON array.
    75  type containsableScalar struct{ JSON }
    76  
    77  // containsableArray is a preprocessed JSON array.
    78  // * scalars will always be scalars and will always be sorted,
    79  // * arrays will only contain containsableArrays,
    80  // * objects will only contain containsableObjects
    81  // (the last two are stored interfaces for reuse, though)
    82  type containsableArray struct {
    83  	scalars []containsableScalar
    84  	arrays  []containsable
    85  	objects []containsable
    86  }
    87  
    88  type containsableKeyValuePair struct {
    89  	k jsonString
    90  	v containsable
    91  }
    92  
    93  // containsableObject is a preprocessed JSON object.
    94  // Same as a jsonObject, it is stored as a sorted-by-key list of key-value
    95  // pairs.
    96  type containsableObject []containsableKeyValuePair
    97  
    98  func (j jsonNull) preprocessForContains() (containsable, error)   { return containsableScalar{j}, nil }
    99  func (j jsonFalse) preprocessForContains() (containsable, error)  { return containsableScalar{j}, nil }
   100  func (j jsonTrue) preprocessForContains() (containsable, error)   { return containsableScalar{j}, nil }
   101  func (j jsonNumber) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil }
   102  func (j jsonString) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil }
   103  
   104  func (j jsonArray) preprocessForContains() (containsable, error) {
   105  	result := containsableArray{}
   106  	for _, e := range j {
   107  		switch e.Type() {
   108  		case ArrayJSONType:
   109  			preprocessed, err := e.preprocessForContains()
   110  			if err != nil {
   111  				return nil, err
   112  			}
   113  			result.arrays = append(result.arrays, preprocessed)
   114  		case ObjectJSONType:
   115  			preprocessed, err := e.preprocessForContains()
   116  			if err != nil {
   117  				return nil, err
   118  			}
   119  			result.objects = append(result.objects, preprocessed)
   120  		default:
   121  			preprocessed, err := e.preprocessForContains()
   122  			if err != nil {
   123  				return nil, err
   124  			}
   125  			result.scalars = append(result.scalars, preprocessed.(containsableScalar))
   126  		}
   127  	}
   128  
   129  	var err error
   130  	sort.Slice(result.scalars, func(i, j int) bool {
   131  		if err != nil {
   132  			return false
   133  		}
   134  		var c int
   135  		c, err = result.scalars[i].JSON.Compare(result.scalars[j].JSON)
   136  		return c == -1
   137  	})
   138  
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return result, nil
   144  }
   145  
   146  func (j jsonObject) preprocessForContains() (containsable, error) {
   147  	preprocessed := make(containsableObject, len(j))
   148  
   149  	for i := range preprocessed {
   150  		preprocessed[i].k = j[i].k
   151  		v, err := j[i].v.preprocessForContains()
   152  		if err != nil {
   153  			return nil, err
   154  		}
   155  		preprocessed[i].v = v
   156  	}
   157  
   158  	return preprocessed, nil
   159  }
   160  
   161  func (j containsableScalar) contains(other containsable) (bool, error) {
   162  	if o, ok := other.(containsableScalar); ok {
   163  		result, err := j.JSON.Compare(o.JSON)
   164  		if err != nil {
   165  			return false, err
   166  		}
   167  		return result == 0, nil
   168  	}
   169  	return false, nil
   170  }
   171  
   172  func (j containsableArray) contains(other containsable) (bool, error) {
   173  	if contained, ok := other.(containsableArray); ok {
   174  		// Since both slices of scalars are sorted via the preprocessing, we can
   175  		// step through them together via binary search.
   176  		remainingScalars := j.scalars[:]
   177  		for _, val := range contained.scalars {
   178  			var err error
   179  			found := sort.Search(len(remainingScalars), func(i int) bool {
   180  				if err != nil {
   181  					return false
   182  				}
   183  				var result int
   184  				result, err = remainingScalars[i].JSON.Compare(val.JSON)
   185  				return result >= 0
   186  			})
   187  
   188  			if found == len(remainingScalars) {
   189  				return false, nil
   190  			}
   191  			result, err := remainingScalars[found].JSON.Compare(val.JSON)
   192  			if err != nil {
   193  				return false, err
   194  			}
   195  			if result != 0 {
   196  				return false, nil
   197  			}
   198  			remainingScalars = remainingScalars[found:]
   199  		}
   200  
   201  		// TODO(justin): there's possibly(?) something fancier we can do with the
   202  		// objects and arrays, but for now just do the quadratic check.
   203  		objectsMatch, err := quadraticJSONArrayContains(j.objects, contained.objects)
   204  		if err != nil {
   205  			return false, err
   206  		}
   207  		if !objectsMatch {
   208  			return false, nil
   209  		}
   210  
   211  		arraysMatch, err := quadraticJSONArrayContains(j.arrays, contained.arrays)
   212  		if err != nil {
   213  			return false, err
   214  		}
   215  		if !arraysMatch {
   216  			return false, nil
   217  		}
   218  
   219  		return true, nil
   220  	}
   221  	return false, nil
   222  }
   223  
   224  // quadraticJSONArrayContains does an O(n^2) check to see if every value in
   225  // `other` is contained within a value in `container`. `container` and `other`
   226  // should not contain scalars.
   227  func quadraticJSONArrayContains(container, other []containsable) (bool, error) {
   228  	for _, otherVal := range other {
   229  		found := false
   230  		for _, containerVal := range container {
   231  			c, err := containerVal.contains(otherVal)
   232  			if err != nil {
   233  				return false, err
   234  			}
   235  			if c {
   236  				found = true
   237  				break
   238  			}
   239  		}
   240  		if !found {
   241  			return false, nil
   242  		}
   243  	}
   244  	return true, nil
   245  }
   246  
   247  func (j containsableObject) contains(other containsable) (bool, error) {
   248  	if contained, ok := other.(containsableObject); ok {
   249  		// We can iterate through the keys of `other` and scan through to find the
   250  		// corresponding keys in `j` since they're both sorted.
   251  		objIdx := 0
   252  		for _, rightEntry := range contained {
   253  			for objIdx < len(j) && j[objIdx].k < rightEntry.k {
   254  				objIdx++
   255  			}
   256  			if objIdx >= len(j) ||
   257  				j[objIdx].k != rightEntry.k {
   258  				return false, nil
   259  			}
   260  			c, err := j[objIdx].v.contains(rightEntry.v)
   261  			if err != nil {
   262  				return false, err
   263  			}
   264  			if !c {
   265  				return false, nil
   266  			}
   267  			objIdx++
   268  		}
   269  		return true, nil
   270  	}
   271  	return false, nil
   272  }