github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/store/types/simplify.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package types 16 17 import ( 18 "sort" 19 20 "github.com/dolthub/dolt/go/store/d" 21 ) 22 23 // simplifyType returns a type that is a super type of the input type but is 24 // much smaller and less complex than a straight union of all those types would 25 // be. 26 // 27 // The resulting type is guaranteed to: 28 // a. be a super type of the input type 29 // b. have all unions flattened (no union inside a union) 30 // c. have all unions folded, which means the union 31 // 1. have at most one element each of kind Ref, Set, List, and Map 32 // 2. have at most one struct element with a given name 33 // e. all named unions are pointing at the same simplified struct, which means 34 // that all named unions with the same name form cycles. 35 // f. all cycle type that can be resolved have been resolved. 36 // g. all types reachable from it also fulfill b-f 37 // 38 // The union folding is created roughly as follows: 39 // 40 // - The input types are deduplicated 41 // - Any unions in the input set are "flattened" into the input set 42 // - The inputs are grouped into categories: 43 // - ref 44 // - list 45 // - set 46 // - map 47 // - struct, by name (each unique struct name will have its own group) 48 // - The ref, set, and list groups are collapsed like so: 49 // {Ref<A>,Ref<B>,...} -> Ref<A|B|...> 50 // - The map group is collapsed like so: 51 // {Map<K1,V1>|Map<K2,V2>...} -> Map<K1|K2,V1|V2> 52 // - Each struct group is collapsed like so: 53 // {struct{foo:number,bar:string}, struct{bar:blob, baz:bool}} -> 54 // struct{foo?:number,bar:string|blob,baz?:bool} 55 // 56 // All the above rules are applied recursively. 57 func simplifyType(t *Type, intersectStructs bool) (*Type, error) { 58 if t.Desc.isSimplifiedForSure() { 59 return t, nil 60 } 61 62 // 1. Clone tree because we are going to mutate it 63 // 1.1 Replace all named structs and cycle types with a single `struct Name {}` 64 // 2. When a union type is found change its elemTypes as needed 65 // 2.1 Merge unnamed structs 66 // 3. Update the fields of all named structs 67 68 namedStructs := map[string]structInfo{} 69 70 clone := cloneTypeTreeAndReplaceNamedStructs(t, namedStructs) 71 folded, err := foldUnions(clone, typeset{}, intersectStructs) 72 73 if err != nil { 74 return nil, err 75 } 76 77 for name, info := range namedStructs { 78 if len(info.sources) == 0 { 79 d.PanicIfTrue(name == "") 80 info.instance.Desc = CycleDesc(name) 81 } else { 82 fields, err := foldStructTypesFieldsOnly(name, info.sources, typeset{}, intersectStructs) 83 84 if err != nil { 85 return nil, err 86 } 87 88 info.instance.Desc = StructDesc{name, fields} 89 } 90 } 91 92 return folded, nil 93 } 94 95 // typeset is a helper that aggregates the unique set of input types for this algorithm, flattening 96 // any unions recursively. 97 type typeset map[*Type]struct{} 98 99 func (ts typeset) add(t *Type) { 100 switch t.TargetKind() { 101 case UnionKind: 102 for _, et := range t.Desc.(CompoundDesc).ElemTypes { 103 ts.add(et) 104 } 105 default: 106 ts[t] = struct{}{} 107 } 108 } 109 110 func (ts typeset) has(t *Type) bool { 111 _, ok := ts[t] 112 return ok 113 } 114 115 type structInfo struct { 116 instance *Type 117 sources typeset 118 } 119 120 func cloneTypeTreeAndReplaceNamedStructs(t *Type, namedStructs map[string]structInfo) *Type { 121 getNamedStruct := func(name string, t *Type) *Type { 122 record := namedStructs[name] 123 if t.TargetKind() == StructKind { 124 record.sources.add(t) 125 } 126 return record.instance 127 } 128 129 ensureInstance := func(name string) { 130 if _, ok := namedStructs[name]; !ok { 131 instance := newType(StructDesc{Name: name}) 132 namedStructs[name] = structInfo{instance, typeset{}} 133 } 134 } 135 136 seenStructs := typeset{} 137 var rec func(t *Type) *Type 138 rec = func(t *Type) *Type { 139 kind := t.TargetKind() 140 141 if IsPrimitiveKind(kind) { 142 return t 143 } 144 145 switch kind { 146 case ListKind, MapKind, RefKind, SetKind, UnionKind, TupleKind, JSONKind: 147 elemTypes := make(typeSlice, len(t.Desc.(CompoundDesc).ElemTypes)) 148 for i, et := range t.Desc.(CompoundDesc).ElemTypes { 149 elemTypes[i] = rec(et) 150 } 151 return newType(CompoundDesc{kind, elemTypes}) 152 case StructKind: 153 desc := t.Desc.(StructDesc) 154 name := desc.Name 155 156 if name != "" { 157 ensureInstance(name) 158 if seenStructs.has(t) { 159 return namedStructs[name].instance 160 } 161 } else if seenStructs.has(t) { 162 // It is OK to use the same unnamed struct type in multiple places. 163 // Do not clone it again. 164 return t 165 } 166 seenStructs.add(t) 167 168 fields := make(structTypeFields, len(desc.fields)) 169 for i, f := range desc.fields { 170 fields[i] = StructField{f.Name, rec(f.Type), f.Optional} 171 } 172 newStruct := newType(StructDesc{name, fields}) 173 if name == "" { 174 return newStruct 175 } 176 177 return getNamedStruct(name, newStruct) 178 179 case CycleKind: 180 name := string(t.Desc.(CycleDesc)) 181 d.PanicIfTrue(name == "") 182 ensureInstance(name) 183 return getNamedStruct(name, t) 184 185 default: 186 panic("Unknown noms kind") 187 } 188 } 189 190 return rec(t) 191 } 192 193 func foldUnions(t *Type, seenStructs typeset, intersectStructs bool) (*Type, error) { 194 var err error 195 196 kind := t.TargetKind() 197 if !IsPrimitiveKind(kind) { 198 switch kind { 199 case CycleKind: 200 break 201 202 case ListKind, MapKind, RefKind, SetKind, TupleKind, JSONKind: 203 elemTypes := t.Desc.(CompoundDesc).ElemTypes 204 for i, et := range elemTypes { 205 elemTypes[i], err = foldUnions(et, seenStructs, intersectStructs) 206 207 if err != nil { 208 return nil, err 209 } 210 } 211 212 case StructKind: 213 if seenStructs.has(t) { 214 return t, nil 215 } 216 seenStructs.add(t) 217 fields := t.Desc.(StructDesc).fields 218 for i, f := range fields { 219 fields[i].Type, err = foldUnions(f.Type, seenStructs, intersectStructs) 220 221 if err != nil { 222 return nil, err 223 } 224 } 225 226 case UnionKind: 227 elemTypes := t.Desc.(CompoundDesc).ElemTypes 228 if len(elemTypes) == 0 { 229 break 230 } 231 ts := make(typeset, len(elemTypes)) 232 for _, t := range elemTypes { 233 ts.add(t) 234 } 235 if len(ts) == 0 { 236 t.Desc = CompoundDesc{UnionKind, nil} 237 return t, nil 238 } 239 return foldUnionImpl(ts, seenStructs, intersectStructs) 240 241 default: 242 panic("Unknown noms kind") 243 } 244 } 245 return t, nil 246 } 247 248 func foldUnionImpl(ts typeset, seenStructs typeset, intersectStructs bool) (*Type, error) { 249 type how struct { 250 k NomsKind 251 n string 252 } 253 out := make(typeSlice, 0, len(ts)) 254 groups := map[how]typeset{} 255 for t := range ts { 256 var h how 257 switch t.TargetKind() { 258 case RefKind, SetKind, ListKind, MapKind, TupleKind, JSONKind: 259 h = how{k: t.TargetKind()} 260 case StructKind: 261 h = how{k: t.TargetKind(), n: t.Desc.(StructDesc).Name} 262 default: 263 out = append(out, t) 264 continue 265 } 266 g := groups[h] 267 if g == nil { 268 g = typeset{} 269 groups[h] = g 270 } 271 g.add(t) 272 } 273 274 for h, ts := range groups { 275 if len(ts) == 1 { 276 for t := range ts { 277 out = append(out, t) 278 } 279 continue 280 } 281 282 var r *Type 283 var err error 284 switch h.k { 285 case ListKind, RefKind, SetKind, TupleKind, JSONKind: 286 r, err = foldCompoundTypesForUnion(h.k, ts, seenStructs, intersectStructs) 287 case MapKind: 288 r, err = foldMapTypesForUnion(ts, seenStructs, intersectStructs) 289 case StructKind: 290 r, err = foldStructTypes(h.n, ts, seenStructs, intersectStructs) 291 } 292 293 if err != nil { 294 return nil, err 295 } 296 297 out = append(out, r) 298 } 299 300 for i, t := range out { 301 var err error 302 out[i], err = foldUnions(t, seenStructs, intersectStructs) 303 304 if err != nil { 305 return nil, err 306 } 307 } 308 309 if len(out) == 1 { 310 return out[0], nil 311 } 312 313 sort.Sort(out) 314 315 return newType(CompoundDesc{UnionKind, out}), nil 316 } 317 318 func foldCompoundTypesForUnion(k NomsKind, ts, seenStructs typeset, intersectStructs bool) (*Type, error) { 319 elemTypes := make(typeset, len(ts)) 320 for t := range ts { 321 d.PanicIfFalse(t.TargetKind() == k) 322 elemTypes.add(t.Desc.(CompoundDesc).ElemTypes[0]) 323 } 324 325 elemType, err := foldUnionImpl(elemTypes, seenStructs, intersectStructs) 326 327 if err != nil { 328 return nil, err 329 } 330 331 return makeCompoundType(k, elemType) 332 } 333 334 func foldMapTypesForUnion(ts, seenStructs typeset, intersectStructs bool) (*Type, error) { 335 keyTypes := make(typeset, len(ts)) 336 valTypes := make(typeset, len(ts)) 337 for t := range ts { 338 d.PanicIfFalse(t.TargetKind() == MapKind) 339 elemTypes := t.Desc.(CompoundDesc).ElemTypes 340 keyTypes.add(elemTypes[0]) 341 valTypes.add(elemTypes[1]) 342 } 343 344 kt, err := foldUnionImpl(keyTypes, seenStructs, intersectStructs) 345 346 if err != nil { 347 return nil, err 348 } 349 350 vt, err := foldUnionImpl(valTypes, seenStructs, intersectStructs) 351 352 if err != nil { 353 return nil, err 354 } 355 356 return makeCompoundType(MapKind, kt, vt) 357 } 358 359 func foldStructTypesFieldsOnly(name string, ts, seenStructs typeset, intersectStructs bool) (structTypeFields, error) { 360 fieldset := make([]structTypeFields, len(ts)) 361 i := 0 362 for t := range ts { 363 desc := t.Desc.(StructDesc) 364 d.PanicIfFalse(desc.Name == name) 365 fieldset[i] = desc.fields 366 i++ 367 } 368 369 return simplifyStructFields(fieldset, seenStructs, intersectStructs) 370 } 371 372 func foldStructTypes(name string, ts, seenStructs typeset, intersectStructs bool) (*Type, error) { 373 fields, err := foldStructTypesFieldsOnly(name, ts, seenStructs, intersectStructs) 374 375 if err != nil { 376 return nil, err 377 } 378 379 return newType(StructDesc{name, fields}), nil 380 } 381 382 func simplifyStructFields(in []structTypeFields, seenStructs typeset, intersectStructs bool) (structTypeFields, error) { 383 // We gather all the fields/types into allFields. If the number of 384 // times a field name is present is less that then number of types we 385 // are simplifying then the field must be optional. 386 // If we see an optional field we do not increment the count for it and 387 // it will be treated as optional in the end. 388 389 // If intersectStructs is true we need to pick the more restrictive version (n: T over n?: T). 390 type fieldTypeInfo struct { 391 anyNonOptional bool 392 count int 393 ts typeSlice 394 } 395 allFields := map[string]fieldTypeInfo{} 396 397 for _, ff := range in { 398 for _, f := range ff { 399 fti, ok := allFields[f.Name] 400 if !ok { 401 fti = fieldTypeInfo{ 402 ts: make(typeSlice, 0, len(in)), 403 } 404 } 405 fti.ts = append(fti.ts, f.Type) 406 if !f.Optional { 407 fti.count++ 408 fti.anyNonOptional = true 409 } 410 allFields[f.Name] = fti 411 } 412 } 413 414 count := len(in) 415 fields := make(structTypeFields, len(allFields)) 416 i := 0 417 for name, fti := range allFields { 418 nt, err := makeUnionType(fti.ts...) 419 420 if err != nil { 421 return nil, err 422 } 423 424 t, err := foldUnions(nt, seenStructs, intersectStructs) 425 426 if err != nil { 427 return nil, err 428 } 429 430 fields[i] = StructField{ 431 Name: name, 432 Type: t, 433 Optional: !(intersectStructs && fti.anyNonOptional) && fti.count < count, 434 } 435 i++ 436 } 437 438 sort.Sort(fields) 439 440 return fields, nil 441 }