github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/env/actions/infer_schema.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package actions
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"errors"
    21  	"io"
    22  	"math"
    23  	"strconv"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/google/uuid"
    28  
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/table"
    34  	"github.com/dolthub/dolt/go/libraries/utils/set"
    35  	"github.com/dolthub/dolt/go/store/types"
    36  )
    37  
    38  type typeInfoSet map[typeinfo.TypeInfo]struct{}
    39  
    40  const (
    41  	maxUint24 = 1<<24 - 1
    42  	minInt24  = -1 << 23
    43  )
    44  
    45  // InferenceArgs are arguments that can be passed to the schema inferrer to modify it's inference behavior.
    46  type InferenceArgs interface {
    47  	// ColNameMapper allows columns named X in the schema to be named Y in the inferred schema.
    48  	ColNameMapper() rowconv.NameMapper
    49  	// FloatThreshold is the threshold at which a string representing a floating point number should be interpreted as
    50  	// a float versus an int.  If FloatThreshold is 0.0 then any number with a decimal point will be interpreted as a
    51  	// float (such as 0.0, 1.0, etc).  If FloatThreshold is 1.0 then any number with a decimal point will be converted
    52  	// to an int (0.5 will be the int 0, 1.99 will be the int 1, etc.  If the FloatThreshold is 0.001 then numbers with
    53  	// a fractional component greater than or equal to 0.001 will be treated as a float (1.0 would be an int, 1.0009 would
    54  	// be an int, 1.001 would be a float, 1.1 would be a float, etc)
    55  	FloatThreshold() float64
    56  }
    57  
    58  // InferColumnTypesFromTableReader will infer a data types from a table reader.
    59  func InferColumnTypesFromTableReader(ctx context.Context, rd table.ReadCloser, args InferenceArgs) (*schema.ColCollection, error) {
    60  	// for large imports, we want to sample a subset of the rows.
    61  	// skip through the file in an exponential manner
    62  	const exp = 1.02
    63  
    64  	var curr, prev row.Row
    65  	i := newInferrer(rd.GetSchema(), args)
    66  OUTER:
    67  	for j := 0; true; j++ {
    68  		var err error
    69  
    70  		next := int(math.Pow(exp, float64(j)))
    71  		for n := 0; n < next; n++ {
    72  			curr, err = rd.ReadRow(ctx)
    73  			if err == io.EOF {
    74  				break OUTER
    75  			} else if err != nil {
    76  				return nil, err
    77  			}
    78  			prev = curr
    79  		}
    80  		if err = i.processRow(curr); err != nil {
    81  			return nil, err
    82  		}
    83  	}
    84  
    85  	// always process last row
    86  	if prev != nil {
    87  		if err := i.processRow(prev); err != nil {
    88  			return nil, err
    89  		}
    90  	}
    91  
    92  	return i.inferColumnTypes()
    93  }
    94  
    95  type inferrer struct {
    96  	readerSch      schema.Schema
    97  	inferSets      map[uint64]typeInfoSet
    98  	nullable       *set.Uint64Set
    99  	mapper         rowconv.NameMapper
   100  	floatThreshold float64
   101  }
   102  
   103  func newInferrer(readerSch schema.Schema, args InferenceArgs) *inferrer {
   104  	inferSets := make(map[uint64]typeInfoSet, readerSch.GetAllCols().Size())
   105  	_ = readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   106  		inferSets[tag] = make(typeInfoSet)
   107  		return false, nil
   108  	})
   109  
   110  	return &inferrer{
   111  		readerSch:      readerSch,
   112  		inferSets:      inferSets,
   113  		nullable:       set.NewUint64Set(nil),
   114  		mapper:         args.ColNameMapper(),
   115  		floatThreshold: args.FloatThreshold(),
   116  	}
   117  }
   118  
   119  // inferColumnTypes returns TableReader's columns with updated TypeInfo and columns names
   120  func (inf *inferrer) inferColumnTypes() (*schema.ColCollection, error) {
   121  
   122  	inferredTypes := make(map[uint64]typeinfo.TypeInfo)
   123  	for tag, ts := range inf.inferSets {
   124  		inferredTypes[tag] = findCommonType(ts)
   125  	}
   126  
   127  	var cols []schema.Column
   128  	_ = inf.readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   129  		col.Name = inf.mapper.Map(col.Name)
   130  		col.Kind = inferredTypes[tag].NomsKind()
   131  		col.TypeInfo = inferredTypes[tag]
   132  		col.Tag = schema.ReservedTagMin + tag
   133  
   134  		// for large imports, it is possible to miss all the null values, so we cannot accurately add not null constraint
   135  		col.Constraints = []schema.ColConstraint(nil)
   136  
   137  		cols = append(cols, col)
   138  		return false, nil
   139  	})
   140  
   141  	return schema.NewColCollection(cols...), nil
   142  }
   143  
   144  func (inf *inferrer) processRow(r row.Row) error {
   145  	_, err := r.IterSchema(inf.readerSch, func(tag uint64, val types.Value) (stop bool, err error) {
   146  		if val == nil {
   147  			inf.nullable.Add(tag)
   148  			return false, nil
   149  		}
   150  		strVal := string(val.(types.String))
   151  		typeInfo := leastPermissiveType(strVal, inf.floatThreshold)
   152  		inf.inferSets[tag][typeInfo] = struct{}{}
   153  		return false, nil
   154  	})
   155  
   156  	return err
   157  }
   158  
   159  func leastPermissiveType(strVal string, floatThreshold float64) typeinfo.TypeInfo {
   160  	if len(strVal) == 0 {
   161  		return typeinfo.UnknownType
   162  	}
   163  	strVal = strings.TrimSpace(strVal)
   164  
   165  	numType := leastPermissiveNumericType(strVal, floatThreshold)
   166  	if numType != typeinfo.UnknownType {
   167  		return numType
   168  	}
   169  
   170  	_, err := uuid.Parse(strVal)
   171  	if err == nil {
   172  		return typeinfo.UuidType
   173  	}
   174  
   175  	chronoType := leastPermissiveChronoType(strVal)
   176  	if chronoType != typeinfo.UnknownType {
   177  		return chronoType
   178  	}
   179  
   180  	strVal = strings.ToLower(strVal)
   181  	if strVal == "true" || strVal == "false" {
   182  		return typeinfo.BoolType
   183  	}
   184  
   185  	if strings.Contains(strVal, "{") || strings.Contains(strVal, "[") {
   186  		var j interface{}
   187  		err := json.Unmarshal([]byte(strVal), &j)
   188  		if err == nil {
   189  			return typeinfo.JSONType
   190  		}
   191  	}
   192  
   193  	if int64(len(strVal)) > typeinfo.MaxVarcharLength {
   194  		return typeinfo.TextType
   195  	} else {
   196  		return typeinfo.StringDefaultType
   197  	}
   198  }
   199  
   200  func leastPermissiveNumericType(strVal string, floatThreshold float64) (ti typeinfo.TypeInfo) {
   201  	if strings.Contains(strVal, ".") {
   202  		f, err := strconv.ParseFloat(strVal, 64)
   203  		if err != nil {
   204  			return typeinfo.UnknownType
   205  		}
   206  
   207  		if math.Abs(f) < math.MaxFloat32 {
   208  			ti = typeinfo.Float32Type
   209  		} else {
   210  			ti = typeinfo.Float64Type
   211  		}
   212  
   213  		if floatThreshold != 0.0 {
   214  			floatParts := strings.Split(strVal, ".")
   215  			decimalPart, err := strconv.ParseFloat("0."+floatParts[1], 64)
   216  
   217  			if err != nil {
   218  				panic(err)
   219  			}
   220  
   221  			if decimalPart < floatThreshold {
   222  				if ti == typeinfo.Float32Type {
   223  					ti = typeinfo.Int32Type
   224  				} else {
   225  					ti = typeinfo.Int64Type
   226  				}
   227  			}
   228  		}
   229  		return ti
   230  	}
   231  
   232  	// always parse as signed int
   233  	i, err := strconv.ParseInt(strVal, 10, 64)
   234  
   235  	// use string for out of range
   236  	if errors.Is(err, strconv.ErrRange) {
   237  		return typeinfo.StringDefaultType
   238  	}
   239  
   240  	if err != nil {
   241  		return typeinfo.UnknownType
   242  	}
   243  
   244  	// handle leading zero case
   245  	if len(strVal) > 1 && strVal[0] == '0' {
   246  		return typeinfo.StringDefaultType
   247  	}
   248  
   249  	if i >= math.MinInt32 && i <= math.MaxInt32 {
   250  		return typeinfo.Int32Type
   251  	} else {
   252  		return typeinfo.Int64Type
   253  	}
   254  }
   255  
   256  func leastPermissiveChronoType(strVal string) typeinfo.TypeInfo {
   257  	if strVal == "" {
   258  		return typeinfo.UnknownType
   259  	}
   260  
   261  	dt, err := typeinfo.StringDefaultType.ConvertToType(context.Background(), nil, typeinfo.DatetimeType, types.String(strVal))
   262  	if err == nil {
   263  		t := time.Time(dt.(types.Timestamp))
   264  		if t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 {
   265  			return typeinfo.DateType
   266  		}
   267  
   268  		return typeinfo.DatetimeType
   269  	}
   270  
   271  	_, err = typeinfo.StringDefaultType.ConvertToType(context.Background(), nil, typeinfo.TimeType, types.String(strVal))
   272  	if err == nil {
   273  		return typeinfo.TimeType
   274  	}
   275  
   276  	return typeinfo.UnknownType
   277  }
   278  
   279  func chronoTypes() []typeinfo.TypeInfo {
   280  	return []typeinfo.TypeInfo{
   281  		// chrono types YEAR, DATE, and TIME can also be parsed as DATETIME
   282  		// we prefer less permissive types if possible
   283  		typeinfo.YearType,
   284  		typeinfo.DateType,
   285  		typeinfo.TimeType,
   286  		typeinfo.TimestampType,
   287  		typeinfo.DatetimeType,
   288  	}
   289  }
   290  
   291  // ordered from least to most permissive
   292  func numericTypes() []typeinfo.TypeInfo {
   293  	// prefer:
   294  	//   ints over floats
   295  	//   smaller over larger
   296  	return []typeinfo.TypeInfo{
   297  		//typeinfo.Uint8Type,
   298  		//typeinfo.Uint16Type,
   299  		//typeinfo.Uint24Type,
   300  		//typeinfo.Uint32Type,
   301  		//typeinfo.Uint64Type,
   302  
   303  		//typeinfo.Int8Type,
   304  		//typeinfo.Int16Type,
   305  		//typeinfo.Int24Type,
   306  		typeinfo.Int32Type,
   307  		typeinfo.Int64Type,
   308  
   309  		typeinfo.Float32Type,
   310  		typeinfo.Float64Type,
   311  	}
   312  }
   313  
   314  func setHasType(ts typeInfoSet, t typeinfo.TypeInfo) bool {
   315  	_, found := ts[t]
   316  	return found
   317  }
   318  
   319  // findCommonType takes a set of types and finds the least permissive
   320  // (ie most specific) common type between all types in the set
   321  func findCommonType(ts typeInfoSet) typeinfo.TypeInfo {
   322  
   323  	// empty values were inferred as UnknownType
   324  	delete(ts, typeinfo.UnknownType)
   325  
   326  	if len(ts) == 0 {
   327  		// use strings if all values were empty
   328  		return typeinfo.StringDefaultType
   329  	}
   330  
   331  	if len(ts) == 1 {
   332  		for ti := range ts {
   333  			return ti
   334  		}
   335  	}
   336  
   337  	// len(ts) > 1
   338  
   339  	if setHasType(ts, typeinfo.TextType) {
   340  		return typeinfo.TextType
   341  	} else if setHasType(ts, typeinfo.StringDefaultType) {
   342  		return typeinfo.StringDefaultType
   343  	} else if setHasType(ts, typeinfo.StringDefaultType) {
   344  		return typeinfo.StringDefaultType
   345  	}
   346  
   347  	hasNumeric := false
   348  	for _, nt := range numericTypes() {
   349  		if setHasType(ts, nt) {
   350  			hasNumeric = true
   351  			break
   352  		}
   353  	}
   354  
   355  	hasNonNumeric := false
   356  	for _, nnt := range chronoTypes() {
   357  		if setHasType(ts, nnt) {
   358  			hasNonNumeric = true
   359  			break
   360  		}
   361  	}
   362  	if setHasType(ts, typeinfo.BoolType) || setHasType(ts, typeinfo.UuidType) {
   363  		hasNonNumeric = true
   364  	}
   365  
   366  	if hasNumeric && hasNonNumeric {
   367  		return typeinfo.StringDefaultType
   368  	}
   369  
   370  	if hasNumeric {
   371  		return findCommonNumericType(ts)
   372  	}
   373  
   374  	// find a common nonNumeric type
   375  
   376  	nonChronoTypes := []typeinfo.TypeInfo{
   377  		// todo: BIT implementation parses all uint8
   378  		//typeinfo.PseudoBoolType,
   379  		typeinfo.BoolType,
   380  		typeinfo.UuidType,
   381  	}
   382  	for _, nct := range nonChronoTypes {
   383  		if setHasType(ts, nct) {
   384  			// types in nonChronoTypes have only string
   385  			// as a common type with any other type
   386  			return typeinfo.StringDefaultType
   387  		}
   388  	}
   389  
   390  	return findCommonChronoType(ts)
   391  }
   392  
   393  func findCommonNumericType(nums typeInfoSet) typeinfo.TypeInfo {
   394  	// find a common numeric type
   395  	// iterate through types from most to least permissive
   396  	// return the most permissive type found
   397  	//   ints are a subset of floats
   398  	//   uints are a subset of ints
   399  	//   smaller widths are a subset of larger widths
   400  	mostToLeast := []typeinfo.TypeInfo{
   401  		typeinfo.Float64Type,
   402  		typeinfo.Float32Type,
   403  
   404  		// todo: can all Int64 fit in Float64?
   405  		typeinfo.Int64Type,
   406  		typeinfo.Int32Type,
   407  		typeinfo.Int24Type,
   408  		typeinfo.Int16Type,
   409  		typeinfo.Int8Type,
   410  	}
   411  	for _, numType := range mostToLeast {
   412  		if setHasType(nums, numType) {
   413  			return numType
   414  		}
   415  	}
   416  
   417  	panic("unreachable")
   418  }
   419  
   420  func findCommonChronoType(chronos typeInfoSet) typeinfo.TypeInfo {
   421  	if len(chronos) == 1 {
   422  		for ct := range chronos {
   423  			return ct
   424  		}
   425  	}
   426  
   427  	if setHasType(chronos, typeinfo.DatetimeType) {
   428  		return typeinfo.DatetimeType
   429  	}
   430  
   431  	hasTime := setHasType(chronos, typeinfo.TimeType) || setHasType(chronos, typeinfo.TimestampType)
   432  	hasDate := setHasType(chronos, typeinfo.DateType) || setHasType(chronos, typeinfo.YearType)
   433  
   434  	if hasTime && !hasDate {
   435  		return typeinfo.TimeType
   436  	}
   437  
   438  	if !hasTime && hasDate {
   439  		return typeinfo.DateType
   440  	}
   441  
   442  	if hasDate && hasTime {
   443  		return typeinfo.DatetimeType
   444  	}
   445  
   446  	panic("unreachable")
   447  }