github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/store/types/simplify.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package types
    16  
    17  import (
    18  	"sort"
    19  
    20  	"github.com/dolthub/dolt/go/store/d"
    21  )
    22  
    23  // simplifyType returns a type that is a super type of the input type but is
    24  // much smaller and less complex than a straight union of all those types would
    25  // be.
    26  //
    27  // The resulting type is guaranteed to:
    28  // a. be a super type of the input type
    29  // b. have all unions flattened (no union inside a union)
    30  // c. have all unions folded, which means the union
    31  //    1. have at most one element each of kind Ref, Set, List, and Map
    32  //    2. have at most one struct element with a given name
    33  // e. all named unions are pointing at the same simplified struct, which means
    34  //    that all named unions with the same name form cycles.
    35  // f. all cycle type that can be resolved have been resolved.
    36  // g. all types reachable from it also fulfill b-f
    37  //
    38  // The union folding is created roughly as follows:
    39  //
    40  // - The input types are deduplicated
    41  // - Any unions in the input set are "flattened" into the input set
    42  // - The inputs are grouped into categories:
    43  //    - ref
    44  //    - list
    45  //    - set
    46  //    - map
    47  //    - struct, by name (each unique struct name will have its own group)
    48  // - The ref, set, and list groups are collapsed like so:
    49  //     {Ref<A>,Ref<B>,...} -> Ref<A|B|...>
    50  // - The map group is collapsed like so:
    51  //     {Map<K1,V1>|Map<K2,V2>...} -> Map<K1|K2,V1|V2>
    52  // - Each struct group is collapsed like so:
    53  //     {struct{foo:number,bar:string}, struct{bar:blob, baz:bool}} ->
    54  //       struct{foo?:number,bar:string|blob,baz?:bool}
    55  //
    56  // All the above rules are applied recursively.
    57  func simplifyType(t *Type, intersectStructs bool) (*Type, error) {
    58  	if t.Desc.isSimplifiedForSure() {
    59  		return t, nil
    60  	}
    61  
    62  	// 1. Clone tree because we are going to mutate it
    63  	//    1.1 Replace all named structs and cycle types with a single `struct Name {}`
    64  	// 2. When a union type is found change its elemTypes as needed
    65  	//    2.1 Merge unnamed structs
    66  	// 3. Update the fields of all named structs
    67  
    68  	namedStructs := map[string]structInfo{}
    69  
    70  	clone := cloneTypeTreeAndReplaceNamedStructs(t, namedStructs)
    71  	folded, err := foldUnions(clone, typeset{}, intersectStructs)
    72  
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	for name, info := range namedStructs {
    78  		if len(info.sources) == 0 {
    79  			d.PanicIfTrue(name == "")
    80  			info.instance.Desc = CycleDesc(name)
    81  		} else {
    82  			fields, err := foldStructTypesFieldsOnly(name, info.sources, typeset{}, intersectStructs)
    83  
    84  			if err != nil {
    85  				return nil, err
    86  			}
    87  
    88  			info.instance.Desc = StructDesc{name, fields}
    89  		}
    90  	}
    91  
    92  	return folded, nil
    93  }
    94  
    95  // typeset is a helper that aggregates the unique set of input types for this algorithm, flattening
    96  // any unions recursively.
    97  type typeset map[*Type]struct{}
    98  
    99  func (ts typeset) add(t *Type) {
   100  	switch t.TargetKind() {
   101  	case UnionKind:
   102  		for _, et := range t.Desc.(CompoundDesc).ElemTypes {
   103  			ts.add(et)
   104  		}
   105  	default:
   106  		ts[t] = struct{}{}
   107  	}
   108  }
   109  
   110  func (ts typeset) has(t *Type) bool {
   111  	_, ok := ts[t]
   112  	return ok
   113  }
   114  
   115  type structInfo struct {
   116  	instance *Type
   117  	sources  typeset
   118  }
   119  
   120  func cloneTypeTreeAndReplaceNamedStructs(t *Type, namedStructs map[string]structInfo) *Type {
   121  	getNamedStruct := func(name string, t *Type) *Type {
   122  		record := namedStructs[name]
   123  		if t.TargetKind() == StructKind {
   124  			record.sources.add(t)
   125  		}
   126  		return record.instance
   127  	}
   128  
   129  	ensureInstance := func(name string) {
   130  		if _, ok := namedStructs[name]; !ok {
   131  			instance := newType(StructDesc{Name: name})
   132  			namedStructs[name] = structInfo{instance, typeset{}}
   133  		}
   134  	}
   135  
   136  	seenStructs := typeset{}
   137  	var rec func(t *Type) *Type
   138  	rec = func(t *Type) *Type {
   139  		kind := t.TargetKind()
   140  
   141  		if IsPrimitiveKind(kind) {
   142  			return t
   143  		}
   144  
   145  		switch kind {
   146  		case ListKind, MapKind, RefKind, SetKind, UnionKind, TupleKind, JSONKind:
   147  			elemTypes := make(typeSlice, len(t.Desc.(CompoundDesc).ElemTypes))
   148  			for i, et := range t.Desc.(CompoundDesc).ElemTypes {
   149  				elemTypes[i] = rec(et)
   150  			}
   151  			return newType(CompoundDesc{kind, elemTypes})
   152  		case StructKind:
   153  			desc := t.Desc.(StructDesc)
   154  			name := desc.Name
   155  
   156  			if name != "" {
   157  				ensureInstance(name)
   158  				if seenStructs.has(t) {
   159  					return namedStructs[name].instance
   160  				}
   161  			} else if seenStructs.has(t) {
   162  				// It is OK to use the same unnamed struct type in multiple places.
   163  				// Do not clone it again.
   164  				return t
   165  			}
   166  			seenStructs.add(t)
   167  
   168  			fields := make(structTypeFields, len(desc.fields))
   169  			for i, f := range desc.fields {
   170  				fields[i] = StructField{f.Name, rec(f.Type), f.Optional}
   171  			}
   172  			newStruct := newType(StructDesc{name, fields})
   173  			if name == "" {
   174  				return newStruct
   175  			}
   176  
   177  			return getNamedStruct(name, newStruct)
   178  
   179  		case CycleKind:
   180  			name := string(t.Desc.(CycleDesc))
   181  			d.PanicIfTrue(name == "")
   182  			ensureInstance(name)
   183  			return getNamedStruct(name, t)
   184  
   185  		default:
   186  			panic("Unknown noms kind")
   187  		}
   188  	}
   189  
   190  	return rec(t)
   191  }
   192  
   193  func foldUnions(t *Type, seenStructs typeset, intersectStructs bool) (*Type, error) {
   194  	var err error
   195  
   196  	kind := t.TargetKind()
   197  	if !IsPrimitiveKind(kind) {
   198  		switch kind {
   199  		case CycleKind:
   200  			break
   201  
   202  		case ListKind, MapKind, RefKind, SetKind, TupleKind, JSONKind:
   203  			elemTypes := t.Desc.(CompoundDesc).ElemTypes
   204  			for i, et := range elemTypes {
   205  				elemTypes[i], err = foldUnions(et, seenStructs, intersectStructs)
   206  
   207  				if err != nil {
   208  					return nil, err
   209  				}
   210  			}
   211  
   212  		case StructKind:
   213  			if seenStructs.has(t) {
   214  				return t, nil
   215  			}
   216  			seenStructs.add(t)
   217  			fields := t.Desc.(StructDesc).fields
   218  			for i, f := range fields {
   219  				fields[i].Type, err = foldUnions(f.Type, seenStructs, intersectStructs)
   220  
   221  				if err != nil {
   222  					return nil, err
   223  				}
   224  			}
   225  
   226  		case UnionKind:
   227  			elemTypes := t.Desc.(CompoundDesc).ElemTypes
   228  			if len(elemTypes) == 0 {
   229  				break
   230  			}
   231  			ts := make(typeset, len(elemTypes))
   232  			for _, t := range elemTypes {
   233  				ts.add(t)
   234  			}
   235  			if len(ts) == 0 {
   236  				t.Desc = CompoundDesc{UnionKind, nil}
   237  				return t, nil
   238  			}
   239  			return foldUnionImpl(ts, seenStructs, intersectStructs)
   240  
   241  		default:
   242  			panic("Unknown noms kind")
   243  		}
   244  	}
   245  	return t, nil
   246  }
   247  
   248  func foldUnionImpl(ts typeset, seenStructs typeset, intersectStructs bool) (*Type, error) {
   249  	type how struct {
   250  		k NomsKind
   251  		n string
   252  	}
   253  	out := make(typeSlice, 0, len(ts))
   254  	groups := map[how]typeset{}
   255  	for t := range ts {
   256  		var h how
   257  		switch t.TargetKind() {
   258  		case RefKind, SetKind, ListKind, MapKind, TupleKind, JSONKind:
   259  			h = how{k: t.TargetKind()}
   260  		case StructKind:
   261  			h = how{k: t.TargetKind(), n: t.Desc.(StructDesc).Name}
   262  		default:
   263  			out = append(out, t)
   264  			continue
   265  		}
   266  		g := groups[h]
   267  		if g == nil {
   268  			g = typeset{}
   269  			groups[h] = g
   270  		}
   271  		g.add(t)
   272  	}
   273  
   274  	for h, ts := range groups {
   275  		if len(ts) == 1 {
   276  			for t := range ts {
   277  				out = append(out, t)
   278  			}
   279  			continue
   280  		}
   281  
   282  		var r *Type
   283  		var err error
   284  		switch h.k {
   285  		case ListKind, RefKind, SetKind, TupleKind, JSONKind:
   286  			r, err = foldCompoundTypesForUnion(h.k, ts, seenStructs, intersectStructs)
   287  		case MapKind:
   288  			r, err = foldMapTypesForUnion(ts, seenStructs, intersectStructs)
   289  		case StructKind:
   290  			r, err = foldStructTypes(h.n, ts, seenStructs, intersectStructs)
   291  		}
   292  
   293  		if err != nil {
   294  			return nil, err
   295  		}
   296  
   297  		out = append(out, r)
   298  	}
   299  
   300  	for i, t := range out {
   301  		var err error
   302  		out[i], err = foldUnions(t, seenStructs, intersectStructs)
   303  
   304  		if err != nil {
   305  			return nil, err
   306  		}
   307  	}
   308  
   309  	if len(out) == 1 {
   310  		return out[0], nil
   311  	}
   312  
   313  	sort.Sort(out)
   314  
   315  	return newType(CompoundDesc{UnionKind, out}), nil
   316  }
   317  
   318  func foldCompoundTypesForUnion(k NomsKind, ts, seenStructs typeset, intersectStructs bool) (*Type, error) {
   319  	elemTypes := make(typeset, len(ts))
   320  	for t := range ts {
   321  		d.PanicIfFalse(t.TargetKind() == k)
   322  		elemTypes.add(t.Desc.(CompoundDesc).ElemTypes[0])
   323  	}
   324  
   325  	elemType, err := foldUnionImpl(elemTypes, seenStructs, intersectStructs)
   326  
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  
   331  	return makeCompoundType(k, elemType)
   332  }
   333  
   334  func foldMapTypesForUnion(ts, seenStructs typeset, intersectStructs bool) (*Type, error) {
   335  	keyTypes := make(typeset, len(ts))
   336  	valTypes := make(typeset, len(ts))
   337  	for t := range ts {
   338  		d.PanicIfFalse(t.TargetKind() == MapKind)
   339  		elemTypes := t.Desc.(CompoundDesc).ElemTypes
   340  		keyTypes.add(elemTypes[0])
   341  		valTypes.add(elemTypes[1])
   342  	}
   343  
   344  	kt, err := foldUnionImpl(keyTypes, seenStructs, intersectStructs)
   345  
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	vt, err := foldUnionImpl(valTypes, seenStructs, intersectStructs)
   351  
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	return makeCompoundType(MapKind, kt, vt)
   357  }
   358  
   359  func foldStructTypesFieldsOnly(name string, ts, seenStructs typeset, intersectStructs bool) (structTypeFields, error) {
   360  	fieldset := make([]structTypeFields, len(ts))
   361  	i := 0
   362  	for t := range ts {
   363  		desc := t.Desc.(StructDesc)
   364  		d.PanicIfFalse(desc.Name == name)
   365  		fieldset[i] = desc.fields
   366  		i++
   367  	}
   368  
   369  	return simplifyStructFields(fieldset, seenStructs, intersectStructs)
   370  }
   371  
   372  func foldStructTypes(name string, ts, seenStructs typeset, intersectStructs bool) (*Type, error) {
   373  	fields, err := foldStructTypesFieldsOnly(name, ts, seenStructs, intersectStructs)
   374  
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  
   379  	return newType(StructDesc{name, fields}), nil
   380  }
   381  
   382  func simplifyStructFields(in []structTypeFields, seenStructs typeset, intersectStructs bool) (structTypeFields, error) {
   383  	// We gather all the fields/types into allFields. If the number of
   384  	// times a field name is present is less that then number of types we
   385  	// are simplifying then the field must be optional.
   386  	// If we see an optional field we do not increment the count for it and
   387  	// it will be treated as optional in the end.
   388  
   389  	// If intersectStructs is true we need to pick the more restrictive version (n: T over n?: T).
   390  	type fieldTypeInfo struct {
   391  		anyNonOptional bool
   392  		count          int
   393  		ts             typeSlice
   394  	}
   395  	allFields := map[string]fieldTypeInfo{}
   396  
   397  	for _, ff := range in {
   398  		for _, f := range ff {
   399  			fti, ok := allFields[f.Name]
   400  			if !ok {
   401  				fti = fieldTypeInfo{
   402  					ts: make(typeSlice, 0, len(in)),
   403  				}
   404  			}
   405  			fti.ts = append(fti.ts, f.Type)
   406  			if !f.Optional {
   407  				fti.count++
   408  				fti.anyNonOptional = true
   409  			}
   410  			allFields[f.Name] = fti
   411  		}
   412  	}
   413  
   414  	count := len(in)
   415  	fields := make(structTypeFields, len(allFields))
   416  	i := 0
   417  	for name, fti := range allFields {
   418  		nt, err := makeUnionType(fti.ts...)
   419  
   420  		if err != nil {
   421  			return nil, err
   422  		}
   423  
   424  		t, err := foldUnions(nt, seenStructs, intersectStructs)
   425  
   426  		if err != nil {
   427  			return nil, err
   428  		}
   429  
   430  		fields[i] = StructField{
   431  			Name:     name,
   432  			Type:     t,
   433  			Optional: !(intersectStructs && fti.anyNonOptional) && fti.count < count,
   434  		}
   435  		i++
   436  	}
   437  
   438  	sort.Sort(fields)
   439  
   440  	return fields, nil
   441  }