github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/workload/workloadsql/dataload.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package workloadsql
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	gosql "database/sql"
    17  	"fmt"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/util/log"
    22  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    23  	"github.com/cockroachdb/cockroach/pkg/workload"
    24  	"github.com/cockroachdb/errors"
    25  	"golang.org/x/sync/errgroup"
    26  )
    27  
    28  // InsertsDataLoader is an InitialDataLoader implementation that loads data with
    29  // batched INSERTs. The zero-value gets some sane defaults for the tunable
    30  // settings.
    31  type InsertsDataLoader struct {
    32  	BatchSize   int
    33  	Concurrency int
    34  }
    35  
    36  // InitialDataLoad implements the InitialDataLoader interface.
    37  func (l InsertsDataLoader) InitialDataLoad(
    38  	ctx context.Context, db *gosql.DB, gen workload.Generator,
    39  ) (int64, error) {
    40  	if gen.Meta().Name == `tpch` {
    41  		return 0, errors.New(
    42  			`tpch currently doesn't work with the inserts data loader. try --data-loader=import`)
    43  	}
    44  
    45  	if l.BatchSize <= 0 {
    46  		l.BatchSize = 1000
    47  	}
    48  	if l.Concurrency < 1 {
    49  		l.Concurrency = 1
    50  	}
    51  
    52  	tables := gen.Tables()
    53  	var hooks workload.Hooks
    54  	if h, ok := gen.(workload.Hookser); ok {
    55  		hooks = h.Hooks()
    56  	}
    57  
    58  	for _, table := range tables {
    59  		createStmt := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%s" %s`, table.Name, table.Schema)
    60  		if _, err := db.ExecContext(ctx, createStmt); err != nil {
    61  			return 0, errors.Wrapf(err, "could not create table: %s", table.Name)
    62  		}
    63  	}
    64  
    65  	if hooks.PreLoad != nil {
    66  		if err := hooks.PreLoad(db); err != nil {
    67  			return 0, errors.Wrapf(err, "Could not preload")
    68  		}
    69  	}
    70  
    71  	var bytesAtomic int64
    72  	for _, table := range tables {
    73  		if table.InitialRows.NumBatches == 0 {
    74  			continue
    75  		} else if table.InitialRows.FillBatch == nil {
    76  			return 0, errors.Errorf(
    77  				`initial data is not supported for workload %s`, gen.Meta().Name)
    78  		}
    79  		tableStart := timeutil.Now()
    80  		var tableRowsAtomic int64
    81  
    82  		batchesPerWorker := table.InitialRows.NumBatches / l.Concurrency
    83  		g, gCtx := errgroup.WithContext(ctx)
    84  		for i := 0; i < l.Concurrency; i++ {
    85  			startIdx := i * batchesPerWorker
    86  			endIdx := startIdx + batchesPerWorker
    87  			if i == l.Concurrency-1 {
    88  				// Account for any rounding error in batchesPerWorker.
    89  				endIdx = table.InitialRows.NumBatches
    90  			}
    91  			g.Go(func() error {
    92  				var insertStmtBuf bytes.Buffer
    93  				var params []interface{}
    94  				var numRows int
    95  				flush := func() error {
    96  					if len(params) > 0 {
    97  						insertStmt := insertStmtBuf.String()
    98  						if _, err := db.ExecContext(gCtx, insertStmt, params...); err != nil {
    99  							return errors.Wrapf(err, "failed insert into %s", table.Name)
   100  						}
   101  					}
   102  					insertStmtBuf.Reset()
   103  					fmt.Fprintf(&insertStmtBuf, `INSERT INTO "%s" VALUES `, table.Name)
   104  					params = params[:0]
   105  					numRows = 0
   106  					return nil
   107  				}
   108  				_ = flush()
   109  
   110  				for batchIdx := startIdx; batchIdx < endIdx; batchIdx++ {
   111  					for _, row := range table.InitialRows.BatchRows(batchIdx) {
   112  						atomic.AddInt64(&tableRowsAtomic, 1)
   113  						if len(params) != 0 {
   114  							insertStmtBuf.WriteString(`,`)
   115  						}
   116  						insertStmtBuf.WriteString(`(`)
   117  						for i, datum := range row {
   118  							atomic.AddInt64(&bytesAtomic, workload.ApproxDatumSize(datum))
   119  							if i != 0 {
   120  								insertStmtBuf.WriteString(`,`)
   121  							}
   122  							fmt.Fprintf(&insertStmtBuf, `$%d`, len(params)+i+1)
   123  						}
   124  						params = append(params, row...)
   125  						insertStmtBuf.WriteString(`)`)
   126  						if numRows++; numRows >= l.BatchSize {
   127  							if err := flush(); err != nil {
   128  								return err
   129  							}
   130  						}
   131  					}
   132  				}
   133  				return flush()
   134  			})
   135  		}
   136  		if err := g.Wait(); err != nil {
   137  			return 0, err
   138  		}
   139  		tableRows := int(atomic.LoadInt64(&tableRowsAtomic))
   140  		log.Infof(ctx, `imported %s (%s, %d rows)`,
   141  			table.Name, timeutil.Since(tableStart).Round(time.Second), tableRows,
   142  		)
   143  	}
   144  	return atomic.LoadInt64(&bytesAtomic), nil
   145  }