github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/env/actions/infer_schema.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package actions 16 17 import ( 18 "context" 19 "encoding/json" 20 "errors" 21 "io" 22 "math" 23 "strconv" 24 "strings" 25 "time" 26 27 "github.com/google/uuid" 28 29 "github.com/dolthub/dolt/go/libraries/doltcore/row" 30 "github.com/dolthub/dolt/go/libraries/doltcore/rowconv" 31 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 32 "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" 33 "github.com/dolthub/dolt/go/libraries/doltcore/table" 34 "github.com/dolthub/dolt/go/libraries/utils/set" 35 "github.com/dolthub/dolt/go/store/types" 36 ) 37 38 type typeInfoSet map[typeinfo.TypeInfo]struct{} 39 40 const ( 41 maxUint24 = 1<<24 - 1 42 minInt24 = -1 << 23 43 ) 44 45 // InferenceArgs are arguments that can be passed to the schema inferrer to modify it's inference behavior. 46 type InferenceArgs interface { 47 // ColNameMapper allows columns named X in the schema to be named Y in the inferred schema. 48 ColNameMapper() rowconv.NameMapper 49 // FloatThreshold is the threshold at which a string representing a floating point number should be interpreted as 50 // a float versus an int. If FloatThreshold is 0.0 then any number with a decimal point will be interpreted as a 51 // float (such as 0.0, 1.0, etc). If FloatThreshold is 1.0 then any number with a decimal point will be converted 52 // to an int (0.5 will be the int 0, 1.99 will be the int 1, etc. If the FloatThreshold is 0.001 then numbers with 53 // a fractional component greater than or equal to 0.001 will be treated as a float (1.0 would be an int, 1.0009 would 54 // be an int, 1.001 would be a float, 1.1 would be a float, etc) 55 FloatThreshold() float64 56 } 57 58 // InferColumnTypesFromTableReader will infer a data types from a table reader. 59 func InferColumnTypesFromTableReader(ctx context.Context, rd table.ReadCloser, args InferenceArgs) (*schema.ColCollection, error) { 60 // for large imports, we want to sample a subset of the rows. 61 // skip through the file in an exponential manner 62 const exp = 1.02 63 64 var curr, prev row.Row 65 i := newInferrer(rd.GetSchema(), args) 66 OUTER: 67 for j := 0; true; j++ { 68 var err error 69 70 next := int(math.Pow(exp, float64(j))) 71 for n := 0; n < next; n++ { 72 curr, err = rd.ReadRow(ctx) 73 if err == io.EOF { 74 break OUTER 75 } else if err != nil { 76 return nil, err 77 } 78 prev = curr 79 } 80 if err = i.processRow(curr); err != nil { 81 return nil, err 82 } 83 } 84 85 // always process last row 86 if prev != nil { 87 if err := i.processRow(prev); err != nil { 88 return nil, err 89 } 90 } 91 92 return i.inferColumnTypes() 93 } 94 95 type inferrer struct { 96 readerSch schema.Schema 97 inferSets map[uint64]typeInfoSet 98 nullable *set.Uint64Set 99 mapper rowconv.NameMapper 100 floatThreshold float64 101 } 102 103 func newInferrer(readerSch schema.Schema, args InferenceArgs) *inferrer { 104 inferSets := make(map[uint64]typeInfoSet, readerSch.GetAllCols().Size()) 105 _ = readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 106 inferSets[tag] = make(typeInfoSet) 107 return false, nil 108 }) 109 110 return &inferrer{ 111 readerSch: readerSch, 112 inferSets: inferSets, 113 nullable: set.NewUint64Set(nil), 114 mapper: args.ColNameMapper(), 115 floatThreshold: args.FloatThreshold(), 116 } 117 } 118 119 // inferColumnTypes returns TableReader's columns with updated TypeInfo and columns names 120 func (inf *inferrer) inferColumnTypes() (*schema.ColCollection, error) { 121 122 inferredTypes := make(map[uint64]typeinfo.TypeInfo) 123 for tag, ts := range inf.inferSets { 124 inferredTypes[tag] = findCommonType(ts) 125 } 126 127 var cols []schema.Column 128 _ = inf.readerSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 129 col.Name = inf.mapper.Map(col.Name) 130 col.Kind = inferredTypes[tag].NomsKind() 131 col.TypeInfo = inferredTypes[tag] 132 col.Tag = schema.ReservedTagMin + tag 133 134 // for large imports, it is possible to miss all the null values, so we cannot accurately add not null constraint 135 col.Constraints = []schema.ColConstraint(nil) 136 137 cols = append(cols, col) 138 return false, nil 139 }) 140 141 return schema.NewColCollection(cols...), nil 142 } 143 144 func (inf *inferrer) processRow(r row.Row) error { 145 _, err := r.IterSchema(inf.readerSch, func(tag uint64, val types.Value) (stop bool, err error) { 146 if val == nil { 147 inf.nullable.Add(tag) 148 return false, nil 149 } 150 strVal := string(val.(types.String)) 151 typeInfo := leastPermissiveType(strVal, inf.floatThreshold) 152 inf.inferSets[tag][typeInfo] = struct{}{} 153 return false, nil 154 }) 155 156 return err 157 } 158 159 func leastPermissiveType(strVal string, floatThreshold float64) typeinfo.TypeInfo { 160 if len(strVal) == 0 { 161 return typeinfo.UnknownType 162 } 163 strVal = strings.TrimSpace(strVal) 164 165 numType := leastPermissiveNumericType(strVal, floatThreshold) 166 if numType != typeinfo.UnknownType { 167 return numType 168 } 169 170 _, err := uuid.Parse(strVal) 171 if err == nil { 172 return typeinfo.UuidType 173 } 174 175 chronoType := leastPermissiveChronoType(strVal) 176 if chronoType != typeinfo.UnknownType { 177 return chronoType 178 } 179 180 strVal = strings.ToLower(strVal) 181 if strVal == "true" || strVal == "false" { 182 return typeinfo.BoolType 183 } 184 185 if strings.Contains(strVal, "{") || strings.Contains(strVal, "[") { 186 var j interface{} 187 err := json.Unmarshal([]byte(strVal), &j) 188 if err == nil { 189 return typeinfo.JSONType 190 } 191 } 192 193 if int64(len(strVal)) > typeinfo.MaxVarcharLength { 194 return typeinfo.TextType 195 } else { 196 return typeinfo.StringDefaultType 197 } 198 } 199 200 func leastPermissiveNumericType(strVal string, floatThreshold float64) (ti typeinfo.TypeInfo) { 201 if strings.Contains(strVal, ".") { 202 f, err := strconv.ParseFloat(strVal, 64) 203 if err != nil { 204 return typeinfo.UnknownType 205 } 206 207 if math.Abs(f) < math.MaxFloat32 { 208 ti = typeinfo.Float32Type 209 } else { 210 ti = typeinfo.Float64Type 211 } 212 213 if floatThreshold != 0.0 { 214 floatParts := strings.Split(strVal, ".") 215 decimalPart, err := strconv.ParseFloat("0."+floatParts[1], 64) 216 217 if err != nil { 218 panic(err) 219 } 220 221 if decimalPart < floatThreshold { 222 if ti == typeinfo.Float32Type { 223 ti = typeinfo.Int32Type 224 } else { 225 ti = typeinfo.Int64Type 226 } 227 } 228 } 229 return ti 230 } 231 232 // always parse as signed int 233 i, err := strconv.ParseInt(strVal, 10, 64) 234 235 // use string for out of range 236 if errors.Is(err, strconv.ErrRange) { 237 return typeinfo.StringDefaultType 238 } 239 240 if err != nil { 241 return typeinfo.UnknownType 242 } 243 244 // handle leading zero case 245 if len(strVal) > 1 && strVal[0] == '0' { 246 return typeinfo.StringDefaultType 247 } 248 249 if i >= math.MinInt32 && i <= math.MaxInt32 { 250 return typeinfo.Int32Type 251 } else { 252 return typeinfo.Int64Type 253 } 254 } 255 256 func leastPermissiveChronoType(strVal string) typeinfo.TypeInfo { 257 if strVal == "" { 258 return typeinfo.UnknownType 259 } 260 261 dt, err := typeinfo.StringDefaultType.ConvertToType(context.Background(), nil, typeinfo.DatetimeType, types.String(strVal)) 262 if err == nil { 263 t := time.Time(dt.(types.Timestamp)) 264 if t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 { 265 return typeinfo.DateType 266 } 267 268 return typeinfo.DatetimeType 269 } 270 271 _, err = typeinfo.StringDefaultType.ConvertToType(context.Background(), nil, typeinfo.TimeType, types.String(strVal)) 272 if err == nil { 273 return typeinfo.TimeType 274 } 275 276 return typeinfo.UnknownType 277 } 278 279 func chronoTypes() []typeinfo.TypeInfo { 280 return []typeinfo.TypeInfo{ 281 // chrono types YEAR, DATE, and TIME can also be parsed as DATETIME 282 // we prefer less permissive types if possible 283 typeinfo.YearType, 284 typeinfo.DateType, 285 typeinfo.TimeType, 286 typeinfo.TimestampType, 287 typeinfo.DatetimeType, 288 } 289 } 290 291 // ordered from least to most permissive 292 func numericTypes() []typeinfo.TypeInfo { 293 // prefer: 294 // ints over floats 295 // smaller over larger 296 return []typeinfo.TypeInfo{ 297 //typeinfo.Uint8Type, 298 //typeinfo.Uint16Type, 299 //typeinfo.Uint24Type, 300 //typeinfo.Uint32Type, 301 //typeinfo.Uint64Type, 302 303 //typeinfo.Int8Type, 304 //typeinfo.Int16Type, 305 //typeinfo.Int24Type, 306 typeinfo.Int32Type, 307 typeinfo.Int64Type, 308 309 typeinfo.Float32Type, 310 typeinfo.Float64Type, 311 } 312 } 313 314 func setHasType(ts typeInfoSet, t typeinfo.TypeInfo) bool { 315 _, found := ts[t] 316 return found 317 } 318 319 // findCommonType takes a set of types and finds the least permissive 320 // (ie most specific) common type between all types in the set 321 func findCommonType(ts typeInfoSet) typeinfo.TypeInfo { 322 323 // empty values were inferred as UnknownType 324 delete(ts, typeinfo.UnknownType) 325 326 if len(ts) == 0 { 327 // use strings if all values were empty 328 return typeinfo.StringDefaultType 329 } 330 331 if len(ts) == 1 { 332 for ti := range ts { 333 return ti 334 } 335 } 336 337 // len(ts) > 1 338 339 if setHasType(ts, typeinfo.TextType) { 340 return typeinfo.TextType 341 } else if setHasType(ts, typeinfo.StringDefaultType) { 342 return typeinfo.StringDefaultType 343 } else if setHasType(ts, typeinfo.StringDefaultType) { 344 return typeinfo.StringDefaultType 345 } 346 347 hasNumeric := false 348 for _, nt := range numericTypes() { 349 if setHasType(ts, nt) { 350 hasNumeric = true 351 break 352 } 353 } 354 355 hasNonNumeric := false 356 for _, nnt := range chronoTypes() { 357 if setHasType(ts, nnt) { 358 hasNonNumeric = true 359 break 360 } 361 } 362 if setHasType(ts, typeinfo.BoolType) || setHasType(ts, typeinfo.UuidType) { 363 hasNonNumeric = true 364 } 365 366 if hasNumeric && hasNonNumeric { 367 return typeinfo.StringDefaultType 368 } 369 370 if hasNumeric { 371 return findCommonNumericType(ts) 372 } 373 374 // find a common nonNumeric type 375 376 nonChronoTypes := []typeinfo.TypeInfo{ 377 // todo: BIT implementation parses all uint8 378 //typeinfo.PseudoBoolType, 379 typeinfo.BoolType, 380 typeinfo.UuidType, 381 } 382 for _, nct := range nonChronoTypes { 383 if setHasType(ts, nct) { 384 // types in nonChronoTypes have only string 385 // as a common type with any other type 386 return typeinfo.StringDefaultType 387 } 388 } 389 390 return findCommonChronoType(ts) 391 } 392 393 func findCommonNumericType(nums typeInfoSet) typeinfo.TypeInfo { 394 // find a common numeric type 395 // iterate through types from most to least permissive 396 // return the most permissive type found 397 // ints are a subset of floats 398 // uints are a subset of ints 399 // smaller widths are a subset of larger widths 400 mostToLeast := []typeinfo.TypeInfo{ 401 typeinfo.Float64Type, 402 typeinfo.Float32Type, 403 404 // todo: can all Int64 fit in Float64? 405 typeinfo.Int64Type, 406 typeinfo.Int32Type, 407 typeinfo.Int24Type, 408 typeinfo.Int16Type, 409 typeinfo.Int8Type, 410 } 411 for _, numType := range mostToLeast { 412 if setHasType(nums, numType) { 413 return numType 414 } 415 } 416 417 panic("unreachable") 418 } 419 420 func findCommonChronoType(chronos typeInfoSet) typeinfo.TypeInfo { 421 if len(chronos) == 1 { 422 for ct := range chronos { 423 return ct 424 } 425 } 426 427 if setHasType(chronos, typeinfo.DatetimeType) { 428 return typeinfo.DatetimeType 429 } 430 431 hasTime := setHasType(chronos, typeinfo.TimeType) || setHasType(chronos, typeinfo.TimestampType) 432 hasDate := setHasType(chronos, typeinfo.DateType) || setHasType(chronos, typeinfo.YearType) 433 434 if hasTime && !hasDate { 435 return typeinfo.TimeType 436 } 437 438 if !hasTime && hasDate { 439 return typeinfo.DateType 440 } 441 442 if hasDate && hasTime { 443 return typeinfo.DatetimeType 444 } 445 446 panic("unreachable") 447 }