github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/workload/rand/rand.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package rand
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	gosql "database/sql"
    17  	"database/sql/driver"
    18  	"fmt"
    19  	"math/rand"
    20  	"reflect"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    27  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    28  	"github.com/cockroachdb/cockroach/pkg/workload"
    29  	"github.com/cockroachdb/cockroach/pkg/workload/histogram"
    30  	"github.com/cockroachdb/errors"
    31  	"github.com/lib/pq"
    32  	"github.com/lib/pq/oid"
    33  	"github.com/spf13/pflag"
    34  )
    35  
    36  type random struct {
    37  	flags     workload.Flags
    38  	connFlags *workload.ConnFlags
    39  
    40  	batchSize int
    41  
    42  	seed int64
    43  
    44  	tableName string
    45  
    46  	tables     int
    47  	method     string
    48  	primaryKey string
    49  	nullPct    int
    50  }
    51  
    52  func init() {
    53  	workload.Register(randMeta)
    54  }
    55  
    56  var randMeta = workload.Meta{
    57  	Name:        `rand`,
    58  	Description: `random writes to table`,
    59  	Version:     `1.0.0`,
    60  	New: func() workload.Generator {
    61  		g := &random{}
    62  		g.flags.FlagSet = pflag.NewFlagSet(`rand`, pflag.ContinueOnError)
    63  		g.flags.Meta = map[string]workload.FlagMeta{
    64  			`batch`: {RuntimeOnly: true},
    65  		}
    66  		g.flags.IntVar(&g.tables, `tables`, 1, `Number of tables to create`)
    67  		g.flags.StringVar(&g.tableName, `table`, ``, `Table to write to`)
    68  		g.flags.IntVar(&g.batchSize, `batch`, 1, `Number of rows to insert in a single SQL statement`)
    69  		g.flags.StringVar(&g.method, `method`, `upsert`, `Choice of DML name: insert, upsert, ioc-update (insert on conflict update), ioc-nothing (insert on conflict no nothing)`)
    70  		g.flags.Int64Var(&g.seed, `seed`, 1, `Key hash seed.`)
    71  		g.flags.StringVar(&g.primaryKey, `primary-key`, ``, `ioc-update and ioc-nothing require primary key`)
    72  		g.flags.IntVar(&g.nullPct, `null-percent`, 5, `Percent random nulls`)
    73  		g.connFlags = workload.NewConnFlags(&g.flags)
    74  		return g
    75  	},
    76  }
    77  
    78  // Meta implements the Generator interface.
    79  func (*random) Meta() workload.Meta { return randMeta }
    80  
    81  // Flags implements the Flagser interface.
    82  func (w *random) Flags() workload.Flags { return w.flags }
    83  
    84  // Hooks implements the Hookser interface.
    85  func (w *random) Hooks() workload.Hooks {
    86  	return workload.Hooks{}
    87  }
    88  
    89  // Tables implements the Generator interface.
    90  func (w *random) Tables() []workload.Table {
    91  	tables := make([]workload.Table, w.tables)
    92  	rng := rand.New(rand.NewSource(w.seed))
    93  	for i := 0; i < w.tables; i++ {
    94  		createTable := sqlbase.RandCreateTable(rng, "table", rng.Int())
    95  		ctx := tree.NewFmtCtx(tree.FmtParsable)
    96  		createTable.FormatBody(ctx)
    97  		tables[i] = workload.Table{
    98  			Name:   createTable.Table.String(),
    99  			Schema: ctx.CloseAndGetString(),
   100  		}
   101  	}
   102  	return tables
   103  }
   104  
   105  type col struct {
   106  	name          string
   107  	dataType      *types.T
   108  	dataPrecision int
   109  	dataScale     int
   110  	cdefault      gosql.NullString
   111  	isNullable    bool
   112  }
   113  
   114  // Ops implements the Opser interface.
   115  func (w *random) Ops(urls []string, reg *histogram.Registry) (workload.QueryLoad, error) {
   116  	sqlDatabase, err := workload.SanitizeUrls(w, w.connFlags.DBOverride, urls)
   117  	if err != nil {
   118  		return workload.QueryLoad{}, err
   119  	}
   120  	db, err := gosql.Open(`cockroach`, strings.Join(urls, ` `))
   121  	if err != nil {
   122  		return workload.QueryLoad{}, err
   123  	}
   124  	// Allow a maximum of concurrency+1 connections to the database.
   125  	db.SetMaxOpenConns(w.connFlags.Concurrency + 1)
   126  	db.SetMaxIdleConns(w.connFlags.Concurrency + 1)
   127  
   128  	tableName := w.tableName
   129  	if tableName == "" {
   130  		tableName = w.Tables()[0].Name
   131  	}
   132  
   133  	var relid int
   134  	if err := db.QueryRow(fmt.Sprintf("SELECT '%s'::REGCLASS::OID", tableName)).Scan(&relid); err != nil {
   135  		return workload.QueryLoad{}, err
   136  	}
   137  
   138  	rows, err := db.Query(
   139  		`
   140  SELECT attname, atttypid, adsrc, NOT attnotnull
   141  FROM pg_catalog.pg_attribute
   142  LEFT JOIN pg_catalog.pg_attrdef
   143  ON attrelid=adrelid AND attnum=adnum
   144  WHERE attrelid=$1`, relid)
   145  	if err != nil {
   146  		return workload.QueryLoad{}, err
   147  	}
   148  
   149  	var cols []col
   150  	var numCols = 0
   151  
   152  	defer rows.Close()
   153  	for rows.Next() {
   154  		var c col
   155  		c.dataPrecision = 0
   156  		c.dataScale = 0
   157  
   158  		var typOid int
   159  		if err := rows.Scan(&c.name, &typOid, &c.cdefault, &c.isNullable); err != nil {
   160  			return workload.QueryLoad{}, err
   161  		}
   162  		datumType := types.OidToType[oid.Oid(typOid)]
   163  		c.dataType = datumType
   164  		if c.cdefault.String == "unique_rowid()" { // skip
   165  			continue
   166  		}
   167  		if strings.HasPrefix(c.cdefault.String, "uuid_v4()") { // skip
   168  			continue
   169  		}
   170  		cols = append(cols, c)
   171  		numCols++
   172  	}
   173  
   174  	if numCols == 0 {
   175  		return workload.QueryLoad{}, errors.New("no columns detected")
   176  	}
   177  
   178  	// insert on conflict requires the primary key. check information_schema if not specified on the command line
   179  	if strings.HasPrefix(w.method, "ioc") && w.primaryKey == "" {
   180  		rows, err := db.Query(
   181  			`
   182  SELECT a.attname
   183  FROM   pg_index i
   184  JOIN   pg_attribute a ON a.attrelid = i.indrelid
   185                        AND a.attnum = ANY(i.indkey)
   186  WHERE  i.indrelid = $1
   187  AND    i.indisprimary`, relid)
   188  		if err != nil {
   189  			return workload.QueryLoad{}, err
   190  		}
   191  		defer rows.Close()
   192  		for rows.Next() {
   193  			var colname string
   194  
   195  			if err := rows.Scan(&colname); err != nil {
   196  				return workload.QueryLoad{}, err
   197  			}
   198  			if w.primaryKey != "" {
   199  				w.primaryKey += "," + colname
   200  			} else {
   201  				w.primaryKey += colname
   202  			}
   203  		}
   204  	}
   205  
   206  	if strings.HasPrefix(w.method, "ioc") && w.primaryKey == "" {
   207  		err := errors.New(
   208  			"insert on conflict requires primary key to be specified via -primary if the table does " +
   209  				"not have primary key")
   210  		return workload.QueryLoad{}, err
   211  	}
   212  
   213  	var dmlMethod string
   214  	var dmlSuffix bytes.Buffer
   215  	var buf bytes.Buffer
   216  	switch w.method {
   217  	case "insert":
   218  		dmlMethod = "insert"
   219  		dmlSuffix.WriteString("")
   220  	case "upsert":
   221  		dmlMethod = "upsert"
   222  		dmlSuffix.WriteString("")
   223  	case "ioc-nothing":
   224  		dmlMethod = "insert"
   225  		dmlSuffix.WriteString(fmt.Sprintf(" on conflict (%s) do nothing", w.primaryKey))
   226  	case "ioc-update":
   227  		dmlMethod = "insert"
   228  		dmlSuffix.WriteString(fmt.Sprintf(" on conflict (%s) do update set ", w.primaryKey))
   229  		for i, c := range cols {
   230  			if i > 0 {
   231  				dmlSuffix.WriteString(",")
   232  			}
   233  			dmlSuffix.WriteString(fmt.Sprintf("%s=EXCLUDED.%s", c.name, c.name))
   234  		}
   235  	default:
   236  		return workload.QueryLoad{}, errors.Errorf("%s DML method not valid", w.primaryKey)
   237  	}
   238  
   239  	fmt.Fprintf(&buf, `%s INTO %s.%s (`, dmlMethod, sqlDatabase, tableName)
   240  	for i, c := range cols {
   241  		if i > 0 {
   242  			buf.WriteString(",")
   243  		}
   244  		buf.WriteString(c.name)
   245  	}
   246  	buf.WriteString(`) VALUES `)
   247  
   248  	nCols := len(cols)
   249  	for i := 0; i < w.batchSize; i++ {
   250  		if i > 0 {
   251  			buf.WriteString(", ")
   252  		}
   253  		buf.WriteString("(")
   254  		for j := range cols {
   255  			if j > 0 {
   256  				buf.WriteString(", ")
   257  			}
   258  			fmt.Fprintf(&buf, `$%d`, 1+j+(nCols*i))
   259  		}
   260  		buf.WriteString(")")
   261  	}
   262  
   263  	buf.WriteString(dmlSuffix.String())
   264  
   265  	if testing.Verbose() {
   266  		fmt.Println(buf.String())
   267  	}
   268  
   269  	writeStmt, err := db.Prepare(buf.String())
   270  	if err != nil {
   271  		return workload.QueryLoad{}, err
   272  	}
   273  
   274  	ql := workload.QueryLoad{SQLDatabase: sqlDatabase}
   275  
   276  	for i := 0; i < w.connFlags.Concurrency; i++ {
   277  		op := randOp{
   278  			config:    w,
   279  			hists:     reg.GetHandle(),
   280  			db:        db,
   281  			cols:      cols,
   282  			rng:       rand.New(rand.NewSource(w.seed + int64(i))),
   283  			writeStmt: writeStmt,
   284  		}
   285  		ql.WorkerFns = append(ql.WorkerFns, op.run)
   286  	}
   287  	return ql, nil
   288  }
   289  
   290  type randOp struct {
   291  	config    *random
   292  	hists     *histogram.Histograms
   293  	db        *gosql.DB
   294  	cols      []col
   295  	rng       *rand.Rand
   296  	writeStmt *gosql.Stmt
   297  }
   298  
   299  // DatumToGoSQL converts a datum to a Go type.
   300  func DatumToGoSQL(d tree.Datum) (interface{}, error) {
   301  	d = tree.UnwrapDatum(nil, d)
   302  	if d == tree.DNull {
   303  		return nil, nil
   304  	}
   305  	switch d := d.(type) {
   306  	case *tree.DBool:
   307  		return bool(*d), nil
   308  	case *tree.DString:
   309  		return string(*d), nil
   310  	case *tree.DBytes:
   311  		return string(*d), nil
   312  	case *tree.DDate, *tree.DTime:
   313  		return tree.AsStringWithFlags(d, tree.FmtBareStrings), nil
   314  	case *tree.DTimestamp:
   315  		return d.Time, nil
   316  	case *tree.DTimestampTZ:
   317  		return d.Time, nil
   318  	case *tree.DInterval:
   319  		return d.Duration.String(), nil
   320  	case *tree.DBitArray:
   321  		return tree.AsStringWithFlags(d, tree.FmtBareStrings), nil
   322  	case *tree.DInt:
   323  		return int64(*d), nil
   324  	case *tree.DOid:
   325  		return int(d.DInt), nil
   326  	case *tree.DFloat:
   327  		return float64(*d), nil
   328  	case *tree.DDecimal:
   329  		return d.Float64()
   330  	case *tree.DArray:
   331  		arr := make([]interface{}, len(d.Array))
   332  		for i := range d.Array {
   333  			elt, err := DatumToGoSQL(d.Array[i])
   334  			if err != nil {
   335  				return nil, err
   336  			}
   337  			if elt == nil {
   338  				elt = nullVal{}
   339  			}
   340  			arr[i] = elt
   341  		}
   342  		return pq.Array(arr), nil
   343  	case *tree.DUuid:
   344  		return d.UUID, nil
   345  	case *tree.DIPAddr:
   346  		return d.IPAddr.String(), nil
   347  	}
   348  	return nil, errors.Errorf("unhandled datum type: %s", reflect.TypeOf(d))
   349  }
   350  
   351  type nullVal struct {
   352  }
   353  
   354  func (nullVal) Value() (driver.Value, error) {
   355  	return nil, nil
   356  }
   357  
   358  func (o *randOp) run(ctx context.Context) (err error) {
   359  	params := make([]interface{}, len(o.cols)*o.config.batchSize)
   360  	k := 0 // index into params
   361  	for j := 0; j < o.config.batchSize; j++ {
   362  		for _, c := range o.cols {
   363  			nullPct := 0
   364  			if c.isNullable && o.config.nullPct > 0 {
   365  				nullPct = 100 / o.config.nullPct
   366  			}
   367  			d := sqlbase.RandDatumWithNullChance(o.rng, c.dataType, nullPct)
   368  			params[k], err = DatumToGoSQL(d)
   369  			if err != nil {
   370  				return err
   371  			}
   372  			k++
   373  		}
   374  	}
   375  	start := timeutil.Now()
   376  	_, err = o.writeStmt.ExecContext(ctx, params...)
   377  	o.hists.Get(`write`).Record(timeutil.Since(start))
   378  	return err
   379  }