github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/schema/encoding/schema_marshaling_test.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package encoding
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"math/rand"
    21  	"reflect"
    22  	"strconv"
    23  	"testing"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    27  	"github.com/dolthub/vitess/go/sqltypes"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    33  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    34  	"github.com/dolthub/dolt/go/store/chunks"
    35  	"github.com/dolthub/dolt/go/store/constants"
    36  	"github.com/dolthub/dolt/go/store/marshal"
    37  	"github.com/dolthub/dolt/go/store/types"
    38  )
    39  
    40  func createTestSchema() schema.Schema {
    41  	columns := []schema.Column{
    42  		schema.NewColumn("id", 4, types.InlineBlobKind, true, schema.NotNullConstraint{}),
    43  		schema.NewColumn("first", 1, types.StringKind, false),
    44  		schema.NewColumn("last", 2, types.StringKind, false, schema.NotNullConstraint{}),
    45  		schema.NewColumn("age", 3, types.UintKind, false),
    46  	}
    47  
    48  	colColl := schema.NewColCollection(columns...)
    49  	sch := schema.MustSchemaFromCols(colColl)
    50  	_, _ = sch.Indexes().AddIndexByColTags("idx_age", []uint64{3}, nil, schema.IndexProperties{IsUnique: false, Comment: ""})
    51  	return sch
    52  }
    53  
    54  func TestNomsMarshalling(t *testing.T) {
    55  	tSchema := createTestSchema()
    56  	_, vrw, _, err := dbfactory.MemFactory{}.CreateDB(context.Background(), types.Format_Default, nil, nil)
    57  
    58  	if err != nil {
    59  		t.Fatal("Could not create in mem noms db.")
    60  	}
    61  
    62  	val, err := MarshalSchema(context.Background(), vrw, tSchema)
    63  
    64  	if err != nil {
    65  		t.Fatal("Failed to marshal Schema as a types.Value.")
    66  	}
    67  
    68  	unMarshalled, err := UnmarshalSchema(context.Background(), types.Format_Default, val)
    69  
    70  	if err != nil {
    71  		t.Fatal("Failed to unmarshal types.Value as Schema")
    72  	}
    73  
    74  	if !assert.Equal(t, tSchema, unMarshalled) {
    75  		t.Error("Value different after marshalling and unmarshalling.")
    76  	}
    77  
    78  	// this validation step only makes sense for Noms encoded schemas
    79  	if !types.Format_Default.UsesFlatbuffers() {
    80  		validated, err := validateUnmarshaledNomsValue(context.Background(), types.Format_Default, val)
    81  		assert.NoError(t, err,
    82  			"Failed compatibility test. Schema could not be unmarshalled with mirror type, error: %v", err)
    83  		if !assert.Equal(t, tSchema, validated) {
    84  			t.Error("Value different after marshalling and unmarshalling.")
    85  		}
    86  	}
    87  }
    88  
    89  func getSqlTypes() []sql.Type {
    90  	//TODO: determine the storage format for BINARY
    91  	//TODO: determine the storage format for BLOB
    92  	//TODO: determine the storage format for LONGBLOB
    93  	//TODO: determine the storage format for MEDIUMBLOB
    94  	//TODO: determine the storage format for TINYBLOB
    95  	//TODO: determine the storage format for VARBINARY
    96  	return []sql.Type{
    97  		gmstypes.Int64,  //BIGINT
    98  		gmstypes.Uint64, //BIGINT UNSIGNED
    99  		//sql.MustCreateBinary(sqltypes.Binary, 10), //BINARY(10)
   100  		gmstypes.MustCreateBitType(10), //BIT(10)
   101  		//sql.Blob, //BLOB
   102  		gmstypes.Boolean, //BOOLEAN
   103  		gmstypes.MustCreateStringWithDefaults(sqltypes.Char, 10), //CHAR(10)
   104  		gmstypes.Date,     //DATE
   105  		gmstypes.Datetime, //DATETIME
   106  		gmstypes.MustCreateColumnDecimalType(9, 5), //DECIMAL(9, 5)
   107  		gmstypes.Float64, //DOUBLE
   108  		gmstypes.MustCreateEnumType([]string{"a", "b", "c"}, sql.Collation_Default), //ENUM('a','b','c')
   109  		gmstypes.Float32, //FLOAT
   110  		gmstypes.Int32,   //INT
   111  		gmstypes.Uint32,  //INT UNSIGNED
   112  		//sql.LongBlob, //LONGBLOB
   113  		gmstypes.LongText, //LONGTEXT
   114  		//sql.MediumBlob, //MEDIUMBLOB
   115  		gmstypes.Int24,      //MEDIUMINT
   116  		gmstypes.Uint24,     //MEDIUMINT UNSIGNED
   117  		gmstypes.MediumText, //MEDIUMTEXT
   118  		gmstypes.MustCreateSetType([]string{"a", "b", "c"}, sql.Collation_Default), //SET('a','b','c')
   119  		gmstypes.Int16,     //SMALLINT
   120  		gmstypes.Uint16,    //SMALLINT UNSIGNED
   121  		gmstypes.Text,      //TEXT
   122  		gmstypes.Time,      //TIME
   123  		gmstypes.Timestamp, //TIMESTAMP
   124  		//sql.TinyBlob, //TINYBLOB
   125  		gmstypes.Int8,     //TINYINT
   126  		gmstypes.Uint8,    //TINYINT UNSIGNED
   127  		gmstypes.TinyText, //TINYTEXT
   128  		//sql.MustCreateBinary(sqltypes.VarBinary, 10), //VARBINARY(10)
   129  		gmstypes.MustCreateStringWithDefaults(sqltypes.VarChar, 10),                //VARCHAR(10)
   130  		gmstypes.MustCreateString(sqltypes.VarChar, 10, sql.Collation_utf8mb3_bin), //VARCHAR(10) CHARACTER SET utf8mb3 COLLATE utf8mb3_bin
   131  		gmstypes.Year, //YEAR
   132  	}
   133  }
   134  
   135  func TestTypeInfoMarshalling(t *testing.T) {
   136  	for _, sqlType := range getSqlTypes() {
   137  		t.Run(sqlType.String(), func(t *testing.T) {
   138  			ti, err := typeinfo.FromSqlType(sqlType)
   139  			require.NoError(t, err)
   140  			col, err := schema.NewColumnWithTypeInfo("pk", 1, ti, true, "", false, "", schema.NotNullConstraint{})
   141  			require.NoError(t, err)
   142  			colColl := schema.NewColCollection(col)
   143  			originalSch, err := schema.SchemaFromCols(colColl)
   144  			require.NoError(t, err)
   145  
   146  			nbf, err := types.GetFormatForVersionString(constants.FormatDefaultString)
   147  			require.NoError(t, err)
   148  			_, vrw, _, err := dbfactory.MemFactory{}.CreateDB(context.Background(), nbf, nil, nil)
   149  			require.NoError(t, err)
   150  			val, err := MarshalSchema(context.Background(), vrw, originalSch)
   151  			require.NoError(t, err)
   152  			unmarshalledSch, err := UnmarshalSchema(context.Background(), nbf, val)
   153  			require.NoError(t, err)
   154  			ok := schema.SchemasAreEqual(originalSch, unmarshalledSch)
   155  			assert.True(t, ok)
   156  		})
   157  	}
   158  }
   159  
   160  func validateUnmarshaledNomsValue(ctx context.Context, nbf *types.NomsBinFormat, schemaVal types.Value) (schema.Schema, error) {
   161  	var sd testSchemaData
   162  	err := marshal.Unmarshal(ctx, nbf, schemaVal, &sd)
   163  
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	return sd.decodeSchema()
   169  }
   170  
   171  func TestMirroredTypes(t *testing.T) {
   172  	realType := reflect.ValueOf(&encodedColumn{}).Elem()
   173  	mirrorType := reflect.ValueOf(&testEncodedColumn{}).Elem()
   174  	require.Equal(t, mirrorType.NumField(), realType.NumField())
   175  
   176  	// TODO: create reflection tests to ensure that:
   177  	// - no fields in testEncodeColumn have the 'omitempty' annotation
   178  	// - no legacy fields in encodeColumn have the 'omitempty' annotation (with whitelist)
   179  	// - all new fields in encodeColumn have the 'omitempty' annotation
   180  }
   181  
   182  // testEncodedColumn is a mirror type of encodedColumn that helps ensure compatibility between Dolt versions
   183  //
   184  // If a field in encodedColumn is not optional, it should be added WITHOUT the "omitempty" annotation here
   185  // in order to guarantee that all fields in encodeColumn are always being written when encodedColumn is serialized.
   186  // See the comment above type encodeColumn.
   187  type testEncodedColumn struct {
   188  	Tag uint64 `noms:"tag" json:"tag"`
   189  
   190  	Name string `noms:"name" json:"name"`
   191  
   192  	Kind string `noms:"kind" json:"kind"`
   193  
   194  	IsPartOfPK bool `noms:"is_part_of_pk" json:"is_part_of_pk"`
   195  
   196  	TypeInfo encodedTypeInfo `noms:"typeinfo" json:"typeinfo"`
   197  
   198  	Default string `noms:"default,omitempty" json:"default,omitempty"`
   199  
   200  	AutoIncrement bool `noms:"auto_increment,omitempty" json:"auto_increment,omitempty"`
   201  
   202  	Comment string `noms:"comment,omitempty" json:"comment,omitempty"`
   203  
   204  	Constraints []encodedConstraint `noms:"col_constraints" json:"col_constraints"`
   205  }
   206  
   207  type testEncodedIndex struct {
   208  	Name    string   `noms:"name" json:"name"`
   209  	Tags    []uint64 `noms:"tags" json:"tags"`
   210  	Comment string   `noms:"comment" json:"comment"`
   211  	Unique  bool     `noms:"unique" json:"unique"`
   212  	Hidden  bool     `noms:"hidden,omitempty" json:"hidden,omitempty"`
   213  }
   214  
   215  type testSchemaData struct {
   216  	Columns         []testEncodedColumn `noms:"columns" json:"columns"`
   217  	IndexCollection []testEncodedIndex  `noms:"idxColl,omitempty" json:"idxColl,omitempty"`
   218  	Collation       schema.Collation    `noms:"collation,omitempty" json:"collation,omitempty"`
   219  }
   220  
   221  func (tec testEncodedColumn) decodeColumn() (schema.Column, error) {
   222  	var typeInfo typeinfo.TypeInfo
   223  	var err error
   224  	if tec.TypeInfo.Type != "" {
   225  		typeInfo, err = tec.TypeInfo.decodeTypeInfo()
   226  		if err != nil {
   227  			return schema.Column{}, err
   228  		}
   229  	} else if tec.Kind != "" {
   230  		typeInfo = typeinfo.FromKind(schema.LwrStrToKind[tec.Kind])
   231  	} else {
   232  		return schema.Column{}, errors.New("cannot decode column due to unknown schema format")
   233  	}
   234  	colConstraints := decodeAllColConstraint(tec.Constraints)
   235  	return schema.NewColumnWithTypeInfo(tec.Name, tec.Tag, typeInfo, tec.IsPartOfPK, tec.Default, tec.AutoIncrement, tec.Comment, colConstraints...)
   236  }
   237  
   238  func (tsd testSchemaData) decodeSchema() (schema.Schema, error) {
   239  	numCols := len(tsd.Columns)
   240  	cols := make([]schema.Column, numCols)
   241  
   242  	var err error
   243  	for i, col := range tsd.Columns {
   244  		cols[i], err = col.decodeColumn()
   245  		if err != nil {
   246  			return nil, err
   247  		}
   248  	}
   249  
   250  	colColl := schema.NewColCollection(cols...)
   251  
   252  	sch, err := schema.SchemaFromCols(colColl)
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  	sch.SetCollation(tsd.Collation)
   257  
   258  	for _, encodedIndex := range tsd.IndexCollection {
   259  		_, err = sch.Indexes().AddIndexByColTags(encodedIndex.Name, encodedIndex.Tags, nil, schema.IndexProperties{IsUnique: encodedIndex.Unique, Comment: encodedIndex.Comment})
   260  		if err != nil {
   261  			return nil, err
   262  		}
   263  	}
   264  
   265  	return sch, nil
   266  }
   267  
   268  func TestSchemaMarshalling(t *testing.T) {
   269  	ctx := context.Background()
   270  	nbf := types.Format_Default
   271  	vrw := getTestVRW(nbf)
   272  	schemas := getSchemas(t, 1000)
   273  	for _, sch := range schemas {
   274  		v, err := MarshalSchema(ctx, vrw, sch)
   275  		require.NoError(t, err)
   276  		s, err := UnmarshalSchema(ctx, nbf, v)
   277  		require.NoError(t, err)
   278  		assert.Equal(t, sch, s)
   279  	}
   280  }
   281  
   282  func getTypeinfo(t *testing.T) (ti []typeinfo.TypeInfo) {
   283  	st := getSqlTypes()
   284  	ti = make([]typeinfo.TypeInfo, len(st))
   285  	for i := range st {
   286  		var err error
   287  		ti[i], err = typeinfo.FromSqlType(st[i])
   288  		require.NoError(t, err)
   289  	}
   290  	return
   291  }
   292  
   293  func getColumns(t *testing.T) (cols []schema.Column) {
   294  	ti := getTypeinfo(t)
   295  	cols = make([]schema.Column, len(ti))
   296  	var err error
   297  	for i := range cols {
   298  		name := "col" + strconv.Itoa(i)
   299  		tag := uint64(i)
   300  		cols[i], err = schema.NewColumnWithTypeInfo(name, tag, ti[i], false, "", false, "")
   301  		require.NoError(t, err)
   302  	}
   303  	return
   304  }
   305  
   306  func getSchemas(t *testing.T, n int) (schemas []schema.Schema) {
   307  	cols := getColumns(t)
   308  	schemas = make([]schema.Schema, n)
   309  	var err error
   310  	for i := range schemas {
   311  		rand.Shuffle(len(cols), func(i, j int) {
   312  			cols[i], cols[j] = cols[j], cols[i]
   313  		})
   314  		k := rand.Intn(len(cols)-1) + 1
   315  		cc := make([]schema.Column, k)
   316  		copy(cc, cols)
   317  		cc[0].IsPartOfPK = true
   318  		cc[0].Constraints = []schema.ColConstraint{schema.NotNullConstraint{}}
   319  		schemas[i], err = schema.SchemaFromCols(
   320  			schema.NewColCollection(cc...))
   321  		require.NoError(t, err)
   322  	}
   323  	return
   324  }
   325  
   326  func getTestVRW(nbf *types.NomsBinFormat) types.ValueReadWriter {
   327  	ts := &chunks.TestStorage{}
   328  	cs := ts.NewViewWithFormat(nbf.VersionString())
   329  	return types.NewValueStore(cs)
   330  }