github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/env/actions/infer_schema_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package actions
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math"
    21  	"os"
    22  	"strconv"
    23  	"testing"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/table/untyped/csv"
    33  	"github.com/dolthub/dolt/go/libraries/utils/set"
    34  	"github.com/dolthub/dolt/go/store/types"
    35  )
    36  
    37  var maxIntPlusTwo uint64 = 1<<63 + 1
    38  
    39  func TestLeastPermissiveType(t *testing.T) {
    40  	tests := []struct {
    41  		name           string
    42  		valStr         string
    43  		floatThreshold float64
    44  		expType        typeinfo.TypeInfo
    45  	}{
    46  		{"empty string", "", 0.0, typeinfo.UnknownType},
    47  		{"valid uuid", "00000000-0000-0000-0000-000000000000", 0.0, typeinfo.UuidType},
    48  		{"invalid uuid", "00000000-0000-0000-0000-00000000000z", 0.0, typeinfo.StringDefaultType},
    49  		{"lower bool", "true", 0.0, typeinfo.BoolType},
    50  		{"upper bool", "FALSE", 0.0, typeinfo.BoolType},
    51  		{"yes", "yes", 0.0, typeinfo.StringDefaultType},
    52  		{"one", "1", 0.0, typeinfo.Int32Type},
    53  		{"negative one", "-1", 0.0, typeinfo.Int32Type},
    54  		{"negative one point 0", "-1.0", 0.0, typeinfo.Float32Type},
    55  		{"negative one point 0 with FT of 0.1", "-1.0", 0.1, typeinfo.Int32Type},
    56  		{"negative one point one with FT of 0.1", "-1.1", 0.1, typeinfo.Float32Type},
    57  		{"negative one point 999 with FT of 1.0", "-1.999", 1.0, typeinfo.Int32Type},
    58  		{"zero point zero zero zero zero", "0.0000", 0.0, typeinfo.Float32Type},
    59  		{"max int", strconv.FormatUint(math.MaxInt64, 10), 0.0, typeinfo.Int64Type},
    60  		{"bigger than max int", strconv.FormatUint(math.MaxUint64, 10) + "0", 0.0, typeinfo.StringDefaultType},
    61  	}
    62  
    63  	for _, test := range tests {
    64  		t.Run(test.name, func(t *testing.T) {
    65  			actualType := leastPermissiveType(test.valStr, test.floatThreshold)
    66  			assert.Equal(t, test.expType, actualType, "val: %s, expected: %v, actual: %v", test.valStr, test.expType, actualType)
    67  		})
    68  	}
    69  }
    70  
    71  func TestLeastPermissiveNumericType(t *testing.T) {
    72  	tests := []struct {
    73  		name           string
    74  		valStr         string
    75  		floatThreshold float64
    76  		expType        typeinfo.TypeInfo
    77  	}{
    78  		{"zero", "0", 0.0, typeinfo.Int32Type},
    79  		{"zero float", "0.0", 0.0, typeinfo.Float32Type},
    80  		{"zero float with floatThreshold of 0.1", "0.0", 0.1, typeinfo.Int32Type},
    81  		{"negative float", "-1.3451234", 0.0, typeinfo.Float32Type},
    82  		{"double decimal point", "0.00.0", 0.0, typeinfo.UnknownType},
    83  		{"leading zero floats", "05.78", 0.0, typeinfo.Float32Type},
    84  		{"zero float with high precision", "0.0000", 0.0, typeinfo.Float32Type},
    85  		{"all zeroes", "0000", 0.0, typeinfo.StringDefaultType},
    86  		{"leading zeroes", "01", 0.0, typeinfo.StringDefaultType},
    87  		{"negative int", "-1234", 0.0, typeinfo.Int32Type},
    88  		{"fits in uint64 but not int64", strconv.FormatUint(math.MaxUint64, 10), 0.0, typeinfo.StringDefaultType},
    89  		{"negative less than math.MinInt64", "-" + strconv.FormatUint(math.MaxUint64, 10), 0.0, typeinfo.StringDefaultType},
    90  		{"math.MinInt64", strconv.FormatInt(math.MinInt64, 10), 0.0, typeinfo.Int64Type},
    91  	}
    92  
    93  	for _, test := range tests {
    94  		t.Run(test.name, func(t *testing.T) {
    95  			actualType := leastPermissiveNumericType(test.valStr, test.floatThreshold)
    96  			assert.Equal(t, test.expType, actualType, "val: %s, expected: %v, actual: %v", test.valStr, test.expType, actualType)
    97  		})
    98  	}
    99  }
   100  
   101  func TestLeasPermissiveChronoType(t *testing.T) {
   102  	tests := []struct {
   103  		name    string
   104  		valStr  string
   105  		expType typeinfo.TypeInfo
   106  	}{
   107  		{"empty string", "", typeinfo.UnknownType},
   108  		{"random string", "asdf", typeinfo.UnknownType},
   109  		{"time", "9:27:10.485214", typeinfo.TimeType},
   110  		{"date", "2020-02-02", typeinfo.DateType},
   111  		{"also date", "2020-02-02 00:00:00.0", typeinfo.DateType},
   112  		{"datetime", "2030-01-02 04:06:03.472382", typeinfo.DatetimeType},
   113  	}
   114  
   115  	for _, test := range tests {
   116  		t.Run(test.name, func(t *testing.T) {
   117  			actualType := leastPermissiveChronoType(test.valStr)
   118  			assert.Equal(t, test.expType, actualType, "val: %s, expected: %v, actual: %v", test.valStr, test.expType, actualType)
   119  		})
   120  	}
   121  }
   122  
   123  type commonTypeTest struct {
   124  	name     string
   125  	inferSet typeInfoSet
   126  	expType  typeinfo.TypeInfo
   127  }
   128  
   129  func TestFindCommonType(t *testing.T) {
   130  	testFindCommonType(t)
   131  	testFindCommonTypeFromSingleType(t)
   132  	testFindCommonChronologicalType(t)
   133  }
   134  
   135  func testFindCommonType(t *testing.T) {
   136  	tests := []commonTypeTest{
   137  		{
   138  			name: "all signed ints",
   139  			inferSet: typeInfoSet{
   140  				typeinfo.Int32Type: {},
   141  				typeinfo.Int64Type: {},
   142  			},
   143  			expType: typeinfo.Int64Type,
   144  		},
   145  		{
   146  			name: "all floats",
   147  			inferSet: typeInfoSet{
   148  				typeinfo.Float32Type: {},
   149  				typeinfo.Float64Type: {},
   150  			},
   151  			expType: typeinfo.Float64Type,
   152  		},
   153  		{
   154  			name: "32 bit ints",
   155  			inferSet: typeInfoSet{
   156  				typeinfo.Int32Type: {},
   157  			},
   158  			expType: typeinfo.Int32Type,
   159  		},
   160  		{
   161  			name: "64 bit ints",
   162  			inferSet: typeInfoSet{
   163  				typeinfo.Int64Type: {},
   164  			},
   165  			expType: typeinfo.Int64Type,
   166  		},
   167  		{
   168  			name: "32 bit ints and floats",
   169  			inferSet: typeInfoSet{
   170  				typeinfo.Int32Type:   {},
   171  				typeinfo.Float32Type: {},
   172  			},
   173  			expType: typeinfo.Float32Type,
   174  		},
   175  		{
   176  			name: "64 bit ints and floats",
   177  			inferSet: typeInfoSet{
   178  				typeinfo.Int64Type:   {},
   179  				typeinfo.Float64Type: {},
   180  			},
   181  			expType: typeinfo.Float64Type,
   182  		},
   183  		{
   184  			name: "ints and bools",
   185  			inferSet: typeInfoSet{
   186  				typeinfo.Int32Type: {},
   187  				typeinfo.BoolType:  {},
   188  			},
   189  			expType: typeinfo.StringDefaultType,
   190  		},
   191  		{
   192  			name: "floats and bools",
   193  			inferSet: typeInfoSet{
   194  				typeinfo.Float32Type: {},
   195  				typeinfo.BoolType:    {},
   196  			},
   197  			expType: typeinfo.StringDefaultType,
   198  		},
   199  		{
   200  			name: "floats and uuids",
   201  			inferSet: typeInfoSet{
   202  				typeinfo.Float32Type: {},
   203  				typeinfo.UuidType:    {},
   204  			},
   205  			expType: typeinfo.StringDefaultType,
   206  		},
   207  	}
   208  
   209  	for _, test := range tests {
   210  		t.Run(test.name, func(t *testing.T) {
   211  			actualType := findCommonType(test.inferSet)
   212  			assert.Equal(t, test.expType, actualType)
   213  		})
   214  	}
   215  }
   216  
   217  func testFindCommonTypeFromSingleType(t *testing.T) {
   218  	allTypes := []typeinfo.TypeInfo{
   219  		typeinfo.Int8Type,
   220  		typeinfo.Int16Type,
   221  		typeinfo.Int24Type,
   222  		typeinfo.Int32Type,
   223  		typeinfo.Int64Type,
   224  		typeinfo.Float32Type,
   225  		typeinfo.Float64Type,
   226  		typeinfo.BoolType,
   227  		typeinfo.UuidType,
   228  		typeinfo.YearType,
   229  		typeinfo.DateType,
   230  		typeinfo.TimeType,
   231  		typeinfo.TimestampType,
   232  		typeinfo.DatetimeType,
   233  		typeinfo.StringDefaultType,
   234  	}
   235  
   236  	for _, ti := range allTypes {
   237  		tests := []commonTypeTest{
   238  			{
   239  				name: fmt.Sprintf("only %s", ti.String()),
   240  				inferSet: typeInfoSet{
   241  					ti: {},
   242  				},
   243  				expType: ti,
   244  			},
   245  			{
   246  				name: fmt.Sprintf("Unknown and %s", ti.String()),
   247  				inferSet: typeInfoSet{
   248  					ti:                   {},
   249  					typeinfo.UnknownType: {},
   250  				},
   251  				expType: ti,
   252  			},
   253  		}
   254  		for _, test := range tests {
   255  			t.Run(test.name, func(t *testing.T) {
   256  				actualType := findCommonType(test.inferSet)
   257  				assert.Equal(t, test.expType, actualType)
   258  			})
   259  		}
   260  	}
   261  }
   262  
   263  func testFindCommonChronologicalType(t *testing.T) {
   264  
   265  	tests := []commonTypeTest{
   266  		{
   267  			name: "date and time",
   268  			inferSet: typeInfoSet{
   269  				typeinfo.DateType: {},
   270  				typeinfo.TimeType: {},
   271  			},
   272  			expType: typeinfo.DatetimeType,
   273  		},
   274  		{
   275  			name: "date and datetime",
   276  			inferSet: typeInfoSet{
   277  				typeinfo.DateType:     {},
   278  				typeinfo.DatetimeType: {},
   279  			},
   280  			expType: typeinfo.DatetimeType,
   281  		},
   282  		{
   283  			name: "time and datetime",
   284  			inferSet: typeInfoSet{
   285  				typeinfo.TimeType:     {},
   286  				typeinfo.DatetimeType: {},
   287  			},
   288  			expType: typeinfo.DatetimeType,
   289  		},
   290  	}
   291  
   292  	for _, test := range tests {
   293  		t.Run(test.name, func(t *testing.T) {
   294  			actualType := findCommonType(test.inferSet)
   295  			assert.Equal(t, test.expType, actualType)
   296  		})
   297  	}
   298  }
   299  
   300  var oneOfEachKindCSVStr = `uuid,int,uint,float,bool,string
   301  00000000-0000-0000-0000-000000000000,-4,9223372036854775810,-4.1,true,this is
   302  f3c600cd-74d3-4c2d-b5a5-88934e7c6018,-3,9223372036854775810,-3.2,false,a test
   303  0b4fe00e-c690-4fbe-b014-b07f236cd7fd,-2,9223372036854775810,-2.3,TRUE,anything could
   304  7d85f566-4b60-4cc0-a029-cfb679c3fd8b,-1,9223372036854775810,-1.4,FALSE,be written
   305  df017641-5e4b-4ef5-9f7a-9e9f2ce044b3,0,9223372036854775810,0.0,true,in these
   306  4b4fb1af-0701-48dc-8593-aa7e094e72cc,1,9223372036854775810,1.5,false,string
   307  b97b7e95-c811-4327-8f75-8358595e9614,2,9223372036854775810,2.6,TRUE,columns.
   308  68d7b8bb-9065-47a8-9302-d4d44e5a149d,3,9223372036854775810,3.7,FALSE,Even emojis
   309  5b6b7354-f89e-44b2-9058-8aea7f17d5fb,4,9223372036854775810,4.8,true,🐈🐈🐈🐈`
   310  
   311  var oneOfEachKindWithSomeNilsCSVStr = `uuid,int,uint,float,bool,string
   312  6a52205c-2a46-4b04-9567-9e8707054774,-4,9223372036854775810,-4.1,true,this is
   313  3975299d-71cf-4faa-afea-155d64bd1a9b,-3,9223372036854775810,-3.2,false,a test
   314  8b682886-0ad7-4b96-a300-48fc1e10032e,,9223372036854775810,-2.3,TRUE,anything could
   315  5da4b68c-58b1-41f9-8fed-4fbe15740b79,-1,9223372036854775810,-1.4,FALSE,be written
   316  c659a238-3516-4380-a443-e528e55eb1bd,0,9223372036854775810,0.0,true,in these
   317  86f3bb91-7476-48a1-8d14-6249cc696b02,1,9223372036854775810,1.5,false,string
   318  42f2b05d-c279-4058-a13a-58f81a867401,,9223372036854775810,2.6,TRUE,columns.
   319  935afaea-af2a-4cc9-b008-505a5edbeb3e,3,9223372036854775810,3.7,FALSE,Even emojis
   320  aab7b641-124e-4193-8837-97cdfb13e8c4,4,9223372036854775810,4.8,true,🐈🐈🐈🐈`
   321  
   322  var mixUintsAndPositiveInts = `uuid,mix
   323  95e8ccf4-78a7-479e-a974-cf8605e6b01d,9223372036854775810
   324  60d6b5f9-9c99-4165-a835-1159e54b82ca,0
   325  cab7f39d-e1f2-46f8-b858-372067ec75a0,1000000`
   326  
   327  var floatsWithZeroForFractionalPortion = `uuid,float
   328  61c8e0ea-f2ff-49f8-98b8-df8142fd8a94,0.0
   329  e5a3bdca-9315-4afe-b548-6e44924ff549,-1.0
   330  e6b09e59-d524-4417-8252-f5c3eeeb5033,1.0`
   331  
   332  var floatsWithLargeFractionalPortion = `uuid,float
   333  ed929281-b4d0-4631-b931-8c2ebfb4ba84,0.0
   334  7d7c9b56-301c-4aa0-9034-8e12315b35cf,-1.0
   335  76bbed0b-7479-45bb-9813-ff1383bccee6,1.0`
   336  
   337  var floatsWithTinyFractionalPortion = `uuid,float
   338  da95f30d-4cde-4e53-a786-e00624ff2cbe,0.0001
   339  6fb474ca-8bec-4e21-9af2-a1ba22f39f1d,-1.0005
   340  aee125d4-e055-42e9-af3d-0bc676436ccd,1.0001`
   341  
   342  var identityMapper = make(rowconv.NameMapper)
   343  
   344  type testInferenceArgs struct {
   345  	ColMapper      rowconv.NameMapper
   346  	floatThreshold float64
   347  }
   348  
   349  func (tia testInferenceArgs) ColNameMapper() rowconv.NameMapper {
   350  	return tia.ColMapper
   351  }
   352  
   353  func (tia testInferenceArgs) FloatThreshold() float64 {
   354  	return tia.floatThreshold
   355  }
   356  
   357  func TestInferSchema(t *testing.T) {
   358  	tests := []struct {
   359  		name         string
   360  		csvContents  string
   361  		infArgs      InferenceArgs
   362  		expTypes     map[string]typeinfo.TypeInfo
   363  		nullableCols *set.StrSet
   364  	}{
   365  		{
   366  			"one of each kind",
   367  			oneOfEachKindCSVStr,
   368  			testInferenceArgs{
   369  				ColMapper:      identityMapper,
   370  				floatThreshold: 0,
   371  			},
   372  			map[string]typeinfo.TypeInfo{
   373  				"int":    typeinfo.Int32Type,
   374  				"uint":   typeinfo.StringDefaultType,
   375  				"uuid":   typeinfo.UuidType,
   376  				"float":  typeinfo.Float32Type,
   377  				"bool":   typeinfo.BoolType,
   378  				"string": typeinfo.StringDefaultType,
   379  			},
   380  			nil,
   381  		},
   382  		{
   383  			"mix uints and positive ints",
   384  			mixUintsAndPositiveInts,
   385  			testInferenceArgs{
   386  				ColMapper:      identityMapper,
   387  				floatThreshold: 0,
   388  			},
   389  			map[string]typeinfo.TypeInfo{
   390  				"mix":  typeinfo.StringDefaultType,
   391  				"uuid": typeinfo.UuidType,
   392  			},
   393  			nil,
   394  		},
   395  		{
   396  			"floats with zero fractional and float threshold of 0",
   397  			floatsWithZeroForFractionalPortion,
   398  			testInferenceArgs{
   399  				ColMapper:      identityMapper,
   400  				floatThreshold: 0,
   401  			},
   402  			map[string]typeinfo.TypeInfo{
   403  				"float": typeinfo.Float32Type,
   404  				"uuid":  typeinfo.UuidType,
   405  			},
   406  			nil,
   407  		},
   408  		{
   409  			"floats with zero fractional and float threshold of 0.1",
   410  			floatsWithZeroForFractionalPortion,
   411  			testInferenceArgs{
   412  				ColMapper:      identityMapper,
   413  				floatThreshold: 0.1,
   414  			},
   415  			map[string]typeinfo.TypeInfo{
   416  				"float": typeinfo.Int32Type,
   417  				"uuid":  typeinfo.UuidType,
   418  			},
   419  			nil,
   420  		},
   421  		{
   422  			"floats with large fractional and float threshold of 1.0",
   423  			floatsWithLargeFractionalPortion,
   424  			testInferenceArgs{
   425  				ColMapper:      identityMapper,
   426  				floatThreshold: 1.0,
   427  			},
   428  			map[string]typeinfo.TypeInfo{
   429  				"float": typeinfo.Int32Type,
   430  				"uuid":  typeinfo.UuidType,
   431  			},
   432  			nil,
   433  		},
   434  		{
   435  			"float threshold smaller than some of the values",
   436  			floatsWithTinyFractionalPortion,
   437  			testInferenceArgs{
   438  				ColMapper:      identityMapper,
   439  				floatThreshold: 0.0002,
   440  			},
   441  			map[string]typeinfo.TypeInfo{
   442  				"float": typeinfo.Float32Type,
   443  				"uuid":  typeinfo.UuidType,
   444  			},
   445  			nil,
   446  		},
   447  	}
   448  
   449  	const importFilePath = "/Users/home/datasets/test/import_file.csv"
   450  
   451  	for _, test := range tests {
   452  		t.Run(test.name, func(t *testing.T) {
   453  			dEnv := dtestutils.CreateTestEnv()
   454  			defer dEnv.DoltDB.Close()
   455  
   456  			wrCl, err := dEnv.FS.OpenForWrite(importFilePath, os.ModePerm)
   457  			require.NoError(t, err)
   458  			_, err = wrCl.Write([]byte(test.csvContents))
   459  			require.NoError(t, err)
   460  			err = wrCl.Close()
   461  			require.NoError(t, err)
   462  
   463  			rdCl, err := dEnv.FS.OpenForRead(importFilePath)
   464  			require.NoError(t, err)
   465  
   466  			csvRd, err := csv.NewCSVReader(types.Format_Default, rdCl, csv.NewCSVInfo())
   467  			require.NoError(t, err)
   468  
   469  			allCols, err := InferColumnTypesFromTableReader(context.Background(), csvRd, test.infArgs)
   470  			require.NoError(t, err)
   471  
   472  			assert.Equal(t, len(test.expTypes), allCols.Size())
   473  			err = allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   474  				expectedType, ok := test.expTypes[col.Name]
   475  				require.True(t, ok, "column not found: %s", col.Name)
   476  				assert.Equal(t, expectedType, col.TypeInfo, "column: %s - expected: %s got: %s", col.Name, expectedType.String(), col.TypeInfo.String())
   477  				return false, nil
   478  			})
   479  			require.NoError(t, err)
   480  
   481  			if test.nullableCols == nil {
   482  				test.nullableCols = set.NewStrSet(nil)
   483  			}
   484  
   485  			err = allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   486  				idx := schema.IndexOfConstraint(col.Constraints, schema.NotNullConstraintType)
   487  				assert.True(t, idx == -1, "%s unexpected not null constraint", col.Name)
   488  				return false, nil
   489  			})
   490  			require.NoError(t, err)
   491  		})
   492  	}
   493  }