code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/batcher_list.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  
    22  	"code.vegaprotocol.io/vega/datanode/metrics"
    23  
    24  	"github.com/jackc/pgx/v4"
    25  )
    26  
    27  type ListBatcher[T simpleEntity] struct {
    28  	pending     []T
    29  	tableName   string
    30  	columnNames []string
    31  }
    32  
    33  func NewListBatcher[T simpleEntity](tableName string, columnNames []string) ListBatcher[T] {
    34  	return ListBatcher[T]{
    35  		tableName:   tableName,
    36  		columnNames: columnNames,
    37  		pending:     make([]T, 0, 1000),
    38  	}
    39  }
    40  
    41  type simpleEntity interface {
    42  	ToRow() []interface{}
    43  }
    44  
    45  func (b *ListBatcher[T]) Add(entity T) {
    46  	metrics.IncrementBatcherAddedEntities(b.tableName)
    47  	b.pending = append(b.pending, entity)
    48  }
    49  
    50  func (b *ListBatcher[T]) Flush(ctx context.Context, connection Connection) ([]T, error) {
    51  	rows := make([][]interface{}, len(b.pending))
    52  	for i := 0; i < len(b.pending); i++ {
    53  		rows[i] = b.pending[i].ToRow()
    54  	}
    55  
    56  	copyCount, err := connection.CopyFrom(
    57  		ctx,
    58  		pgx.Identifier{b.tableName},
    59  		b.columnNames,
    60  		pgx.CopyFromRows(rows),
    61  	)
    62  	if err != nil {
    63  		return nil, fmt.Errorf("failed to copy %q entries into database: %w", b.tableName, err)
    64  	}
    65  
    66  	if copyCount != int64(len(b.pending)) {
    67  		return nil, fmt.Errorf("copied %d %s rows into the database, expected to copy %d",
    68  			copyCount,
    69  			b.tableName,
    70  			len(b.pending))
    71  	}
    72  
    73  	flushed := b.pending
    74  	b.pending = b.pending[:0]
    75  
    76  	metrics.BatcherFlushedEntitiesAdd(b.tableName, len(rows))
    77  	return flushed, nil
    78  }