github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/doltdb/root_val_storage.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package doltdb
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  
    22  	flatbuffers "github.com/dolthub/flatbuffers/v23/go"
    23  
    24  	"github.com/dolthub/dolt/go/gen/fb/serial"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/store/hash"
    27  	"github.com/dolthub/dolt/go/store/prolly"
    28  	"github.com/dolthub/dolt/go/store/prolly/shim"
    29  	"github.com/dolthub/dolt/go/store/prolly/tree"
    30  	"github.com/dolthub/dolt/go/store/types"
    31  )
    32  
    33  type rootValueStorage interface {
    34  	GetFeatureVersion() (FeatureVersion, bool, error)
    35  
    36  	GetTablesMap(ctx context.Context, vr types.ValueReadWriter, ns tree.NodeStore, databaseSchema string) (tableMap, error)
    37  	GetForeignKeys(ctx context.Context, vr types.ValueReader) (types.Value, bool, error)
    38  	GetCollation(ctx context.Context) (schema.Collation, error)
    39  	GetSchemas(ctx context.Context) ([]schema.DatabaseSchema, error)
    40  
    41  	SetForeignKeyMap(ctx context.Context, vrw types.ValueReadWriter, m types.Value) (rootValueStorage, error)
    42  	SetFeatureVersion(v FeatureVersion) (rootValueStorage, error)
    43  	SetCollation(ctx context.Context, collation schema.Collation) (rootValueStorage, error)
    44  	SetSchemas(ctx context.Context, dbSchemas []schema.DatabaseSchema) (rootValueStorage, error)
    45  
    46  	EditTablesMap(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, edits []tableEdit) (rootValueStorage, error)
    47  
    48  	DebugString(ctx context.Context) string
    49  	nomsValue() types.Value
    50  }
    51  
    52  type nomsRvStorage struct {
    53  	valueSt types.Struct
    54  }
    55  
    56  type tableMap interface {
    57  	Get(ctx context.Context, name string) (hash.Hash, error)
    58  	Iter(ctx context.Context, cb func(name string, addr hash.Hash) (bool, error)) error
    59  }
    60  
    61  func tmIterAll(ctx context.Context, tm tableMap, cb func(name string, addr hash.Hash)) error {
    62  	return tm.Iter(ctx, func(name string, addr hash.Hash) (bool, error) {
    63  		cb(name, addr)
    64  		return false, nil
    65  	})
    66  }
    67  
    68  func (r nomsRvStorage) GetFeatureVersion() (FeatureVersion, bool, error) {
    69  	v, ok, err := r.valueSt.MaybeGet(featureVersKey)
    70  	if err != nil {
    71  		return 0, false, err
    72  	}
    73  	if ok {
    74  		return FeatureVersion(v.(types.Int)), true, nil
    75  	} else {
    76  		return 0, false, nil
    77  	}
    78  }
    79  
    80  func (r nomsRvStorage) GetTablesMap(context.Context, types.ValueReadWriter, tree.NodeStore, string) (tableMap, error) {
    81  	v, found, err := r.valueSt.MaybeGet(tablesKey)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if !found {
    86  		return nomsTableMap{types.EmptyMap}, nil
    87  	}
    88  	return nomsTableMap{v.(types.Map)}, nil
    89  }
    90  
    91  func (ntm nomsTableMap) Get(ctx context.Context, name string) (hash.Hash, error) {
    92  	v, f, err := ntm.MaybeGet(ctx, types.String(name))
    93  	if err != nil {
    94  		return hash.Hash{}, err
    95  	}
    96  	if !f {
    97  		return hash.Hash{}, nil
    98  	}
    99  	return v.(types.Ref).TargetHash(), nil
   100  }
   101  
   102  func (ntm nomsTableMap) Iter(ctx context.Context, cb func(name string, addr hash.Hash) (bool, error)) error {
   103  	return ntm.Map.Iter(ctx, func(k, v types.Value) (bool, error) {
   104  		name := string(k.(types.String))
   105  		addr := v.(types.Ref).TargetHash()
   106  		return cb(name, addr)
   107  	})
   108  }
   109  
   110  func (r nomsRvStorage) GetForeignKeys(context.Context, types.ValueReader) (types.Value, bool, error) {
   111  	v, found, err := r.valueSt.MaybeGet(foreignKeyKey)
   112  	if err != nil {
   113  		return types.Map{}, false, err
   114  	}
   115  	if !found {
   116  		return types.Map{}, false, err
   117  	}
   118  	return v.(types.Map), true, nil
   119  }
   120  
   121  func (r nomsRvStorage) GetCollation(ctx context.Context) (schema.Collation, error) {
   122  	v, found, err := r.valueSt.MaybeGet(rootCollationKey)
   123  	if err != nil {
   124  		return schema.Collation_Unspecified, err
   125  	}
   126  	if !found {
   127  		return schema.Collation_Default, nil
   128  	}
   129  	return schema.Collation(v.(types.Uint)), nil
   130  }
   131  
   132  func (r nomsRvStorage) EditTablesMap(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, edits []tableEdit) (rootValueStorage, error) {
   133  	m, err := r.GetTablesMap(ctx, vrw, ns, "")
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  	nm := m.(nomsTableMap).Map
   138  
   139  	me := nm.Edit()
   140  	for _, e := range edits {
   141  		if e.old_name != "" {
   142  			old, f, err := nm.MaybeGet(ctx, types.String(e.old_name))
   143  			if err != nil {
   144  				return nil, err
   145  			}
   146  			if !f {
   147  				return nil, ErrTableNotFound
   148  			}
   149  			_, f, err = nm.MaybeGet(ctx, types.String(e.name.Name))
   150  			if err != nil {
   151  				return nil, err
   152  			}
   153  			if f {
   154  				return nil, ErrTableExists
   155  			}
   156  			me = me.Remove(types.String(e.old_name)).Set(types.String(e.name.Name), old)
   157  		} else {
   158  			if e.ref == nil {
   159  				me = me.Remove(types.String(e.name.Name))
   160  			} else {
   161  				me = me.Set(types.String(e.name.Name), *e.ref)
   162  			}
   163  		}
   164  	}
   165  
   166  	nm, err = me.Map(ctx)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	st, err := r.valueSt.Set(tablesKey, nm)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	return nomsRvStorage{st}, nil
   176  }
   177  
   178  func (r nomsRvStorage) SetForeignKeyMap(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (rootValueStorage, error) {
   179  	st, err := r.valueSt.Set(foreignKeyKey, v)
   180  	if err != nil {
   181  		return nomsRvStorage{}, err
   182  	}
   183  	return nomsRvStorage{st}, nil
   184  }
   185  
   186  func (r nomsRvStorage) SetFeatureVersion(v FeatureVersion) (rootValueStorage, error) {
   187  	st, err := r.valueSt.Set(featureVersKey, types.Int(v))
   188  	if err != nil {
   189  		return nomsRvStorage{}, err
   190  	}
   191  	return nomsRvStorage{st}, nil
   192  }
   193  
   194  func (r nomsRvStorage) SetCollation(ctx context.Context, collation schema.Collation) (rootValueStorage, error) {
   195  	st, err := r.valueSt.Set(rootCollationKey, types.Uint(collation))
   196  	if err != nil {
   197  		return nomsRvStorage{}, err
   198  	}
   199  	return nomsRvStorage{st}, nil
   200  }
   201  
   202  func (r nomsRvStorage) GetSchemas(ctx context.Context) ([]schema.DatabaseSchema, error) {
   203  	panic("schemas not implemented for nomsRvStorage")
   204  }
   205  
   206  func (r nomsRvStorage) SetSchemas(ctx context.Context, dbSchemas []schema.DatabaseSchema) (rootValueStorage, error) {
   207  	panic("schemas not implemented for nomsRvStorage")
   208  }
   209  
   210  func (r nomsRvStorage) DebugString(ctx context.Context) string {
   211  	var buf bytes.Buffer
   212  	err := types.WriteEncodedValue(ctx, &buf, r.valueSt)
   213  	if err != nil {
   214  		panic(err)
   215  	}
   216  	return buf.String()
   217  }
   218  
   219  func (r nomsRvStorage) nomsValue() types.Value {
   220  	return r.valueSt
   221  }
   222  
   223  type nomsTableMap struct {
   224  	types.Map
   225  }
   226  
   227  type fbRvStorage struct {
   228  	srv *serial.RootValue
   229  }
   230  
   231  func (r fbRvStorage) SetForeignKeyMap(ctx context.Context, vrw types.ValueReadWriter, v types.Value) (rootValueStorage, error) {
   232  	var h hash.Hash
   233  	isempty, err := EmptyForeignKeyCollection(v.(types.SerialMessage))
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	if !isempty {
   238  		ref, err := vrw.WriteValue(ctx, v)
   239  		if err != nil {
   240  			return nil, err
   241  		}
   242  		h = ref.TargetHash()
   243  	}
   244  	ret := r.clone()
   245  	copy(ret.srv.ForeignKeyAddrBytes(), h[:])
   246  	return ret, nil
   247  }
   248  
   249  func (r fbRvStorage) SetFeatureVersion(v FeatureVersion) (rootValueStorage, error) {
   250  	ret := r.clone()
   251  	ret.srv.MutateFeatureVersion(int64(v))
   252  	return ret, nil
   253  }
   254  
   255  func (r fbRvStorage) SetCollation(ctx context.Context, collation schema.Collation) (rootValueStorage, error) {
   256  	ret := r.clone()
   257  	ret.srv.MutateCollation(serial.Collation(collation))
   258  	return ret, nil
   259  }
   260  
   261  func (r fbRvStorage) GetSchemas(ctx context.Context) ([]schema.DatabaseSchema, error) {
   262  	numSchemas := r.srv.SchemasLength()
   263  	schemas := make([]schema.DatabaseSchema, numSchemas)
   264  	for i := 0; i < numSchemas; i++ {
   265  		dbSchema := new(serial.DatabaseSchema)
   266  		_, err := r.srv.TrySchemas(dbSchema, i)
   267  		if err != nil {
   268  			return nil, err
   269  		}
   270  
   271  		schemas[i] = schema.DatabaseSchema{
   272  			Name: string(dbSchema.Name()),
   273  		}
   274  	}
   275  
   276  	return schemas, nil
   277  }
   278  
   279  func (r fbRvStorage) SetSchemas(ctx context.Context, dbSchemas []schema.DatabaseSchema) (rootValueStorage, error) {
   280  	msg, err := r.serializeRootValue(r.srv.TablesBytes(), dbSchemas)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	return fbRvStorage{msg}, nil
   285  }
   286  
   287  func (r fbRvStorage) clone() fbRvStorage {
   288  	bs := make([]byte, len(r.srv.Table().Bytes))
   289  	copy(bs, r.srv.Table().Bytes)
   290  	var ret serial.RootValue
   291  	ret.Init(bs, r.srv.Table().Pos)
   292  	return fbRvStorage{&ret}
   293  }
   294  
   295  func (r fbRvStorage) DebugString(ctx context.Context) string {
   296  	return fmt.Sprintf("fbRvStorage[%d, %s, %s]",
   297  		r.srv.FeatureVersion(),
   298  		"...", // TODO: Print out tables map
   299  		hash.New(r.srv.ForeignKeyAddrBytes()).String())
   300  }
   301  
   302  func (r fbRvStorage) nomsValue() types.Value {
   303  	return types.SerialMessage(r.srv.Table().Bytes)
   304  }
   305  
   306  func (r fbRvStorage) GetFeatureVersion() (FeatureVersion, bool, error) {
   307  	return FeatureVersion(r.srv.FeatureVersion()), true, nil
   308  }
   309  
   310  func (r fbRvStorage) getAddressMap(vrw types.ValueReadWriter, ns tree.NodeStore) (prolly.AddressMap, error) {
   311  	tbytes := r.srv.TablesBytes()
   312  	node, err := shim.NodeFromValue(types.SerialMessage(tbytes))
   313  	if err != nil {
   314  		return prolly.AddressMap{}, err
   315  	}
   316  	return prolly.NewAddressMap(node, ns)
   317  }
   318  
   319  func (r fbRvStorage) GetTablesMap(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, databaseSchema string) (tableMap, error) {
   320  	am, err := r.getAddressMap(vrw, ns)
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  	return fbTableMap{AddressMap: am, schemaName: databaseSchema}, nil
   325  }
   326  
   327  func (r fbRvStorage) GetForeignKeys(ctx context.Context, vr types.ValueReader) (types.Value, bool, error) {
   328  	addr := hash.New(r.srv.ForeignKeyAddrBytes())
   329  	if addr.IsEmpty() {
   330  		return types.SerialMessage{}, false, nil
   331  	}
   332  	v, err := vr.ReadValue(ctx, addr)
   333  	if err != nil {
   334  		return types.SerialMessage{}, false, err
   335  	}
   336  	return v.(types.SerialMessage), true, nil
   337  }
   338  
   339  func (r fbRvStorage) GetCollation(ctx context.Context) (schema.Collation, error) {
   340  	collation := r.srv.Collation()
   341  	// Pre-existing repositories will return invalid here
   342  	if collation == serial.Collationinvalid {
   343  		return schema.Collation_Default, nil
   344  	}
   345  	return schema.Collation(collation), nil
   346  }
   347  
   348  func (r fbRvStorage) EditTablesMap(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, edits []tableEdit) (rootValueStorage, error) {
   349  	am, err := r.getAddressMap(vrw, ns)
   350  	if err != nil {
   351  		return nil, err
   352  	}
   353  	ae := am.Editor()
   354  	for _, e := range edits {
   355  		if e.old_name != "" {
   356  			oldaddr, err := am.Get(ctx, e.old_name)
   357  			if err != nil {
   358  				return nil, err
   359  			}
   360  			newaddr, err := am.Get(ctx, encodeTableNameForAddressMap(e.name))
   361  			if err != nil {
   362  				return nil, err
   363  			}
   364  			if oldaddr.IsEmpty() {
   365  				return nil, ErrTableNotFound
   366  			}
   367  			if !newaddr.IsEmpty() {
   368  				return nil, ErrTableExists
   369  			}
   370  			err = ae.Delete(ctx, e.old_name)
   371  			if err != nil {
   372  				return nil, err
   373  			}
   374  			err = ae.Update(ctx, encodeTableNameForAddressMap(e.name), oldaddr)
   375  			if err != nil {
   376  				return nil, err
   377  			}
   378  		} else {
   379  			if e.ref == nil {
   380  				err := ae.Delete(ctx, encodeTableNameForAddressMap(e.name))
   381  				if err != nil {
   382  					return nil, err
   383  				}
   384  			} else {
   385  				err := ae.Update(ctx, encodeTableNameForAddressMap(e.name), e.ref.TargetHash())
   386  				if err != nil {
   387  					return nil, err
   388  				}
   389  			}
   390  		}
   391  	}
   392  	am, err = ae.Flush(ctx)
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  
   397  	ambytes := []byte(tree.ValueFromNode(am.Node()).(types.SerialMessage))
   398  	dbSchemas, err := r.GetSchemas(ctx)
   399  	if err != nil {
   400  		return nil, err
   401  	}
   402  
   403  	msg, err := r.serializeRootValue(ambytes, dbSchemas)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  	return fbRvStorage{msg}, nil
   408  }
   409  
   410  func (r fbRvStorage) serializeRootValue(addressMapBytes []byte, dbSchemas []schema.DatabaseSchema) (*serial.RootValue, error) {
   411  	builder := flatbuffers.NewBuilder(80)
   412  	tablesoff := builder.CreateByteVector(addressMapBytes)
   413  	schemasOff := serializeDatabaseSchemas(builder, dbSchemas)
   414  
   415  	fkoff := builder.CreateByteVector(r.srv.ForeignKeyAddrBytes())
   416  	serial.RootValueStart(builder)
   417  	serial.RootValueAddFeatureVersion(builder, r.srv.FeatureVersion())
   418  	serial.RootValueAddCollation(builder, r.srv.Collation())
   419  	serial.RootValueAddTables(builder, tablesoff)
   420  	serial.RootValueAddForeignKeyAddr(builder, fkoff)
   421  	if schemasOff > 0 {
   422  		serial.RootValueAddSchemas(builder, schemasOff)
   423  	}
   424  
   425  	bs := serial.FinishMessage(builder, serial.RootValueEnd(builder), []byte(serial.RootValueFileID))
   426  	msg, err := serial.TryGetRootAsRootValue(bs, serial.MessagePrefixSz)
   427  	if err != nil {
   428  		return nil, err
   429  	}
   430  	return msg, nil
   431  }
   432  
   433  func serializeDatabaseSchemas(b *flatbuffers.Builder, dbSchemas []schema.DatabaseSchema) flatbuffers.UOffsetT {
   434  	// if we have no schemas, do not serialize an empty vector
   435  	if len(dbSchemas) == 0 {
   436  		return 0
   437  	}
   438  
   439  	offsets := make([]flatbuffers.UOffsetT, len(dbSchemas))
   440  	for i := len(dbSchemas) - 1; i >= 0; i-- {
   441  		dbSchema := dbSchemas[i]
   442  
   443  		nameOff := b.CreateString(dbSchema.Name)
   444  		serial.DatabaseSchemaStart(b)
   445  		serial.DatabaseSchemaAddName(b, nameOff)
   446  		offsets[i] = serial.DatabaseSchemaEnd(b)
   447  	}
   448  
   449  	serial.RootValueStartSchemasVector(b, len(offsets))
   450  	for i := len(offsets) - 1; i >= 0; i-- {
   451  		b.PrependUOffsetT(offsets[i])
   452  	}
   453  	return b.EndVector(len(offsets))
   454  }
   455  
   456  func encodeTableNameForAddressMap(name TableName) string {
   457  	if name.Schema == "" {
   458  		return name.Name
   459  	}
   460  	return fmt.Sprintf("\000%s\000%s", name.Schema, name.Name)
   461  }
   462  
   463  func decodeTableNameForAddressMap(encodedName, schemaName string) (string, bool) {
   464  	if schemaName == "" && encodedName[0] != 0 {
   465  		return encodedName, true
   466  	} else if schemaName != "" && encodedName[0] == 0 &&
   467  		len(encodedName) > len(schemaName)+2 &&
   468  		encodedName[1:len(schemaName)+1] == schemaName {
   469  		return encodedName[len(schemaName)+2:], true
   470  	}
   471  	return "", false
   472  }
   473  
   474  type fbTableMap struct {
   475  	prolly.AddressMap
   476  	schemaName string
   477  }
   478  
   479  func (m fbTableMap) Get(ctx context.Context, name string) (hash.Hash, error) {
   480  	return m.AddressMap.Get(ctx, encodeTableNameForAddressMap(TableName{Name: name, Schema: m.schemaName}))
   481  }
   482  
   483  func (m fbTableMap) Iter(ctx context.Context, cb func(string, hash.Hash) (bool, error)) error {
   484  	var stop bool
   485  	return m.AddressMap.IterAll(ctx, func(n string, a hash.Hash) error {
   486  		n, ok := decodeTableNameForAddressMap(n, m.schemaName)
   487  		if !stop && ok {
   488  			var err error
   489  			stop, err = cb(n, a)
   490  			return err
   491  		}
   492  		return nil
   493  	})
   494  }