github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/schema/encoding/schema_marshaling.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package encoding
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"sync"
    22  
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    25  	"github.com/dolthub/dolt/go/store/hash"
    26  	"github.com/dolthub/dolt/go/store/marshal"
    27  	"github.com/dolthub/dolt/go/store/types"
    28  )
    29  
    30  // Correct Marshalling & Unmarshalling is essential to compatibility across Dolt versions
    31  // any changes to the fields of Schema or other persisted objects must be append only, no
    32  // fields can ever be removed without breaking compatibility.
    33  //
    34  // the marshalling annotations of new fields must have the "omitempty" option to allow newer
    35  // versions of Dolt to read objects serialized by older Dolt versions where the field did not
    36  // yet exists. However, all fields must always be written.
    37  type encodedColumn struct {
    38  	Tag uint64 `noms:"tag" json:"tag"`
    39  
    40  	// Name is the name of the field
    41  	Name string `noms:"name" json:"name"`
    42  
    43  	// Kind is the type of the field.  See types/noms_kind.go in the liquidata fork for valid values
    44  	Kind string `noms:"kind" json:"kind"`
    45  
    46  	IsPartOfPK bool `noms:"is_part_of_pk" json:"is_part_of_pk"`
    47  
    48  	TypeInfo encodedTypeInfo `noms:"typeinfo,omitempty" json:"typeinfo,omitempty"`
    49  
    50  	Default string `noms:"default,omitempty" json:"default,omitempty"`
    51  
    52  	AutoIncrement bool `noms:"auto_increment,omitempty" json:"auto_increment,omitempty"`
    53  
    54  	Comment string `noms:"comment,omitempty" json:"comment,omitempty"`
    55  
    56  	Constraints []encodedConstraint `noms:"col_constraints" json:"col_constraints"`
    57  
    58  	// NB: all new fields must have the 'omitempty' annotation. See comment above
    59  }
    60  
    61  func encodeAllColConstraints(constraints []schema.ColConstraint) []encodedConstraint {
    62  	nomsConstraints := make([]encodedConstraint, len(constraints))
    63  	for i, c := range constraints {
    64  		nomsConstraints[i] = encodeColConstraint(c)
    65  	}
    66  
    67  	return nomsConstraints
    68  }
    69  
    70  func decodeAllColConstraint(encConstraints []encodedConstraint) []schema.ColConstraint {
    71  	if len(encConstraints) == 0 {
    72  		return nil
    73  	}
    74  
    75  	var constraints []schema.ColConstraint
    76  	seenNotNull := false
    77  	for _, nc := range encConstraints {
    78  		c := nc.decodeColConstraint()
    79  		// Prevent duplicate NOT NULL constraints
    80  		if c.GetConstraintType() == schema.NotNullConstraintType {
    81  			if seenNotNull {
    82  				continue
    83  			}
    84  			seenNotNull = true
    85  		}
    86  		constraints = append(constraints, c)
    87  	}
    88  
    89  	return constraints
    90  }
    91  
    92  func encodeColumn(col schema.Column) encodedColumn {
    93  	return encodedColumn{
    94  		Tag:           col.Tag,
    95  		Name:          col.Name,
    96  		Kind:          col.KindString(),
    97  		IsPartOfPK:    col.IsPartOfPK,
    98  		TypeInfo:      encodeTypeInfo(col.TypeInfo),
    99  		Default:       col.Default,
   100  		AutoIncrement: col.AutoIncrement,
   101  		Comment:       col.Comment,
   102  		Constraints:   encodeAllColConstraints(col.Constraints),
   103  	}
   104  }
   105  
   106  func (nfd encodedColumn) decodeColumn() (schema.Column, error) {
   107  	var typeInfo typeinfo.TypeInfo
   108  	var err error
   109  	if nfd.TypeInfo.Type != "" {
   110  		typeInfo, err = nfd.TypeInfo.decodeTypeInfo()
   111  		if err != nil {
   112  			return schema.Column{}, err
   113  		}
   114  	} else if nfd.Kind != "" {
   115  		typeInfo = typeinfo.FromKind(schema.LwrStrToKind[nfd.Kind])
   116  	} else {
   117  		return schema.Column{}, errors.New("cannot decode column due to unknown schema format")
   118  	}
   119  	colConstraints := decodeAllColConstraint(nfd.Constraints)
   120  	return schema.NewColumnWithTypeInfo(nfd.Name, nfd.Tag, typeInfo, nfd.IsPartOfPK, nfd.Default, nfd.AutoIncrement, nfd.Comment, colConstraints...)
   121  }
   122  
   123  type encodedConstraint struct {
   124  	Type   string            `noms:"constraint_type" json:"constraint_type"`
   125  	Params map[string]string `noms:"params" json:"params"`
   126  }
   127  
   128  func encodeColConstraint(constraint schema.ColConstraint) encodedConstraint {
   129  	return encodedConstraint{constraint.GetConstraintType(), constraint.GetConstraintParams()}
   130  }
   131  
   132  func (encCnst encodedConstraint) decodeColConstraint() schema.ColConstraint {
   133  	return schema.ColConstraintFromTypeAndParams(encCnst.Type, encCnst.Params)
   134  }
   135  
   136  type encodedTypeInfo struct {
   137  	Type   string            `noms:"type" json:"type"`
   138  	Params map[string]string `noms:"params" json:"params"`
   139  }
   140  
   141  func encodeTypeInfo(ti typeinfo.TypeInfo) encodedTypeInfo {
   142  	return encodedTypeInfo{ti.GetTypeIdentifier().String(), ti.GetTypeParams()}
   143  }
   144  
   145  func (enc encodedTypeInfo) decodeTypeInfo() (typeinfo.TypeInfo, error) {
   146  	id := typeinfo.ParseIdentifier(enc.Type)
   147  	return typeinfo.FromTypeParams(id, enc.Params)
   148  }
   149  
   150  type encodedIndex struct {
   151  	Name            string              `noms:"name" json:"name"`
   152  	Tags            []uint64            `noms:"tags" json:"tags"`
   153  	Comment         string              `noms:"comment" json:"comment"`
   154  	Unique          bool                `noms:"unique" json:"unique"`
   155  	Spatial         bool                `noms:"spatial,omitempty" json:"spatial,omitempty"`
   156  	FullText        bool                `noms:"fulltext,omitempty" json:"fulltext,omitempty"`
   157  	IsSystemDefined bool                `noms:"hidden,omitempty" json:"hidden,omitempty"` // Was previously named Hidden, do not change noms name
   158  	PrefixLengths   []uint16            `noms:"prefixLengths,omitempty" json:"prefixLengths,omitempty"`
   159  	FullTextInfo    encodedFullTextInfo `noms:"fulltext_info,omitempty" json:"fulltext_info,omitempty"`
   160  }
   161  
   162  type encodedFullTextInfo struct {
   163  	ConfigTable      string   `noms:"config_table" json:"config_table"`
   164  	PositionTable    string   `noms:"position_table" json:"position_table"`
   165  	DocCountTable    string   `noms:"doc_count_table" json:"doc_count_table"`
   166  	GlobalCountTable string   `noms:"global_count_table" json:"global_count_table"`
   167  	RowCountTable    string   `noms:"row_count_table" json:"row_count_table"`
   168  	KeyType          uint8    `noms:"key_type" json:"key_type"`
   169  	KeyName          string   `noms:"key_name" json:"key_name"`
   170  	KeyPositions     []uint16 `noms:"key_positions" json:"key_positions"`
   171  }
   172  
   173  type encodedCheck struct {
   174  	Name       string `noms:"name" json:"name"`
   175  	Expression string `noms:"expression" json:"expression"`
   176  	Enforced   bool   `noms:"enforced" json:"enforced"`
   177  }
   178  
   179  type encodedSchemaData struct {
   180  	Columns          []encodedColumn  `noms:"columns" json:"columns"`
   181  	IndexCollection  []encodedIndex   `noms:"idxColl,omitempty" json:"idxColl,omitempty"`
   182  	CheckConstraints []encodedCheck   `noms:"checks,omitempty" json:"checks,omitempty"`
   183  	PkOrdinals       []int            `noms:"pkOrdinals,omitempty" json:"pkOrdinals,omitEmpty"`
   184  	Collation        schema.Collation `noms:"collation,omitempty" json:"collation,omitempty"`
   185  }
   186  
   187  func (sd *encodedSchemaData) Copy() *encodedSchemaData {
   188  	var columns []encodedColumn
   189  	if sd.Columns != nil {
   190  		columns = make([]encodedColumn, len(sd.Columns))
   191  		for i, column := range sd.Columns {
   192  			columns[i] = column
   193  		}
   194  	}
   195  
   196  	var idxCol []encodedIndex
   197  	if sd.IndexCollection != nil {
   198  		idxCol = make([]encodedIndex, len(sd.IndexCollection))
   199  		for i, idx := range sd.IndexCollection {
   200  			idxCol[i] = idx
   201  			idxCol[i].Tags = make([]uint64, len(idx.Tags))
   202  			for j, tag := range idx.Tags {
   203  				idxCol[i].Tags[j] = tag
   204  			}
   205  			if len(idx.PrefixLengths) > 0 {
   206  				idxCol[i].PrefixLengths = make([]uint16, len(idx.PrefixLengths))
   207  				for j, prefixLength := range idx.PrefixLengths {
   208  					idxCol[i].PrefixLengths[j] = prefixLength
   209  				}
   210  			}
   211  		}
   212  	}
   213  
   214  	var checks []encodedCheck
   215  	if sd.CheckConstraints != nil {
   216  		checks = make([]encodedCheck, len(sd.CheckConstraints))
   217  		for i, check := range sd.CheckConstraints {
   218  			checks[i] = check
   219  		}
   220  	}
   221  
   222  	var pkOrdinals []int
   223  	if sd.PkOrdinals != nil {
   224  		pkOrdinals = make([]int, len(sd.PkOrdinals))
   225  		for i, j := range sd.PkOrdinals {
   226  			pkOrdinals[i] = j
   227  		}
   228  	}
   229  
   230  	return &encodedSchemaData{
   231  		Columns:          columns,
   232  		IndexCollection:  idxCol,
   233  		CheckConstraints: checks,
   234  		PkOrdinals:       pkOrdinals,
   235  		Collation:        sd.Collation,
   236  	}
   237  }
   238  
   239  func toSchemaData(sch schema.Schema) (encodedSchemaData, error) {
   240  	allCols := sch.GetAllCols()
   241  	encCols := make([]encodedColumn, allCols.Size())
   242  
   243  	i := 0
   244  	err := allCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   245  		encCols[i] = encodeColumn(col)
   246  		i++
   247  
   248  		return false, nil
   249  	})
   250  
   251  	if err != nil {
   252  		return encodedSchemaData{}, err
   253  	}
   254  
   255  	encodedIndexes := make([]encodedIndex, sch.Indexes().Count())
   256  	for i, index := range sch.Indexes().AllIndexes() {
   257  		props := index.FullTextProperties()
   258  		encodedIndexes[i] = encodedIndex{
   259  			Name:            index.Name(),
   260  			Tags:            index.IndexedColumnTags(),
   261  			Comment:         index.Comment(),
   262  			Unique:          index.IsUnique(),
   263  			Spatial:         index.IsSpatial(),
   264  			FullText:        index.IsFullText(),
   265  			IsSystemDefined: !index.IsUserDefined(),
   266  			PrefixLengths:   index.PrefixLengths(),
   267  			FullTextInfo: encodedFullTextInfo{
   268  				ConfigTable:      props.ConfigTable,
   269  				PositionTable:    props.PositionTable,
   270  				DocCountTable:    props.DocCountTable,
   271  				GlobalCountTable: props.GlobalCountTable,
   272  				RowCountTable:    props.RowCountTable,
   273  				KeyType:          props.KeyType,
   274  				KeyName:          props.KeyName,
   275  				KeyPositions:     props.KeyPositions,
   276  			},
   277  		}
   278  	}
   279  
   280  	encodedChecks := make([]encodedCheck, sch.Checks().Count())
   281  	checks := sch.Checks()
   282  	for i, check := range checks.AllChecks() {
   283  		encodedChecks[i] = encodedCheck{
   284  			Name:       check.Name(),
   285  			Expression: check.Expression(),
   286  			Enforced:   check.Enforced(),
   287  		}
   288  	}
   289  
   290  	return encodedSchemaData{
   291  		Columns:          encCols,
   292  		IndexCollection:  encodedIndexes,
   293  		CheckConstraints: encodedChecks,
   294  		PkOrdinals:       sch.GetPkOrdinals(),
   295  		Collation:        sch.GetCollation(),
   296  	}, nil
   297  }
   298  
   299  func (sd encodedSchemaData) decodeSchema() (schema.Schema, error) {
   300  	numCols := len(sd.Columns)
   301  	cols := make([]schema.Column, numCols)
   302  
   303  	var err error
   304  	for i, col := range sd.Columns {
   305  		cols[i], err = col.decodeColumn()
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  	}
   310  
   311  	colColl := schema.NewColCollection(cols...)
   312  
   313  	sch, err := schema.SchemaFromCols(colColl)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	sch.SetCollation(sd.Collation)
   318  
   319  	if sd.PkOrdinals != nil {
   320  		err := sch.SetPkOrdinals(sd.PkOrdinals)
   321  		if err != nil {
   322  			return nil, err
   323  		}
   324  	}
   325  
   326  	for _, encodedIndex := range sd.IndexCollection {
   327  		_, err := sch.Indexes().UnsafeAddIndexByColTags(
   328  			encodedIndex.Name,
   329  			encodedIndex.Tags,
   330  			encodedIndex.PrefixLengths,
   331  			schema.IndexProperties{
   332  				IsUnique:      encodedIndex.Unique,
   333  				IsSpatial:     encodedIndex.Spatial,
   334  				IsFullText:    encodedIndex.FullText,
   335  				IsUserDefined: !encodedIndex.IsSystemDefined,
   336  				Comment:       encodedIndex.Comment,
   337  				FullTextProperties: schema.FullTextProperties{
   338  					ConfigTable:      encodedIndex.FullTextInfo.ConfigTable,
   339  					PositionTable:    encodedIndex.FullTextInfo.PositionTable,
   340  					DocCountTable:    encodedIndex.FullTextInfo.DocCountTable,
   341  					GlobalCountTable: encodedIndex.FullTextInfo.GlobalCountTable,
   342  					RowCountTable:    encodedIndex.FullTextInfo.RowCountTable,
   343  					KeyType:          encodedIndex.FullTextInfo.KeyType,
   344  					KeyName:          encodedIndex.FullTextInfo.KeyName,
   345  					KeyPositions:     encodedIndex.FullTextInfo.KeyPositions,
   346  				},
   347  			},
   348  		)
   349  		if err != nil {
   350  			return nil, err
   351  		}
   352  	}
   353  
   354  	for _, encodedCheck := range sd.CheckConstraints {
   355  		_, err := sch.Checks().AddCheck(
   356  			encodedCheck.Name,
   357  			encodedCheck.Expression,
   358  			encodedCheck.Enforced,
   359  		)
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  	}
   364  
   365  	return sch, nil
   366  }
   367  
   368  // MarshalSchema takes a Schema and converts it to a types.Value
   369  func MarshalSchema(ctx context.Context, vrw types.ValueReadWriter, sch schema.Schema) (types.Value, error) {
   370  	// Anyone calling this is going to serialize this to disk, so it's our last line of defense against defective schemas.
   371  	// Business logic should catch errors before this point, but this is a failsafe.
   372  	err := schema.ValidateColumnConstraints(sch.GetAllCols())
   373  	if err != nil {
   374  		return nil, err
   375  	}
   376  
   377  	err = schema.ValidateForInsert(sch.GetAllCols())
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	if vrw.Format().VersionString() != types.Format_DOLT.VersionString() {
   383  		for _, idx := range sch.Indexes().AllIndexes() {
   384  			if idx.IsSpatial() {
   385  				return nil, fmt.Errorf("spatial indexes are only supported in storage format __DOLT__")
   386  			}
   387  		}
   388  	}
   389  
   390  	if vrw.Format().UsesFlatbuffers() {
   391  		return SerializeSchema(ctx, vrw, sch)
   392  	}
   393  
   394  	sd, err := toSchemaData(sch)
   395  
   396  	if err != nil {
   397  		return types.EmptyStruct(vrw.Format()), err
   398  	}
   399  
   400  	val, err := marshal.Marshal(ctx, vrw, sd)
   401  
   402  	if err != nil {
   403  		return types.EmptyStruct(vrw.Format()), err
   404  	}
   405  
   406  	if _, ok := val.(types.Struct); ok {
   407  		return val, nil
   408  	}
   409  
   410  	return types.EmptyStruct(vrw.Format()), errors.New("Table Schema could not be converted to types.Struct")
   411  }
   412  
   413  type schCacheData struct {
   414  	schema schema.Schema
   415  }
   416  
   417  var schemaCacheMu *sync.Mutex = &sync.Mutex{}
   418  var unmarshalledSchemaCache = map[hash.Hash]schCacheData{}
   419  
   420  // UnmarshalSchema takes a types.Value representing a Schema and Unmarshalls it into a schema.Schema.
   421  func UnmarshalSchema(ctx context.Context, nbf *types.NomsBinFormat, schemaVal types.Value) (schema.Schema, error) {
   422  	var sch schema.Schema
   423  	var err error
   424  	if nbf.UsesFlatbuffers() {
   425  		sch, err = DeserializeSchema(ctx, nbf, schemaVal)
   426  		if err != nil {
   427  			return nil, err
   428  		}
   429  	} else {
   430  		var sd encodedSchemaData
   431  		err = marshal.Unmarshal(ctx, nbf, schemaVal, &sd)
   432  		if err != nil {
   433  			return nil, err
   434  		}
   435  
   436  		sch, err = sd.decodeSchema()
   437  		if err != nil {
   438  			return nil, err
   439  		}
   440  	}
   441  
   442  	if nbf.VersionString() != types.Format_DOLT.VersionString() {
   443  		for _, idx := range sch.Indexes().AllIndexes() {
   444  			if idx.IsSpatial() {
   445  				return nil, fmt.Errorf("spatial indexes are only supported in storage format __DOLT__")
   446  			}
   447  		}
   448  	}
   449  
   450  	return sch, nil
   451  }
   452  
   453  // UnmarshalSchemaAtAddr returns the schema at the given address, using the schema cache if possible.
   454  func UnmarshalSchemaAtAddr(ctx context.Context, vr types.ValueReader, addr hash.Hash) (schema.Schema, error) {
   455  	schemaCacheMu.Lock()
   456  	cachedData, ok := unmarshalledSchemaCache[addr]
   457  	schemaCacheMu.Unlock()
   458  
   459  	if ok {
   460  		cachedSch := cachedData.schema
   461  		return cachedSch.Copy(), nil
   462  	}
   463  
   464  	schemaVal, err := vr.ReadValue(ctx, addr)
   465  	if err != nil {
   466  		return nil, err
   467  	}
   468  
   469  	sch, err := UnmarshalSchema(ctx, vr.Format(), schemaVal)
   470  	if err != nil {
   471  		return nil, err
   472  	}
   473  
   474  	d := schCacheData{
   475  		schema: sch,
   476  	}
   477  
   478  	schemaCacheMu.Lock()
   479  	unmarshalledSchemaCache[addr] = d
   480  	schemaCacheMu.Unlock()
   481  
   482  	return sch, nil
   483  }