github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/env/actions/infer_schema.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package actions 16 17 import ( 18 "context" 19 "math" 20 "strconv" 21 "strings" 22 "time" 23 24 "github.com/google/uuid" 25 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/rowconv" 28 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 29 "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" 30 "github.com/dolthub/dolt/go/libraries/doltcore/table" 31 "github.com/dolthub/dolt/go/libraries/doltcore/table/pipeline" 32 "github.com/dolthub/dolt/go/libraries/utils/set" 33 "github.com/dolthub/dolt/go/store/types" 34 ) 35 36 type typeInfoSet map[typeinfo.TypeInfo]struct{} 37 38 const ( 39 maxUint24 = 1<<24 - 1 40 minInt24 = -1 << 23 41 ) 42 43 // InferenceArgs are arguments that can be passed to the schema inferrer to modify it's inference behavior. 44 type InferenceArgs interface { 45 // ColNameMapper allows columns named X in the schema to be named Y in the inferred schema. 46 ColNameMapper() rowconv.NameMapper 47 // FloatThreshold is the threshold at which a string representing a floating point number should be interpreted as 48 // a float versus an int. If FloatThreshold is 0.0 then any number with a decimal point will be interpreted as a 49 // float (such as 0.0, 1.0, etc). If FloatThreshold is 1.0 then any number with a decimal point will be converted 50 // to an int (0.5 will be the int 0, 1.99 will be the int 1, etc. If the FloatThreshold is 0.001 then numbers with 51 // a fractional component greater than or equal to 0.001 will be treated as a float (1.0 would be an int, 1.0009 would 52 // be an int, 1.001 would be a float, 1.1 would be a float, etc) 53 FloatThreshold() float64 54 } 55 56 // InferColumnTypesFromTableReader will infer a data types from a table reader. 57 func InferColumnTypesFromTableReader(ctx context.Context, root *doltdb.RootValue, rd table.TableReadCloser, args InferenceArgs) (*schema.ColCollection, error) { 58 inferrer := newInferrer(rd.GetSchema(), args) 59 60 var rowFailure *pipeline.TransformRowFailure 61 badRow := func(trf *pipeline.TransformRowFailure) (quit bool) { 62 rowFailure = trf 63 return false 64 } 65 66 rdProcFunc := pipeline.ProcFuncForReader(ctx, rd) 67 p := pipeline.NewAsyncPipeline(rdProcFunc, inferrer.sinkRow, nil, badRow) 68 p.Start() 69 70 err := p.Wait() 71 72 if err != nil { 73 return nil, err 74 } 75 76 if rowFailure != nil { 77 return nil, rowFailure 78 } 79 80 return inferrer.inferColumnTypes(ctx, root) 81 } 82 83 type inferrer struct { 84 readerSch schema.Schema 85 inferSets map[uint64]typeInfoSet 86 nullable *set.Uint64Set 87 mapper rowconv.NameMapper 88 floatThreshold float64 89 90 //inferArgs *InferenceArgs 91 } 92 93 func newInferrer(readerSch schema.Schema, args InferenceArgs) *inferrer { 94 inferSets := make(map[uint64]typeInfoSet, readerSch.GetAllCols().Size()) 95 _ = readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 96 inferSets[tag] = make(typeInfoSet) 97 return false, nil 98 }) 99 100 return &inferrer{ 101 readerSch: readerSch, 102 inferSets: inferSets, 103 nullable: set.NewUint64Set(nil), 104 mapper: args.ColNameMapper(), 105 floatThreshold: args.FloatThreshold(), 106 } 107 } 108 109 // inferColumnTypes returns TableReader's columns with updated TypeInfo and columns names 110 func (inf *inferrer) inferColumnTypes(ctx context.Context, root *doltdb.RootValue) (*schema.ColCollection, error) { 111 112 inferredTypes := make(map[uint64]typeinfo.TypeInfo) 113 for tag, ts := range inf.inferSets { 114 inferredTypes[tag] = findCommonType(ts) 115 } 116 117 var cols []schema.Column 118 _ = inf.readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 119 col.Name = inf.mapper.Map(col.Name) 120 col.Kind = inferredTypes[tag].NomsKind() 121 col.TypeInfo = inferredTypes[tag] 122 col.Tag = schema.ReservedTagMin + tag 123 124 col.Constraints = []schema.ColConstraint{schema.NotNullConstraint{}} 125 if inf.nullable.Contains(tag) { 126 col.Constraints = []schema.ColConstraint(nil) 127 } 128 129 cols = append(cols, col) 130 return false, nil 131 }) 132 133 return schema.NewColCollection(cols...), nil 134 } 135 136 func (inf *inferrer) sinkRow(p *pipeline.Pipeline, ch <-chan pipeline.RowWithProps, badRowChan chan<- *pipeline.TransformRowFailure) { 137 for r := range ch { 138 _, _ = r.Row.IterSchema(inf.readerSch, func(tag uint64, val types.Value) (stop bool, err error) { 139 if val == nil { 140 inf.nullable.Add(tag) 141 return false, nil 142 } 143 strVal := string(val.(types.String)) 144 typeInfo := leastPermissiveType(strVal, inf.floatThreshold) 145 inf.inferSets[tag][typeInfo] = struct{}{} 146 return false, nil 147 }) 148 } 149 } 150 151 func leastPermissiveType(strVal string, floatThreshold float64) typeinfo.TypeInfo { 152 if len(strVal) == 0 { 153 return typeinfo.UnknownType 154 } 155 strVal = strings.TrimSpace(strVal) 156 157 numType := leastPermissiveNumericType(strVal, floatThreshold) 158 if numType != typeinfo.UnknownType { 159 return numType 160 } 161 162 chronoType := leastPermissiveChronoType(strVal) 163 if chronoType != typeinfo.UnknownType { 164 return chronoType 165 } 166 167 _, err := uuid.Parse(strVal) 168 if err == nil { 169 return typeinfo.UuidType 170 } 171 172 strVal = strings.ToLower(strVal) 173 if strVal == "true" || strVal == "false" { 174 return typeinfo.BoolType 175 } 176 177 return typeinfo.StringDefaultType 178 } 179 180 func leastPermissiveNumericType(strVal string, floatThreshold float64) (ti typeinfo.TypeInfo) { 181 if strings.Contains(strVal, ".") { 182 f, err := strconv.ParseFloat(strVal, 64) 183 if err != nil { 184 return typeinfo.UnknownType 185 } 186 187 if math.Abs(f) < math.MaxFloat32 { 188 ti = typeinfo.Float32Type 189 } else { 190 ti = typeinfo.Float64Type 191 } 192 193 if floatThreshold != 0.0 { 194 floatParts := strings.Split(strVal, ".") 195 decimalPart, err := strconv.ParseFloat("0."+floatParts[1], 64) 196 197 if err != nil { 198 panic(err) 199 } 200 201 if decimalPart < floatThreshold { 202 if ti == typeinfo.Float32Type { 203 ti = typeinfo.Int32Type 204 } else { 205 ti = typeinfo.Int64Type 206 } 207 } 208 } 209 return ti 210 } 211 212 if strings.Contains(strVal, "-") { 213 i, err := strconv.ParseInt(strVal, 10, 64) 214 if err != nil { 215 return typeinfo.UnknownType 216 } 217 if i >= math.MinInt32 && i <= math.MaxInt32 { 218 return typeinfo.Int32Type 219 } else { 220 return typeinfo.Int64Type 221 } 222 } else { 223 ui, err := strconv.ParseUint(strVal, 10, 64) 224 if err != nil { 225 return typeinfo.UnknownType 226 } 227 if ui <= math.MaxUint32 { 228 return typeinfo.Uint32Type 229 } else { 230 return typeinfo.Uint64Type 231 } 232 } 233 } 234 235 func leastPermissiveChronoType(strVal string) typeinfo.TypeInfo { 236 if strVal == "" { 237 return typeinfo.UnknownType 238 } 239 _, err := typeinfo.TimeType.ParseValue(context.Background(), nil, &strVal) 240 if err == nil { 241 return typeinfo.TimeType 242 } 243 244 dt, err := typeinfo.DatetimeType.ParseValue(context.Background(), nil, &strVal) 245 if err != nil { 246 return typeinfo.UnknownType 247 } 248 249 t := time.Time(dt.(types.Timestamp)) 250 if t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 { 251 return typeinfo.DateType 252 } 253 254 return typeinfo.DatetimeType 255 } 256 257 func chronoTypes() []typeinfo.TypeInfo { 258 return []typeinfo.TypeInfo{ 259 // chrono types YEAR, DATE, and TIME can also be parsed as DATETIME 260 // we prefer less permissive types if possible 261 typeinfo.YearType, 262 typeinfo.DateType, 263 typeinfo.TimeType, 264 typeinfo.TimestampType, 265 typeinfo.DatetimeType, 266 } 267 } 268 269 // ordered from least to most permissive 270 func numericTypes() []typeinfo.TypeInfo { 271 // prefer: 272 // ints over floats 273 // unsigned over signed 274 // smaller over larger 275 return []typeinfo.TypeInfo{ 276 //typeinfo.Uint8Type, 277 //typeinfo.Uint16Type, 278 //typeinfo.Uint24Type, 279 typeinfo.Uint32Type, 280 typeinfo.Uint64Type, 281 282 //typeinfo.Int8Type, 283 //typeinfo.Int16Type, 284 //typeinfo.Int24Type, 285 typeinfo.Int32Type, 286 typeinfo.Int64Type, 287 288 typeinfo.Float32Type, 289 typeinfo.Float64Type, 290 } 291 } 292 293 func setHasType(ts typeInfoSet, t typeinfo.TypeInfo) bool { 294 _, found := ts[t] 295 return found 296 } 297 298 // findCommonType takes a set of types and finds the least permissive 299 // (ie most specific) common type between all types in the set 300 func findCommonType(ts typeInfoSet) typeinfo.TypeInfo { 301 302 // empty values were inferred as UnknownType 303 delete(ts, typeinfo.UnknownType) 304 305 if len(ts) == 0 { 306 // use strings if all values were empty 307 return typeinfo.StringDefaultType 308 } 309 310 if len(ts) == 1 { 311 for ti := range ts { 312 return ti 313 } 314 } 315 316 // len(ts) > 1 317 318 if setHasType(ts, typeinfo.StringDefaultType) { 319 return typeinfo.StringDefaultType 320 } 321 322 hasNumeric := false 323 for _, nt := range numericTypes() { 324 if setHasType(ts, nt) { 325 hasNumeric = true 326 break 327 } 328 } 329 330 hasNonNumeric := false 331 for _, nnt := range chronoTypes() { 332 if setHasType(ts, nnt) { 333 hasNonNumeric = true 334 break 335 } 336 } 337 if setHasType(ts, typeinfo.BoolType) || setHasType(ts, typeinfo.UuidType) { 338 hasNonNumeric = true 339 } 340 341 if hasNumeric && hasNonNumeric { 342 return typeinfo.StringDefaultType 343 } 344 345 if hasNumeric { 346 return findCommonNumericType(ts) 347 } 348 349 // find a common nonNumeric type 350 351 nonChronoTypes := []typeinfo.TypeInfo{ 352 // todo: BIT implementation parses all uint8 353 //typeinfo.PseudoBoolType, 354 typeinfo.BoolType, 355 typeinfo.UuidType, 356 } 357 for _, nct := range nonChronoTypes { 358 if setHasType(ts, nct) { 359 // types in nonChronoTypes have only string 360 // as a common type with any other type 361 return typeinfo.StringDefaultType 362 } 363 } 364 365 return findCommonChronoType(ts) 366 } 367 368 func findCommonNumericType(nums typeInfoSet) typeinfo.TypeInfo { 369 // find a common numeric type 370 // iterate through types from most to least permissive 371 // return the most permissive type found 372 // ints are a subset of floats 373 // uints are a subset of ints 374 // smaller widths are a subset of larger widths 375 mostToLeast := []typeinfo.TypeInfo{ 376 typeinfo.Float64Type, 377 typeinfo.Float32Type, 378 379 // todo: can all Int64 fit in Float64? 380 typeinfo.Int64Type, 381 typeinfo.Int32Type, 382 typeinfo.Int24Type, 383 typeinfo.Int16Type, 384 typeinfo.Int8Type, 385 386 typeinfo.Uint64Type, 387 typeinfo.Uint32Type, 388 typeinfo.Uint24Type, 389 typeinfo.Uint16Type, 390 typeinfo.Uint8Type, 391 } 392 for _, numType := range mostToLeast { 393 if setHasType(nums, numType) { 394 return numType 395 } 396 } 397 398 panic("unreachable") 399 } 400 401 func findCommonChronoType(chronos typeInfoSet) typeinfo.TypeInfo { 402 if len(chronos) == 1 { 403 for ct := range chronos { 404 return ct 405 } 406 } 407 408 if setHasType(chronos, typeinfo.DatetimeType) { 409 return typeinfo.DatetimeType 410 } 411 412 hasTime := setHasType(chronos, typeinfo.TimeType) || setHasType(chronos, typeinfo.TimestampType) 413 hasDate := setHasType(chronos, typeinfo.DateType) || setHasType(chronos, typeinfo.YearType) 414 415 if hasTime && !hasDate { 416 return typeinfo.TimeType 417 } 418 419 if !hasTime && hasDate { 420 return typeinfo.DateType 421 } 422 423 if hasDate && hasTime { 424 return typeinfo.DatetimeType 425 } 426 427 panic("unreachable") 428 }