github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/store/nbs/mem_table.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file incorporates work covered by the following copyright and
    16  // permission notice:
    17  //
    18  // Copyright 2016 Attic Labs, Inc. All rights reserved.
    19  // Licensed under the Apache License, version 2.0:
    20  // http://www.apache.org/licenses/LICENSE-2.0
    21  
    22  package nbs
    23  
    24  import (
    25  	"context"
    26  	"errors"
    27  	"fmt"
    28  	"sort"
    29  
    30  	"golang.org/x/sync/errgroup"
    31  
    32  	"github.com/dolthub/dolt/go/store/chunks"
    33  	"github.com/dolthub/dolt/go/store/hash"
    34  )
    35  
    36  type addChunkResult int
    37  
    38  const (
    39  	chunkExists addChunkResult = iota
    40  	chunkAdded
    41  	chunkNotAdded
    42  )
    43  
    44  func WriteChunks(chunks []chunks.Chunk) (string, []byte, error) {
    45  	var size uint64
    46  	for _, chunk := range chunks {
    47  		size += uint64(len(chunk.Data()))
    48  	}
    49  
    50  	mt := newMemTable(size)
    51  
    52  	return writeChunksToMT(mt, chunks)
    53  }
    54  
    55  func writeChunksToMT(mt *memTable, chunks []chunks.Chunk) (string, []byte, error) {
    56  	for _, chunk := range chunks {
    57  		res := mt.addChunk(chunk.Hash(), chunk.Data())
    58  		if res == chunkNotAdded {
    59  			return "", nil, errors.New("didn't create this memory table with enough space to add all the chunks")
    60  		}
    61  	}
    62  
    63  	var stats Stats
    64  	name, data, count, err := mt.write(nil, &stats)
    65  
    66  	if err != nil {
    67  		return "", nil, err
    68  	}
    69  
    70  	if count != uint32(len(chunks)) {
    71  		return "", nil, errors.New("didn't write everything")
    72  	}
    73  
    74  	return name.String(), data, nil
    75  }
    76  
    77  type memTable struct {
    78  	chunks             map[hash.Hash][]byte
    79  	order              []hasRecord // Must maintain the invariant that these are sorted by rec.order
    80  	pendingRefs        []hasRecord
    81  	getChildAddrs      []chunks.GetAddrsCb
    82  	maxData, totalData uint64
    83  
    84  	snapper snappyEncoder
    85  }
    86  
    87  func newMemTable(memTableSize uint64) *memTable {
    88  	return &memTable{chunks: map[hash.Hash][]byte{}, maxData: memTableSize}
    89  }
    90  
    91  func (mt *memTable) addChunk(h hash.Hash, data []byte) addChunkResult {
    92  	if len(data) == 0 {
    93  		panic("NBS blocks cannot be zero length")
    94  	}
    95  	if _, ok := mt.chunks[h]; ok {
    96  		return chunkExists
    97  	}
    98  
    99  	dataLen := uint64(len(data))
   100  	if mt.totalData+dataLen > mt.maxData {
   101  		return chunkNotAdded
   102  	}
   103  
   104  	mt.totalData += dataLen
   105  	mt.chunks[h] = data
   106  	mt.order = append(mt.order, hasRecord{
   107  		&h,
   108  		h.Prefix(),
   109  		len(mt.order),
   110  		false,
   111  	})
   112  	return chunkAdded
   113  }
   114  
   115  func (mt *memTable) addGetChildRefs(getAddrs chunks.GetAddrsCb) {
   116  	mt.getChildAddrs = append(mt.getChildAddrs, getAddrs)
   117  }
   118  
   119  func (mt *memTable) addChildRefs(addrs hash.HashSet) {
   120  	for h := range addrs {
   121  		h := h
   122  		mt.pendingRefs = append(mt.pendingRefs, hasRecord{
   123  			a:      &h,
   124  			prefix: h.Prefix(),
   125  			order:  len(mt.pendingRefs),
   126  		})
   127  	}
   128  }
   129  
   130  func (mt *memTable) count() (uint32, error) {
   131  	return uint32(len(mt.order)), nil
   132  }
   133  
   134  func (mt *memTable) uncompressedLen() (uint64, error) {
   135  	return mt.totalData, nil
   136  }
   137  
   138  func (mt *memTable) has(h hash.Hash) (bool, error) {
   139  	_, has := mt.chunks[h]
   140  	return has, nil
   141  }
   142  
   143  func (mt *memTable) hasMany(addrs []hasRecord) (bool, error) {
   144  	var remaining bool
   145  	for i, addr := range addrs {
   146  		if addr.has {
   147  			continue
   148  		}
   149  
   150  		ok, err := mt.has(*addr.a)
   151  
   152  		if err != nil {
   153  			return false, err
   154  		}
   155  
   156  		if ok {
   157  			addrs[i].has = true
   158  		} else {
   159  			remaining = true
   160  		}
   161  	}
   162  	return remaining, nil
   163  }
   164  
   165  func (mt *memTable) get(ctx context.Context, h hash.Hash, stats *Stats) ([]byte, error) {
   166  	return mt.chunks[h], nil
   167  }
   168  
   169  func (mt *memTable) getMany(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, *chunks.Chunk), stats *Stats) (bool, error) {
   170  	var remaining bool
   171  	for i, r := range reqs {
   172  		data := mt.chunks[*r.a]
   173  		if data != nil {
   174  			c := chunks.NewChunkWithHash(hash.Hash(*r.a), data)
   175  			reqs[i].found = true
   176  			found(ctx, &c)
   177  		} else {
   178  			remaining = true
   179  		}
   180  	}
   181  	return remaining, nil
   182  }
   183  
   184  func (mt *memTable) getManyCompressed(ctx context.Context, eg *errgroup.Group, reqs []getRecord, found func(context.Context, CompressedChunk), stats *Stats) (bool, error) {
   185  	var remaining bool
   186  	for i, r := range reqs {
   187  		data := mt.chunks[*r.a]
   188  		if data != nil {
   189  			c := chunks.NewChunkWithHash(hash.Hash(*r.a), data)
   190  			reqs[i].found = true
   191  			found(ctx, ChunkToCompressedChunk(c))
   192  		} else {
   193  			remaining = true
   194  		}
   195  	}
   196  
   197  	return remaining, nil
   198  }
   199  
   200  func (mt *memTable) extract(ctx context.Context, chunks chan<- extractRecord) error {
   201  	for _, hrec := range mt.order {
   202  		chunks <- extractRecord{a: *hrec.a, data: mt.chunks[*hrec.a], err: nil}
   203  	}
   204  
   205  	return nil
   206  }
   207  
   208  func (mt *memTable) write(haver chunkReader, stats *Stats) (name hash.Hash, data []byte, count uint32, err error) {
   209  	numChunks := uint64(len(mt.order))
   210  	if numChunks == 0 {
   211  		return hash.Hash{}, nil, 0, fmt.Errorf("mem table cannot write with zero chunks")
   212  	}
   213  	maxSize := maxTableSize(uint64(len(mt.order)), mt.totalData)
   214  	// todo: memory quota
   215  	buff := make([]byte, maxSize)
   216  	tw := newTableWriter(buff, mt.snapper)
   217  
   218  	if haver != nil {
   219  		sort.Sort(hasRecordByPrefix(mt.order)) // hasMany() requires addresses to be sorted.
   220  		_, err := haver.hasMany(mt.order)
   221  
   222  		if err != nil {
   223  			return hash.Hash{}, nil, 0, err
   224  		}
   225  
   226  		sort.Sort(hasRecordByOrder(mt.order)) // restore "insertion" order for write
   227  	}
   228  
   229  	for _, addr := range mt.order {
   230  		if !addr.has {
   231  			h := addr.a
   232  			tw.addChunk(*h, mt.chunks[*h])
   233  			count++
   234  		}
   235  	}
   236  	tableSize, name, err := tw.finish()
   237  
   238  	if err != nil {
   239  		return hash.Hash{}, nil, 0, err
   240  	}
   241  
   242  	if count > 0 {
   243  		stats.BytesPerPersist.Sample(uint64(tableSize))
   244  		stats.CompressedChunkBytesPerPersist.Sample(uint64(tw.totalCompressedData))
   245  		stats.UncompressedChunkBytesPerPersist.Sample(uint64(tw.totalUncompressedData))
   246  		stats.ChunksPerPersist.Sample(uint64(count))
   247  	}
   248  
   249  	return name, buff[:tableSize], count, nil
   250  }
   251  
   252  func (mt *memTable) close() error {
   253  	return nil
   254  }