github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/util/json/contains.go (about) 1 // Copyright 2017 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package json 12 13 import "sort" 14 15 // Contains returns true if a contains b. This implements the @>, <@ operators. 16 // See the Postgres docs for the expected semantics of Contains. 17 // https://www.postgresql.org/docs/10/static/datatype-json.html#JSON-CONTAINMENT 18 // The naive approach to doing array containment would be to do an O(n^2) 19 // nested loop through the arrays to check if one is contained in the 20 // other. We're out of luck when the arrays contain other arrays or 21 // objects (there might actually be something fancy we can do, but there's nothing 22 // obvious). 23 // When the arrays contain scalars however, we can optimize this by 24 // pre-sorting both arrays and iterating through them in lockstep. 25 // To this end, we preprocess the JSON document to sort all of its arrays so 26 // that when we perform contains we can extract the scalars sorted, and then 27 // also the arrays and objects in separate arrays, so that we can do the fast 28 // thing for the subset of the arrays which are scalars. 29 func Contains(a, b JSON) (bool, error) { 30 if a.Type() == ArrayJSONType && b.isScalar() { 31 decoded, err := a.tryDecode() 32 if err != nil { 33 return false, err 34 } 35 ary := decoded.(jsonArray) 36 return checkArrayContainsScalar(ary, b) 37 } 38 39 preA, err := a.preprocessForContains() 40 if err != nil { 41 return false, err 42 } 43 preB, err := b.preprocessForContains() 44 if err != nil { 45 return false, err 46 } 47 return preA.contains(preB) 48 } 49 50 // checkArrayContainsScalar performs a unique case of contains (and is 51 // described as such in the Postgres docs) - a top-level array contains a 52 // scalar which is an element of it. This contradicts the general rule of 53 // contains that the contained object must have the same "shape" as the 54 // containing object. 55 func checkArrayContainsScalar(ary jsonArray, s JSON) (bool, error) { 56 for _, j := range ary { 57 cmp, err := j.Compare(s) 58 if err != nil { 59 return false, err 60 } 61 if cmp == 0 { 62 return true, nil 63 } 64 } 65 return false, nil 66 } 67 68 // containsable is an interface used internally for the implementation of @>. 69 type containsable interface { 70 contains(other containsable) (bool, error) 71 } 72 73 // containsableScalar is a preprocessed JSON scalar. The JSON it holds will 74 // never be a JSON object or a JSON array. 75 type containsableScalar struct{ JSON } 76 77 // containsableArray is a preprocessed JSON array. 78 // * scalars will always be scalars and will always be sorted, 79 // * arrays will only contain containsableArrays, 80 // * objects will only contain containsableObjects 81 // (the last two are stored interfaces for reuse, though) 82 type containsableArray struct { 83 scalars []containsableScalar 84 arrays []containsable 85 objects []containsable 86 } 87 88 type containsableKeyValuePair struct { 89 k jsonString 90 v containsable 91 } 92 93 // containsableObject is a preprocessed JSON object. 94 // Same as a jsonObject, it is stored as a sorted-by-key list of key-value 95 // pairs. 96 type containsableObject []containsableKeyValuePair 97 98 func (j jsonNull) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil } 99 func (j jsonFalse) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil } 100 func (j jsonTrue) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil } 101 func (j jsonNumber) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil } 102 func (j jsonString) preprocessForContains() (containsable, error) { return containsableScalar{j}, nil } 103 104 func (j jsonArray) preprocessForContains() (containsable, error) { 105 result := containsableArray{} 106 for _, e := range j { 107 switch e.Type() { 108 case ArrayJSONType: 109 preprocessed, err := e.preprocessForContains() 110 if err != nil { 111 return nil, err 112 } 113 result.arrays = append(result.arrays, preprocessed) 114 case ObjectJSONType: 115 preprocessed, err := e.preprocessForContains() 116 if err != nil { 117 return nil, err 118 } 119 result.objects = append(result.objects, preprocessed) 120 default: 121 preprocessed, err := e.preprocessForContains() 122 if err != nil { 123 return nil, err 124 } 125 result.scalars = append(result.scalars, preprocessed.(containsableScalar)) 126 } 127 } 128 129 var err error 130 sort.Slice(result.scalars, func(i, j int) bool { 131 if err != nil { 132 return false 133 } 134 var c int 135 c, err = result.scalars[i].JSON.Compare(result.scalars[j].JSON) 136 return c == -1 137 }) 138 139 if err != nil { 140 return nil, err 141 } 142 143 return result, nil 144 } 145 146 func (j jsonObject) preprocessForContains() (containsable, error) { 147 preprocessed := make(containsableObject, len(j)) 148 149 for i := range preprocessed { 150 preprocessed[i].k = j[i].k 151 v, err := j[i].v.preprocessForContains() 152 if err != nil { 153 return nil, err 154 } 155 preprocessed[i].v = v 156 } 157 158 return preprocessed, nil 159 } 160 161 func (j containsableScalar) contains(other containsable) (bool, error) { 162 if o, ok := other.(containsableScalar); ok { 163 result, err := j.JSON.Compare(o.JSON) 164 if err != nil { 165 return false, err 166 } 167 return result == 0, nil 168 } 169 return false, nil 170 } 171 172 func (j containsableArray) contains(other containsable) (bool, error) { 173 if contained, ok := other.(containsableArray); ok { 174 // Since both slices of scalars are sorted via the preprocessing, we can 175 // step through them together via binary search. 176 remainingScalars := j.scalars[:] 177 for _, val := range contained.scalars { 178 var err error 179 found := sort.Search(len(remainingScalars), func(i int) bool { 180 if err != nil { 181 return false 182 } 183 var result int 184 result, err = remainingScalars[i].JSON.Compare(val.JSON) 185 return result >= 0 186 }) 187 188 if found == len(remainingScalars) { 189 return false, nil 190 } 191 result, err := remainingScalars[found].JSON.Compare(val.JSON) 192 if err != nil { 193 return false, err 194 } 195 if result != 0 { 196 return false, nil 197 } 198 remainingScalars = remainingScalars[found:] 199 } 200 201 // TODO(justin): there's possibly(?) something fancier we can do with the 202 // objects and arrays, but for now just do the quadratic check. 203 objectsMatch, err := quadraticJSONArrayContains(j.objects, contained.objects) 204 if err != nil { 205 return false, err 206 } 207 if !objectsMatch { 208 return false, nil 209 } 210 211 arraysMatch, err := quadraticJSONArrayContains(j.arrays, contained.arrays) 212 if err != nil { 213 return false, err 214 } 215 if !arraysMatch { 216 return false, nil 217 } 218 219 return true, nil 220 } 221 return false, nil 222 } 223 224 // quadraticJSONArrayContains does an O(n^2) check to see if every value in 225 // `other` is contained within a value in `container`. `container` and `other` 226 // should not contain scalars. 227 func quadraticJSONArrayContains(container, other []containsable) (bool, error) { 228 for _, otherVal := range other { 229 found := false 230 for _, containerVal := range container { 231 c, err := containerVal.contains(otherVal) 232 if err != nil { 233 return false, err 234 } 235 if c { 236 found = true 237 break 238 } 239 } 240 if !found { 241 return false, nil 242 } 243 } 244 return true, nil 245 } 246 247 func (j containsableObject) contains(other containsable) (bool, error) { 248 if contained, ok := other.(containsableObject); ok { 249 // We can iterate through the keys of `other` and scan through to find the 250 // corresponding keys in `j` since they're both sorted. 251 objIdx := 0 252 for _, rightEntry := range contained { 253 for objIdx < len(j) && j[objIdx].k < rightEntry.k { 254 objIdx++ 255 } 256 if objIdx >= len(j) || 257 j[objIdx].k != rightEntry.k { 258 return false, nil 259 } 260 c, err := j[objIdx].v.contains(rightEntry.v) 261 if err != nil { 262 return false, err 263 } 264 if !c { 265 return false, nil 266 } 267 objIdx++ 268 } 269 return true, nil 270 } 271 return false, nil 272 }