github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/env/actions/infer_schema.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package actions
    16  
    17  import (
    18  	"context"
    19  	"math"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/google/uuid"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/table"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/table/pipeline"
    32  	"github.com/dolthub/dolt/go/libraries/utils/set"
    33  	"github.com/dolthub/dolt/go/store/types"
    34  )
    35  
    36  type typeInfoSet map[typeinfo.TypeInfo]struct{}
    37  
    38  const (
    39  	maxUint24 = 1<<24 - 1
    40  	minInt24  = -1 << 23
    41  )
    42  
    43  // InferenceArgs are arguments that can be passed to the schema inferrer to modify it's inference behavior.
    44  type InferenceArgs interface {
    45  	// ColNameMapper allows columns named X in the schema to be named Y in the inferred schema.
    46  	ColNameMapper() rowconv.NameMapper
    47  	// FloatThreshold is the threshold at which a string representing a floating point number should be interpreted as
    48  	// a float versus an int.  If FloatThreshold is 0.0 then any number with a decimal point will be interpreted as a
    49  	// float (such as 0.0, 1.0, etc).  If FloatThreshold is 1.0 then any number with a decimal point will be converted
    50  	// to an int (0.5 will be the int 0, 1.99 will be the int 1, etc.  If the FloatThreshold is 0.001 then numbers with
    51  	// a fractional component greater than or equal to 0.001 will be treated as a float (1.0 would be an int, 1.0009 would
    52  	// be an int, 1.001 would be a float, 1.1 would be a float, etc)
    53  	FloatThreshold() float64
    54  }
    55  
    56  // InferColumnTypesFromTableReader will infer a data types from a table reader.
    57  func InferColumnTypesFromTableReader(ctx context.Context, root *doltdb.RootValue, rd table.TableReadCloser, args InferenceArgs) (*schema.ColCollection, error) {
    58  	inferrer := newInferrer(rd.GetSchema(), args)
    59  
    60  	var rowFailure *pipeline.TransformRowFailure
    61  	badRow := func(trf *pipeline.TransformRowFailure) (quit bool) {
    62  		rowFailure = trf
    63  		return false
    64  	}
    65  
    66  	rdProcFunc := pipeline.ProcFuncForReader(ctx, rd)
    67  	p := pipeline.NewAsyncPipeline(rdProcFunc, inferrer.sinkRow, nil, badRow)
    68  	p.Start()
    69  
    70  	err := p.Wait()
    71  
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	if rowFailure != nil {
    77  		return nil, rowFailure
    78  	}
    79  
    80  	return inferrer.inferColumnTypes(ctx, root)
    81  }
    82  
    83  type inferrer struct {
    84  	readerSch      schema.Schema
    85  	inferSets      map[uint64]typeInfoSet
    86  	nullable       *set.Uint64Set
    87  	mapper         rowconv.NameMapper
    88  	floatThreshold float64
    89  
    90  	//inferArgs *InferenceArgs
    91  }
    92  
    93  func newInferrer(readerSch schema.Schema, args InferenceArgs) *inferrer {
    94  	inferSets := make(map[uint64]typeInfoSet, readerSch.GetAllCols().Size())
    95  	_ = readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
    96  		inferSets[tag] = make(typeInfoSet)
    97  		return false, nil
    98  	})
    99  
   100  	return &inferrer{
   101  		readerSch:      readerSch,
   102  		inferSets:      inferSets,
   103  		nullable:       set.NewUint64Set(nil),
   104  		mapper:         args.ColNameMapper(),
   105  		floatThreshold: args.FloatThreshold(),
   106  	}
   107  }
   108  
   109  // inferColumnTypes returns TableReader's columns with updated TypeInfo and columns names
   110  func (inf *inferrer) inferColumnTypes(ctx context.Context, root *doltdb.RootValue) (*schema.ColCollection, error) {
   111  
   112  	inferredTypes := make(map[uint64]typeinfo.TypeInfo)
   113  	for tag, ts := range inf.inferSets {
   114  		inferredTypes[tag] = findCommonType(ts)
   115  	}
   116  
   117  	var cols []schema.Column
   118  	_ = inf.readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   119  		col.Name = inf.mapper.Map(col.Name)
   120  		col.Kind = inferredTypes[tag].NomsKind()
   121  		col.TypeInfo = inferredTypes[tag]
   122  		col.Tag = schema.ReservedTagMin + tag
   123  
   124  		col.Constraints = []schema.ColConstraint{schema.NotNullConstraint{}}
   125  		if inf.nullable.Contains(tag) {
   126  			col.Constraints = []schema.ColConstraint(nil)
   127  		}
   128  
   129  		cols = append(cols, col)
   130  		return false, nil
   131  	})
   132  
   133  	return schema.NewColCollection(cols...), nil
   134  }
   135  
   136  func (inf *inferrer) sinkRow(p *pipeline.Pipeline, ch <-chan pipeline.RowWithProps, badRowChan chan<- *pipeline.TransformRowFailure) {
   137  	for r := range ch {
   138  		_, _ = r.Row.IterSchema(inf.readerSch, func(tag uint64, val types.Value) (stop bool, err error) {
   139  			if val == nil {
   140  				inf.nullable.Add(tag)
   141  				return false, nil
   142  			}
   143  			strVal := string(val.(types.String))
   144  			typeInfo := leastPermissiveType(strVal, inf.floatThreshold)
   145  			inf.inferSets[tag][typeInfo] = struct{}{}
   146  			return false, nil
   147  		})
   148  	}
   149  }
   150  
   151  func leastPermissiveType(strVal string, floatThreshold float64) typeinfo.TypeInfo {
   152  	if len(strVal) == 0 {
   153  		return typeinfo.UnknownType
   154  	}
   155  	strVal = strings.TrimSpace(strVal)
   156  
   157  	numType := leastPermissiveNumericType(strVal, floatThreshold)
   158  	if numType != typeinfo.UnknownType {
   159  		return numType
   160  	}
   161  
   162  	chronoType := leastPermissiveChronoType(strVal)
   163  	if chronoType != typeinfo.UnknownType {
   164  		return chronoType
   165  	}
   166  
   167  	_, err := uuid.Parse(strVal)
   168  	if err == nil {
   169  		return typeinfo.UuidType
   170  	}
   171  
   172  	strVal = strings.ToLower(strVal)
   173  	if strVal == "true" || strVal == "false" {
   174  		return typeinfo.BoolType
   175  	}
   176  
   177  	return typeinfo.StringDefaultType
   178  }
   179  
   180  func leastPermissiveNumericType(strVal string, floatThreshold float64) (ti typeinfo.TypeInfo) {
   181  	if strings.Contains(strVal, ".") {
   182  		f, err := strconv.ParseFloat(strVal, 64)
   183  		if err != nil {
   184  			return typeinfo.UnknownType
   185  		}
   186  
   187  		if math.Abs(f) < math.MaxFloat32 {
   188  			ti = typeinfo.Float32Type
   189  		} else {
   190  			ti = typeinfo.Float64Type
   191  		}
   192  
   193  		if floatThreshold != 0.0 {
   194  			floatParts := strings.Split(strVal, ".")
   195  			decimalPart, err := strconv.ParseFloat("0."+floatParts[1], 64)
   196  
   197  			if err != nil {
   198  				panic(err)
   199  			}
   200  
   201  			if decimalPart < floatThreshold {
   202  				if ti == typeinfo.Float32Type {
   203  					ti = typeinfo.Int32Type
   204  				} else {
   205  					ti = typeinfo.Int64Type
   206  				}
   207  			}
   208  		}
   209  		return ti
   210  	}
   211  
   212  	if strings.Contains(strVal, "-") {
   213  		i, err := strconv.ParseInt(strVal, 10, 64)
   214  		if err != nil {
   215  			return typeinfo.UnknownType
   216  		}
   217  		if i >= math.MinInt32 && i <= math.MaxInt32 {
   218  			return typeinfo.Int32Type
   219  		} else {
   220  			return typeinfo.Int64Type
   221  		}
   222  	} else {
   223  		ui, err := strconv.ParseUint(strVal, 10, 64)
   224  		if err != nil {
   225  			return typeinfo.UnknownType
   226  		}
   227  		if ui <= math.MaxUint32 {
   228  			return typeinfo.Uint32Type
   229  		} else {
   230  			return typeinfo.Uint64Type
   231  		}
   232  	}
   233  }
   234  
   235  func leastPermissiveChronoType(strVal string) typeinfo.TypeInfo {
   236  	if strVal == "" {
   237  		return typeinfo.UnknownType
   238  	}
   239  	_, err := typeinfo.TimeType.ParseValue(context.Background(), nil, &strVal)
   240  	if err == nil {
   241  		return typeinfo.TimeType
   242  	}
   243  
   244  	dt, err := typeinfo.DatetimeType.ParseValue(context.Background(), nil, &strVal)
   245  	if err != nil {
   246  		return typeinfo.UnknownType
   247  	}
   248  
   249  	t := time.Time(dt.(types.Timestamp))
   250  	if t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 {
   251  		return typeinfo.DateType
   252  	}
   253  
   254  	return typeinfo.DatetimeType
   255  }
   256  
   257  func chronoTypes() []typeinfo.TypeInfo {
   258  	return []typeinfo.TypeInfo{
   259  		// chrono types YEAR, DATE, and TIME can also be parsed as DATETIME
   260  		// we prefer less permissive types if possible
   261  		typeinfo.YearType,
   262  		typeinfo.DateType,
   263  		typeinfo.TimeType,
   264  		typeinfo.TimestampType,
   265  		typeinfo.DatetimeType,
   266  	}
   267  }
   268  
   269  // ordered from least to most permissive
   270  func numericTypes() []typeinfo.TypeInfo {
   271  	// prefer:
   272  	//   ints over floats
   273  	//   unsigned over signed
   274  	//   smaller over larger
   275  	return []typeinfo.TypeInfo{
   276  		//typeinfo.Uint8Type,
   277  		//typeinfo.Uint16Type,
   278  		//typeinfo.Uint24Type,
   279  		typeinfo.Uint32Type,
   280  		typeinfo.Uint64Type,
   281  
   282  		//typeinfo.Int8Type,
   283  		//typeinfo.Int16Type,
   284  		//typeinfo.Int24Type,
   285  		typeinfo.Int32Type,
   286  		typeinfo.Int64Type,
   287  
   288  		typeinfo.Float32Type,
   289  		typeinfo.Float64Type,
   290  	}
   291  }
   292  
   293  func setHasType(ts typeInfoSet, t typeinfo.TypeInfo) bool {
   294  	_, found := ts[t]
   295  	return found
   296  }
   297  
   298  // findCommonType takes a set of types and finds the least permissive
   299  // (ie most specific) common type between all types in the set
   300  func findCommonType(ts typeInfoSet) typeinfo.TypeInfo {
   301  
   302  	// empty values were inferred as UnknownType
   303  	delete(ts, typeinfo.UnknownType)
   304  
   305  	if len(ts) == 0 {
   306  		// use strings if all values were empty
   307  		return typeinfo.StringDefaultType
   308  	}
   309  
   310  	if len(ts) == 1 {
   311  		for ti := range ts {
   312  			return ti
   313  		}
   314  	}
   315  
   316  	// len(ts) > 1
   317  
   318  	if setHasType(ts, typeinfo.StringDefaultType) {
   319  		return typeinfo.StringDefaultType
   320  	}
   321  
   322  	hasNumeric := false
   323  	for _, nt := range numericTypes() {
   324  		if setHasType(ts, nt) {
   325  			hasNumeric = true
   326  			break
   327  		}
   328  	}
   329  
   330  	hasNonNumeric := false
   331  	for _, nnt := range chronoTypes() {
   332  		if setHasType(ts, nnt) {
   333  			hasNonNumeric = true
   334  			break
   335  		}
   336  	}
   337  	if setHasType(ts, typeinfo.BoolType) || setHasType(ts, typeinfo.UuidType) {
   338  		hasNonNumeric = true
   339  	}
   340  
   341  	if hasNumeric && hasNonNumeric {
   342  		return typeinfo.StringDefaultType
   343  	}
   344  
   345  	if hasNumeric {
   346  		return findCommonNumericType(ts)
   347  	}
   348  
   349  	// find a common nonNumeric type
   350  
   351  	nonChronoTypes := []typeinfo.TypeInfo{
   352  		// todo: BIT implementation parses all uint8
   353  		//typeinfo.PseudoBoolType,
   354  		typeinfo.BoolType,
   355  		typeinfo.UuidType,
   356  	}
   357  	for _, nct := range nonChronoTypes {
   358  		if setHasType(ts, nct) {
   359  			// types in nonChronoTypes have only string
   360  			// as a common type with any other type
   361  			return typeinfo.StringDefaultType
   362  		}
   363  	}
   364  
   365  	return findCommonChronoType(ts)
   366  }
   367  
   368  func findCommonNumericType(nums typeInfoSet) typeinfo.TypeInfo {
   369  	// find a common numeric type
   370  	// iterate through types from most to least permissive
   371  	// return the most permissive type found
   372  	//   ints are a subset of floats
   373  	//   uints are a subset of ints
   374  	//   smaller widths are a subset of larger widths
   375  	mostToLeast := []typeinfo.TypeInfo{
   376  		typeinfo.Float64Type,
   377  		typeinfo.Float32Type,
   378  
   379  		// todo: can all Int64 fit in Float64?
   380  		typeinfo.Int64Type,
   381  		typeinfo.Int32Type,
   382  		typeinfo.Int24Type,
   383  		typeinfo.Int16Type,
   384  		typeinfo.Int8Type,
   385  
   386  		typeinfo.Uint64Type,
   387  		typeinfo.Uint32Type,
   388  		typeinfo.Uint24Type,
   389  		typeinfo.Uint16Type,
   390  		typeinfo.Uint8Type,
   391  	}
   392  	for _, numType := range mostToLeast {
   393  		if setHasType(nums, numType) {
   394  			return numType
   395  		}
   396  	}
   397  
   398  	panic("unreachable")
   399  }
   400  
   401  func findCommonChronoType(chronos typeInfoSet) typeinfo.TypeInfo {
   402  	if len(chronos) == 1 {
   403  		for ct := range chronos {
   404  			return ct
   405  		}
   406  	}
   407  
   408  	if setHasType(chronos, typeinfo.DatetimeType) {
   409  		return typeinfo.DatetimeType
   410  	}
   411  
   412  	hasTime := setHasType(chronos, typeinfo.TimeType) || setHasType(chronos, typeinfo.TimestampType)
   413  	hasDate := setHasType(chronos, typeinfo.DateType) || setHasType(chronos, typeinfo.YearType)
   414  
   415  	if hasTime && !hasDate {
   416  		return typeinfo.TimeType
   417  	}
   418  
   419  	if !hasTime && hasDate {
   420  		return typeinfo.DateType
   421  	}
   422  
   423  	if hasDate && hasTime {
   424  		return typeinfo.DatetimeType
   425  	}
   426  
   427  	panic("unreachable")
   428  }