code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/batcher_map.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "context" 20 "fmt" 21 22 "code.vegaprotocol.io/vega/datanode/metrics" 23 24 "github.com/jackc/pgx/v4" 25 orderedmap "github.com/wk8/go-ordered-map/v2" 26 ) 27 28 type MapBatcher[K entityKey, V entity[K]] struct { 29 pending *orderedmap.OrderedMap[K, V] 30 tableName string 31 columnNames []string 32 } 33 34 func NewMapBatcher[K entityKey, V entity[K]](tableName string, columnNames []string) MapBatcher[K, V] { 35 return MapBatcher[K, V]{ 36 tableName: tableName, 37 columnNames: columnNames, 38 pending: orderedmap.New[K, V](), 39 } 40 } 41 42 type entityKey interface { 43 comparable 44 } 45 46 type entity[K entityKey] interface { 47 ToRow() []interface{} 48 Key() K 49 } 50 51 func (b *MapBatcher[K, V]) Add(e V) { 52 metrics.IncrementBatcherAddedEntities(b.tableName) 53 key := e.Key() 54 _, present := b.pending.Set(key, e) 55 if present { 56 b.pending.MoveToBack(key) 57 } 58 } 59 60 func (b *MapBatcher[K, V]) Flush(ctx context.Context, connection Connection) ([]V, error) { 61 nPending := b.pending.Len() 62 if nPending == 0 { 63 return nil, nil 64 } 65 66 rows := make([][]interface{}, 0, nPending) 67 values := make([]V, 0, nPending) 68 for kv := b.pending.Oldest(); kv != nil; kv = kv.Next() { 69 rows = append(rows, kv.Value.ToRow()) 70 values = append(values, kv.Value) 71 } 72 73 copyCount, err := connection.CopyFrom( 74 ctx, 75 pgx.Identifier{b.tableName}, 76 b.columnNames, 77 pgx.CopyFromRows(rows), 78 ) 79 if err != nil { 80 return nil, fmt.Errorf("failed to copy %s entries into database:%w", b.tableName, err) 81 } 82 83 if copyCount != int64(nPending) { 84 return nil, fmt.Errorf("copied %d %s rows into the database, expected to copy %d", 85 copyCount, 86 b.tableName, 87 nPending) 88 } 89 90 b.pending = orderedmap.New[K, V]() 91 92 metrics.BatcherFlushedEntitiesAdd(b.tableName, len(rows)) 93 94 return values, nil 95 }