github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/workload/querylog/querylog.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package querylog
    12  
    13  import (
    14  	"archive/zip"
    15  	"bufio"
    16  	"context"
    17  	gosql "database/sql"
    18  	"fmt"
    19  	"io"
    20  	"io/ioutil"
    21  	"math/rand"
    22  	"os"
    23  	"path/filepath"
    24  	"regexp"
    25  	"strings"
    26  	"time"
    27  
    28  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    30  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    31  	"github.com/cockroachdb/cockroach/pkg/util/log"
    32  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    33  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    34  	"github.com/cockroachdb/cockroach/pkg/workload"
    35  	"github.com/cockroachdb/cockroach/pkg/workload/histogram"
    36  	workloadrand "github.com/cockroachdb/cockroach/pkg/workload/rand"
    37  	"github.com/cockroachdb/errors"
    38  	"github.com/jackc/pgx"
    39  	"github.com/lib/pq/oid"
    40  	"github.com/spf13/pflag"
    41  )
    42  
    43  type querylog struct {
    44  	flags     workload.Flags
    45  	connFlags *workload.ConnFlags
    46  
    47  	state querylogState
    48  
    49  	dirPath            string
    50  	zipPath            string
    51  	filesToParse       int
    52  	minSamplingProb    float64
    53  	nullPct            int
    54  	numSamples         int
    55  	omitDeleteQueries  bool
    56  	omitWriteQueries   bool
    57  	probOfUsingSamples float64
    58  	// querybenchPath (if set) tells the querylog where to write the generated
    59  	// queries to be later executed by querybench in round-robin fashion, i.e.
    60  	// querylog itself won't execute the queries against the database and only
    61  	// will prepare the file.
    62  	querybenchPath string
    63  	count          uint
    64  	seed           int64
    65  	// TODO(yuzefovich): this is a great variable to move to the main generator.
    66  	stmtTimeoutSeconds int64
    67  	verbose            bool
    68  }
    69  
    70  type querylogState struct {
    71  	tableNames             []string
    72  	totalQueryCount        int
    73  	queryCountPerTable     []int
    74  	seenQueriesByTableName []map[string]int
    75  	tableUsed              map[string]bool
    76  	columnsByTableName     map[string][]columnInfo
    77  }
    78  
    79  func init() {
    80  	workload.Register(querylogMeta)
    81  }
    82  
    83  var querylogMeta = workload.Meta{
    84  	Name:        `querylog`,
    85  	Description: `Querylog is a tool that produces a workload based on the provided query log.`,
    86  	Version:     `1.0.0`,
    87  	New: func() workload.Generator {
    88  		g := &querylog{}
    89  		g.flags.FlagSet = pflag.NewFlagSet(`querylog`, pflag.ContinueOnError)
    90  		g.flags.Meta = map[string]workload.FlagMeta{
    91  			`dir`:                   {RuntimeOnly: true},
    92  			`files-to-parse`:        {RuntimeOnly: true},
    93  			`min-sampling-prob`:     {RuntimeOnly: true},
    94  			`null-percent`:          {RuntimeOnly: true},
    95  			`num-samples`:           {RuntimeOnly: true},
    96  			`omit-delete-queries`:   {RuntimeOnly: true},
    97  			`omit-write-queries`:    {RuntimeOnly: true},
    98  			`prob-of-using-samples`: {RuntimeOnly: true},
    99  			`seed`:                  {RuntimeOnly: true},
   100  			`statement-timeout`:     {RuntimeOnly: true},
   101  			`verbose`:               {RuntimeOnly: true},
   102  			`zip`:                   {RuntimeOnly: true},
   103  		}
   104  		g.flags.UintVar(&g.count, `count`, 100, `Number of queries to be written for querybench (used only if --querybench-path is specified).`)
   105  		g.flags.StringVar(&g.dirPath, `dir`, ``, `Directory of the querylog files.`)
   106  		g.flags.IntVar(&g.filesToParse, `files-to-parse`, 5, `Maximum number of files in the query log to process.`)
   107  		g.flags.Float64Var(&g.minSamplingProb, `min-sampling-prob`, 0.01, `Minimum sampling probability defines the minimum chance `+
   108  			`that a value will be chosen as a sample. The smaller the number is the more diverse samples will be (given that the table has enough values). `+
   109  			`However, at the same time, the sampling process will be slower.`)
   110  		g.flags.IntVar(&g.nullPct, `null-percent`, 5, `Percent random nulls.`)
   111  		g.flags.IntVar(&g.numSamples, `num-samples`, 1000, `Number of samples to be taken from the tables. The bigger this number, `+
   112  			`the more diverse values will be chosen for the generated queries.`)
   113  		g.flags.BoolVar(&g.omitDeleteQueries, `omit-delete-queries`, true, `Indicates whether delete queries should be omitted.`)
   114  		g.flags.BoolVar(&g.omitWriteQueries, `omit-write-queries`, true, `Indicates whether write queries (INSERTs and UPSERTs) should be omitted.`)
   115  		g.flags.Float64Var(&g.probOfUsingSamples, `prob-of-using-samples`, 0.9, `Probability of using samples to generate values for `+
   116  			`the placeholders. Say it is 0.9, then with 0.1 probability the values will be generated randomly.`)
   117  		// TODO(yuzefovich): improve termination of querylog when used for
   118  		//  querybench file generation.
   119  		g.flags.StringVar(&g.querybenchPath, `querybench-path`, ``, `Path to write the generated queries to for querybench tool. `+
   120  			`NOTE: at the moment --max-ops=1 is the best way to terminate the generator to produce the desired count of queries.`)
   121  		g.flags.Int64Var(&g.seed, `seed`, 1, `Random number generator seed.`)
   122  		g.flags.Int64Var(&g.stmtTimeoutSeconds, `statement-timeout`, 0, `Sets session's statement_timeout setting (in seconds).'`)
   123  		g.flags.BoolVar(&g.verbose, `verbose`, false, `Indicates whether error messages should be printed out.`)
   124  		g.flags.StringVar(&g.zipPath, `zip`, ``, `Path to the zip with the query log. Note: this zip will be extracted into a temporary `+
   125  			`directory at the same path as zip just without '.zip' extension which will be removed after parsing is complete.`)
   126  
   127  		g.connFlags = workload.NewConnFlags(&g.flags)
   128  		return g
   129  	},
   130  }
   131  
   132  // Meta implements the Generator interface.
   133  func (*querylog) Meta() workload.Meta { return querylogMeta }
   134  
   135  // Flags implements the Flagser interface.
   136  func (w *querylog) Flags() workload.Flags { return w.flags }
   137  
   138  // Tables implements the Generator interface.
   139  func (*querylog) Tables() []workload.Table {
   140  	// Assume the necessary tables are already present.
   141  	return []workload.Table{}
   142  }
   143  
   144  // Hooks implements the Hookser interface.
   145  func (w *querylog) Hooks() workload.Hooks {
   146  	return workload.Hooks{
   147  		Validate: func() error {
   148  			if w.zipPath == "" && w.dirPath == "" {
   149  				return errors.Errorf("Missing required argument: either `--zip` or `--dir` have to be specified.")
   150  			}
   151  			if w.zipPath != "" {
   152  				if w.zipPath[len(w.zipPath)-4:] != ".zip" {
   153  					return errors.Errorf("Illegal argument: `--zip` is expected to end with '.zip'.")
   154  				}
   155  			}
   156  			if w.minSamplingProb < 0.00000001 || w.minSamplingProb > 1.0 {
   157  				return errors.Errorf("Illegal argument: `--min-sampling-prob` must be in [0.00000001, 1.0] range.")
   158  			}
   159  			if w.nullPct < 0 || w.nullPct > 100 {
   160  				return errors.Errorf("Illegal argument: `--null-pct` must be in [0, 100] range.")
   161  			}
   162  			if w.probOfUsingSamples < 0 || w.probOfUsingSamples > 1.0 {
   163  				return errors.Errorf("Illegal argument: `--prob-of-using-samples` must be in [0.0, 1.0] range.")
   164  			}
   165  			if w.stmtTimeoutSeconds < 0 {
   166  				return errors.Errorf("Illegal argument: `--statement-timeout` must be a non-negative integer.")
   167  			}
   168  			return nil
   169  		},
   170  	}
   171  }
   172  
   173  // Ops implements the Opser interface.
   174  func (w *querylog) Ops(urls []string, reg *histogram.Registry) (workload.QueryLoad, error) {
   175  	ctx := context.Background()
   176  
   177  	sqlDatabase, err := workload.SanitizeUrls(w, w.connFlags.DBOverride, urls)
   178  	if err != nil {
   179  		return workload.QueryLoad{}, err
   180  	}
   181  	db, err := gosql.Open(`cockroach`, strings.Join(urls, ` `))
   182  	if err != nil {
   183  		return workload.QueryLoad{}, err
   184  	}
   185  	// Allow a maximum of concurrency+1 connections to the database.
   186  	db.SetMaxOpenConns(w.connFlags.Concurrency + 1)
   187  	db.SetMaxIdleConns(w.connFlags.Concurrency + 1)
   188  
   189  	if err = w.getTableNames(db); err != nil {
   190  		return workload.QueryLoad{}, err
   191  	}
   192  
   193  	if err = w.processQueryLog(ctx); err != nil {
   194  		return workload.QueryLoad{}, err
   195  	}
   196  
   197  	if err = w.getColumnsInfo(db); err != nil {
   198  		return workload.QueryLoad{}, err
   199  	}
   200  
   201  	if err = w.populateSamples(ctx, db); err != nil {
   202  		return workload.QueryLoad{}, err
   203  	}
   204  
   205  	// TODO(yuzefovich): implement round-robin similar to the one in
   206  	// workload/driver.go.
   207  	if len(urls) != 1 {
   208  		return workload.QueryLoad{}, errors.Errorf(
   209  			"Exactly one connection string is supported at the moment.")
   210  	}
   211  	connCfg, err := pgx.ParseConnectionString(urls[0])
   212  	if err != nil {
   213  		return workload.QueryLoad{}, err
   214  	}
   215  	ql := workload.QueryLoad{SQLDatabase: sqlDatabase}
   216  	if w.querybenchPath != `` {
   217  		conn, err := pgx.Connect(connCfg)
   218  		if err != nil {
   219  			return workload.QueryLoad{}, err
   220  		}
   221  		worker := newQuerybenchWorker(w, reg, conn, 0 /* id */)
   222  		ql.WorkerFns = append(ql.WorkerFns, worker.querybenchRun)
   223  		return ql, nil
   224  	}
   225  	for i := 0; i < w.connFlags.Concurrency; i++ {
   226  		conn, err := pgx.Connect(connCfg)
   227  		if err != nil {
   228  			return workload.QueryLoad{}, err
   229  		}
   230  		worker := newWorker(w, reg, conn, i)
   231  		ql.WorkerFns = append(ql.WorkerFns, worker.run)
   232  	}
   233  	return ql, nil
   234  }
   235  
   236  type worker struct {
   237  	config *querylog
   238  	hists  *histogram.Histograms
   239  	// We're using pgx.Conn to make sure that statement_timeout variable (when
   240  	// applicable) is properly set on every connection.
   241  	conn *pgx.Conn
   242  	id   int
   243  	rng  *rand.Rand
   244  
   245  	reWriteQuery *regexp.Regexp
   246  
   247  	querybenchPath string
   248  }
   249  
   250  func newWorker(q *querylog, reg *histogram.Registry, conn *pgx.Conn, id int) *worker {
   251  	return &worker{
   252  		config:       q,
   253  		hists:        reg.GetHandle(),
   254  		conn:         conn,
   255  		id:           id,
   256  		rng:          rand.New(rand.NewSource(q.seed + int64(id))),
   257  		reWriteQuery: regexp.MustCompile(regexWriteQueryPattern),
   258  	}
   259  }
   260  
   261  func newQuerybenchWorker(q *querylog, reg *histogram.Registry, conn *pgx.Conn, id int) *worker {
   262  	w := newWorker(q, reg, conn, id)
   263  	w.querybenchPath = q.querybenchPath
   264  	return w
   265  }
   266  
   267  // run is the main function of the worker in which it chooses a query, attempts
   268  // to generate values for the placeholders, and executes the query. Most errors
   269  // (if any occur) are "swallowed" and do not stop the worker. The worker will
   270  // only stop if the workload will tell it so (upon reaching the desired
   271  // duration).
   272  func (w *worker) run(ctx context.Context) error {
   273  	if w.config.stmtTimeoutSeconds != 0 {
   274  		if _, err := w.conn.Exec(fmt.Sprintf("SET statement_timeout='%ds'", w.config.stmtTimeoutSeconds)); err != nil {
   275  			return err
   276  		}
   277  	}
   278  
   279  	var start time.Time
   280  	for {
   281  		chosenQuery, tableName := w.chooseQuery()
   282  		pholdersColumnNames, numRepeats, err := w.deduceColumnNamesForPlaceholders(ctx, chosenQuery)
   283  		if err != nil {
   284  			if w.config.verbose {
   285  				log.Infof(ctx, "Encountered an error %s while deducing column names corresponding to the placeholders", err.Error())
   286  				printQueryShortened(ctx, chosenQuery)
   287  			}
   288  			continue
   289  		}
   290  
   291  		placeholders, err := w.generatePlaceholders(ctx, chosenQuery, pholdersColumnNames, numRepeats, tableName)
   292  		if err != nil {
   293  			if w.config.verbose {
   294  				log.Infof(ctx, "Encountered an error %s while generating values for the placeholders", err.Error())
   295  				printQueryShortened(ctx, chosenQuery)
   296  			}
   297  			continue
   298  		}
   299  
   300  		start = timeutil.Now()
   301  		rows, err := w.conn.Query(chosenQuery, placeholders...)
   302  		if err != nil {
   303  			if w.config.verbose {
   304  				log.Infof(ctx, "Encountered an error %s while executing the query", err.Error())
   305  				printQueryShortened(ctx, chosenQuery)
   306  			}
   307  			continue
   308  		}
   309  		// TODO(yuzefovich): do we care about the returned rows?
   310  		// Iterate over all rows to simulate regular behavior.
   311  		for rows.Next() {
   312  		}
   313  		rows.Close()
   314  		elapsed := timeutil.Since(start)
   315  		// TODO(yuzefovich): is there a better way to display the results?
   316  		w.hists.Get("").Record(elapsed)
   317  	}
   318  }
   319  
   320  // chooseQuery chooses a random query found in the query log. The queries are
   321  // chosen proportionally likely to the frequency of their occurrence in the
   322  // query log.
   323  func (w *worker) chooseQuery() (chosenQuery string, tableName string) {
   324  	prob := w.rng.Float64()
   325  	count := 0
   326  	for tableIdx, tableName := range w.config.state.tableNames {
   327  		if probInInterval(count, w.config.state.queryCountPerTable[tableIdx], w.config.state.totalQueryCount, prob) {
   328  			countForThisTable := 0
   329  			totalForThisTable := w.config.state.queryCountPerTable[tableIdx]
   330  			for query, frequency := range w.config.state.seenQueriesByTableName[tableIdx] {
   331  				if probInInterval(countForThisTable, frequency, totalForThisTable, prob) {
   332  					return query, tableName
   333  				}
   334  				countForThisTable += frequency
   335  			}
   336  		}
   337  		count += w.config.state.queryCountPerTable[tableIdx]
   338  	}
   339  	// We should've chosen a query in the loop. Just in case we encountered some
   340  	// very unlikely rounding errors, let's return some query here.
   341  	for tableIdx, tableName := range w.config.state.tableNames {
   342  		if w.config.state.queryCountPerTable[tableIdx] > 0 {
   343  			for query := range w.config.state.seenQueriesByTableName[tableIdx] {
   344  				return query, tableName
   345  			}
   346  		}
   347  	}
   348  	panic("no queries were accumulated from the log")
   349  }
   350  
   351  // deduceColumnNamesForPlaceholders attempts to deduce the names of
   352  // corresponding to placeholders columns. numRepeats indicates how many times
   353  // that the same "schema" of columns should be repeated when generating the
   354  // values for placeholders - it is 1 for read queries and can be more than 1
   355  // for write queries.
   356  func (w *worker) deduceColumnNamesForPlaceholders(
   357  	ctx context.Context, query string,
   358  ) (pholdersColumnNames []string, numRepeats int, err error) {
   359  	if isInsertOrUpsert(query) {
   360  		// Write query is chosen. We assume that the format is:
   361  		// INSERT/UPSERT INTO table(col1, col2, ..., coln) VALUES ($1, $2, ..., $n), ($n+1, $n+2, ..., $n+n), ... ($kn+1, $kn+2, ..., $kn+n)
   362  		if !w.reWriteQuery.Match([]byte(query)) {
   363  			return nil, 0, errors.Errorf("Chosen write query didn't match the pattern.")
   364  		}
   365  		submatch := w.reWriteQuery.FindSubmatch([]byte(query))
   366  		columnsIdx := 3
   367  		columnsNames := string(submatch[columnsIdx])
   368  		pholdersColumnNames = strings.FieldsFunc(columnsNames, func(r rune) bool {
   369  			return r == ' ' || r == ','
   370  		})
   371  		lastPholderIdx := strings.LastIndex(query, "$")
   372  		if lastPholderIdx == -1 {
   373  			return nil, 0, errors.Errorf("Unexpected: no placeholders in the write query.")
   374  		}
   375  		multipleValuesGroupIdx := 4
   376  		// numRepeats should equal to the number of value tuples in the query, so
   377  		// we calculate it as the number of '(' seen in the multiple values group
   378  		// in the regular expression (which can be any non-negative integer) plus 1
   379  		// (for the last tuple that must be present).
   380  		numRepeats = 1 + strings.Count(string(submatch[multipleValuesGroupIdx]), "(")
   381  		return pholdersColumnNames, numRepeats, nil
   382  	}
   383  
   384  	// Read query is chosen. We're making best effort to deduce the column name,
   385  	// namely we assume that all placeholders are used in an expression as follows
   386  	// `col <> $1`, i.e. a column name followed by a sign followed by the
   387  	// placeholder.
   388  	pholdersColumnNames = make([]string, 0)
   389  	for i := 1; ; i++ {
   390  		pholder := fmt.Sprintf("$%d", i)
   391  		pholderIdx := strings.Index(query, pholder)
   392  		if pholderIdx == -1 {
   393  			// We have gone over all the placeholders present in the query.
   394  			break
   395  		}
   396  		tokens := strings.Fields(query[:pholderIdx])
   397  		if len(tokens) < 2 {
   398  			return nil, 0, errors.Errorf("assumption that there are at least two tokens before placeholder is wrong")
   399  		}
   400  		column := tokens[len(tokens)-2]
   401  		for column[0] == '(' {
   402  			column = column[1:]
   403  		}
   404  		pholdersColumnNames = append(pholdersColumnNames, column)
   405  	}
   406  	return pholdersColumnNames, 1 /* numRepeats */, nil
   407  }
   408  
   409  // generatePlaceholders populates the values for the placeholders. It utilizes
   410  // two strategies:
   411  // 1. select random samples based on the samples stored in columnInfos.
   412  // 2. generate random values based on the types of the corresponding columns.
   413  // Note: if a deduced column name corresponding to a placeholder is not among
   414  // the columns of the table, we assume that it is of INT type and generate an
   415  // int in [1, 10] range.
   416  func (w *worker) generatePlaceholders(
   417  	ctx context.Context, query string, pholdersColumnNames []string, numRepeats int, tableName string,
   418  ) (placeholders []interface{}, err error) {
   419  	isWriteQuery := isInsertOrUpsert(query)
   420  	placeholders = make([]interface{}, 0, len(pholdersColumnNames)*numRepeats)
   421  	for j := 0; j < numRepeats; j++ {
   422  		for i, column := range pholdersColumnNames {
   423  			columnMatched := false
   424  			actualTableName := true
   425  			if strings.Contains(column, ".") {
   426  				actualTableName = false
   427  				column = strings.Split(column, ".")[1]
   428  			}
   429  			possibleTableNames := make([]string, 0, 1)
   430  			possibleTableNames = append(possibleTableNames, tableName)
   431  			if !actualTableName {
   432  				// column comes from a table that was aliased. In order to not parse
   433  				// the query to figure out actual table name, we simply compare to
   434  				// columns of all used tables giving priority to the table that the
   435  				// query is assigned to.
   436  				for _, n := range w.config.state.tableNames {
   437  					if n == tableName || !w.config.state.tableUsed[n] {
   438  						continue
   439  					}
   440  					possibleTableNames = append(possibleTableNames, n)
   441  				}
   442  			}
   443  			for _, tableName := range possibleTableNames {
   444  				if columnMatched {
   445  					break
   446  				}
   447  				for _, c := range w.config.state.columnsByTableName[tableName] {
   448  					if c.name == column {
   449  						if w.rng.Float64() < w.config.probOfUsingSamples && c.samples != nil && !isWriteQuery {
   450  							// For non-write queries when samples are present, we're using
   451  							// the samples with w.config.probOfUsingSamples probability.
   452  							sampleIdx := w.rng.Intn(len(c.samples))
   453  							placeholders = append(placeholders, c.samples[sampleIdx])
   454  						} else {
   455  							// In all other cases, we generate random values for the
   456  							// placeholders.
   457  							nullPct := 0
   458  							if c.isNullable && w.config.nullPct > 0 {
   459  								nullPct = 100 / w.config.nullPct
   460  							}
   461  							d := sqlbase.RandDatumWithNullChance(w.rng, c.dataType, nullPct)
   462  							if i, ok := d.(*tree.DInt); ok && c.intRange > 0 {
   463  								j := int64(*i) % int64(c.intRange/2)
   464  								d = tree.NewDInt(tree.DInt(j))
   465  							}
   466  							p, err := workloadrand.DatumToGoSQL(d)
   467  							if err != nil {
   468  								return nil, err
   469  							}
   470  							placeholders = append(placeholders, p)
   471  						}
   472  						columnMatched = true
   473  						break
   474  					}
   475  				}
   476  			}
   477  			if !columnMatched {
   478  				d := w.rng.Int31n(10) + 1
   479  				if w.config.verbose {
   480  					log.Infof(ctx, "Couldn't deduce the corresponding to $%d, so generated %d (a small int)", i+1, d)
   481  					printQueryShortened(ctx, query)
   482  				}
   483  				p, err := workloadrand.DatumToGoSQL(tree.NewDInt(tree.DInt(d)))
   484  				if err != nil {
   485  					return nil, err
   486  				}
   487  				placeholders = append(placeholders, p)
   488  			}
   489  		}
   490  	}
   491  	return placeholders, nil
   492  }
   493  
   494  // getTableNames fetches the names of all the tables in db and stores them in
   495  // w.state.
   496  func (w *querylog) getTableNames(db *gosql.DB) error {
   497  	rows, err := db.Query(`SELECT table_name FROM [SHOW TABLES] ORDER BY table_name`)
   498  	if err != nil {
   499  		return err
   500  	}
   501  	defer rows.Close()
   502  	w.state.tableNames = make([]string, 0)
   503  	for rows.Next() {
   504  		var tableName string
   505  		if err = rows.Scan(&tableName); err != nil {
   506  			return err
   507  		}
   508  		w.state.tableNames = append(w.state.tableNames, tableName)
   509  	}
   510  	return nil
   511  }
   512  
   513  // unzip unzips the zip file at src into dest. It was copied (with slight
   514  // modifications) from
   515  // https://stackoverflow.com/questions/20357223/easy-way-to-unzip-file-with-golang.
   516  func unzip(src, dest string) error {
   517  	r, err := zip.OpenReader(src)
   518  	if err != nil {
   519  		return err
   520  	}
   521  	defer func() {
   522  		if err := r.Close(); err != nil {
   523  			panic(err)
   524  		}
   525  	}()
   526  
   527  	if err = os.MkdirAll(dest, 0755); err != nil {
   528  		return err
   529  	}
   530  
   531  	// Closure to address file descriptors issue with all the deferred .Close()
   532  	// methods.
   533  	extractAndWriteFile := func(f *zip.File) error {
   534  		rc, err := f.Open()
   535  		if err != nil {
   536  			return err
   537  		}
   538  		defer func() {
   539  			if err := rc.Close(); err != nil {
   540  				panic(err)
   541  			}
   542  		}()
   543  
   544  		path := filepath.Join(dest, f.Name)
   545  		// Check for ZipSlip. More Info: http://bit.ly/2MsjAWE
   546  		if !strings.HasPrefix(path, filepath.Clean(dest)+string(os.PathSeparator)) {
   547  			return errors.Errorf("%s: illegal file path while extracting the zip. "+
   548  				"Such a file path can be dangerous because of ZipSlip vulnerability. "+
   549  				"Please reconsider whether the zip file is trustworthy.", path)
   550  		}
   551  
   552  		if f.FileInfo().IsDir() {
   553  			if err = os.MkdirAll(path, f.Mode()); err != nil {
   554  				return err
   555  			}
   556  		} else {
   557  			if err = os.MkdirAll(filepath.Dir(path), f.Mode()); err != nil {
   558  				return err
   559  			}
   560  			f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
   561  			if err != nil {
   562  				return err
   563  			}
   564  			defer func() {
   565  				if err := f.Close(); err != nil {
   566  					panic(err)
   567  				}
   568  			}()
   569  
   570  			_, err = io.Copy(f, rc)
   571  			if err != nil {
   572  				return err
   573  			}
   574  		}
   575  		return nil
   576  	}
   577  
   578  	for _, f := range r.File {
   579  		err := extractAndWriteFile(f)
   580  		if err != nil {
   581  			return err
   582  		}
   583  	}
   584  
   585  	return nil
   586  }
   587  
   588  const (
   589  	// omitted prefix + "(query)" + \s + (placeholders) + omitted suffix
   590  	regexQueryLogFormat = `[^\s]+\s[^\s]+\s[^\s]+\s[^\s]+\s{2}[^\s]+\s[^\s]+\s[^\s]+\s".*"\s\{.*\}\s` +
   591  		`"(.*)"` + `\s` + `(\{.*\})` + `\s[^\s]+\s[\d]+\s[^\s]+`
   592  
   593  	regexQueryGroupIdx = 1
   594  	//regexPlaceholdersGroupIdx = 2
   595  	queryLogHeaderLines = 5
   596  
   597  	regexWriteQueryPattern = `(UPSERT|INSERT)\sINTO\s([^\s]+)\((.*)\)\sVALUES\s(\(.*\)\,\s)*(\(.*\))`
   598  )
   599  
   600  // parseFile parses the file one line at a time. First queryLogHeaderLines are
   601  // skipped, and all other lines are expected to match the pattern of re.
   602  func (w *querylog) parseFile(ctx context.Context, fileInfo os.FileInfo, re *regexp.Regexp) error {
   603  	start := timeutil.Now()
   604  	file, err := os.Open(w.dirPath + "/" + fileInfo.Name())
   605  	if err != nil {
   606  		return err
   607  	}
   608  
   609  	reader := bufio.NewReader(file)
   610  	buffer := make([]byte, 0)
   611  	lineCount := 0
   612  	for {
   613  		line, isPrefix, err := reader.ReadLine()
   614  		if err != nil {
   615  			if err != io.EOF {
   616  				return err
   617  			}
   618  			// We reached EOF, so we're done with this file.
   619  			end := timeutil.Now()
   620  			elapsed := end.Sub(start)
   621  			log.Infof(ctx, "Processing of %s is done in %fs", fileInfo.Name(), elapsed.Seconds())
   622  			break
   623  		}
   624  
   625  		buffer = append(buffer, line...)
   626  		if !isPrefix {
   627  			// We read the full line, so let's process it.
   628  			if lineCount >= queryLogHeaderLines {
   629  				// First queryLogHeaderLines lines of the describe the format of the log
   630  				// file, so we skip them.
   631  				if !re.Match(buffer) {
   632  					return errors.Errorf("Line %d doesn't match the pattern", lineCount)
   633  				}
   634  				// regexQueryLogFormat contains two regexp groups (patterns that are
   635  				// in grouping parenthesis) - the query (without quotes) and actual
   636  				// placeholder values. At the moment, we're only interested in the
   637  				// former.
   638  				groups := re.FindSubmatch(buffer)
   639  				query := string(groups[regexQueryGroupIdx])
   640  				isWriteQuery := isInsertOrUpsert(query)
   641  				skipQuery := strings.HasPrefix(query, "ALTER") ||
   642  					(w.omitWriteQueries && isWriteQuery) ||
   643  					(w.omitDeleteQueries && strings.HasPrefix(query, "DELETE"))
   644  				if !skipQuery {
   645  					tableAssigned := false
   646  					for i, tableName := range w.state.tableNames {
   647  						// TODO(yuzefovich): this simplistic matching doesn't work in all
   648  						// cases.
   649  						if strings.Contains(query, tableName) {
   650  							if !tableAssigned {
   651  								// The query is assigned to one table, namely the one that
   652  								// showed up earlier in `SHOW TABLES`.
   653  								w.state.seenQueriesByTableName[i][query]++
   654  								w.state.queryCountPerTable[i]++
   655  								w.state.totalQueryCount++
   656  								tableAssigned = true
   657  							}
   658  							w.state.tableUsed[tableName] = true
   659  						}
   660  					}
   661  					if !tableAssigned {
   662  						return errors.Errorf("No table matched query %s while processing %s", query, fileInfo.Name())
   663  					}
   664  				}
   665  			}
   666  			buffer = buffer[:0]
   667  			lineCount++
   668  		}
   669  	}
   670  	return file.Close()
   671  }
   672  
   673  // processQueryLog parses the query log and populates most of w.state.
   674  func (w *querylog) processQueryLog(ctx context.Context) error {
   675  	if w.zipPath != "" {
   676  		log.Infof(ctx, "About to start unzipping %s", w.zipPath)
   677  		w.dirPath = w.zipPath[:len(w.zipPath)-4]
   678  		if err := unzip(w.zipPath, w.dirPath); err != nil {
   679  			return err
   680  		}
   681  		log.Infof(ctx, "Unzipping to %s is complete", w.dirPath)
   682  	}
   683  
   684  	files, err := ioutil.ReadDir(w.dirPath)
   685  	if err != nil {
   686  		return err
   687  	}
   688  
   689  	w.state.queryCountPerTable = make([]int, len(w.state.tableNames))
   690  	w.state.seenQueriesByTableName = make([]map[string]int, len(w.state.tableNames))
   691  	for i := range w.state.tableNames {
   692  		w.state.seenQueriesByTableName[i] = make(map[string]int)
   693  	}
   694  	w.state.tableUsed = make(map[string]bool)
   695  
   696  	re := regexp.MustCompile(regexQueryLogFormat)
   697  	log.Infof(ctx, "Starting to parse the query log")
   698  	numFiles := w.filesToParse
   699  	if numFiles > len(files) {
   700  		numFiles = len(files)
   701  	}
   702  	for fileNum, fileInfo := range files {
   703  		if fileNum == w.filesToParse {
   704  			// We have reached the desired number of files to parse, so we don't
   705  			// parse the remaining files.
   706  			break
   707  		}
   708  		if fileInfo.IsDir() {
   709  			if w.verbose {
   710  				log.Infof(ctx, "Unexpected: a directory %s is encountered with the query log, skipping it.", fileInfo.Name())
   711  			}
   712  			continue
   713  		}
   714  		log.Infof(ctx, "Processing %d out of %d", fileNum, numFiles)
   715  		if err = w.parseFile(ctx, fileInfo, re); err != nil {
   716  			return err
   717  		}
   718  	}
   719  	log.Infof(ctx, "Query log processed")
   720  	if w.zipPath != "" {
   721  		log.Infof(ctx, "Unzipped files are about to be removed")
   722  		return os.RemoveAll(w.dirPath)
   723  	}
   724  	return nil
   725  }
   726  
   727  // getColumnsInfo populates the information about the columns of the tables
   728  // that at least one query was issued against.
   729  func (w *querylog) getColumnsInfo(db *gosql.DB) error {
   730  	w.state.columnsByTableName = make(map[string][]columnInfo)
   731  	for _, tableName := range w.state.tableNames {
   732  		if !w.state.tableUsed[tableName] {
   733  			// There were no queries operating on this table, so no query will be
   734  			// generated against this table as well, and we don't need the
   735  			// information about the columns.
   736  			continue
   737  		}
   738  
   739  		// columnTypeByColumnName is used only to distinguish between
   740  		// INT2/INT4/INT8 because otherwise they are mapped to the same INT type.
   741  		columnTypeByColumnName := make(map[string]string)
   742  		rows, err := db.Query(fmt.Sprintf("SELECT column_name, data_type FROM [SHOW COLUMNS FROM %s]", tableName))
   743  		if err != nil {
   744  			return err
   745  		}
   746  		for rows.Next() {
   747  			var columnName, dataType string
   748  			if err = rows.Scan(&columnName, &dataType); err != nil {
   749  				return err
   750  			}
   751  			columnTypeByColumnName[columnName] = dataType
   752  		}
   753  		if err = rows.Close(); err != nil {
   754  			return err
   755  		}
   756  
   757  		// This schema introspection was copied from workload/rand.go (with slight
   758  		// modifications).
   759  		// TODO(yuzefovich): probably we need to extract it.
   760  		var relid int
   761  		if err := db.QueryRow(fmt.Sprintf("SELECT '%s'::REGCLASS::OID", tableName)).Scan(&relid); err != nil {
   762  			return err
   763  		}
   764  		rows, err = db.Query(
   765  			`
   766  SELECT attname, atttypid, adsrc, NOT attnotnull
   767  FROM pg_catalog.pg_attribute
   768  LEFT JOIN pg_catalog.pg_attrdef
   769  ON attrelid=adrelid AND attnum=adnum
   770  WHERE attrelid=$1`, relid)
   771  		if err != nil {
   772  			return err
   773  		}
   774  
   775  		var cols []columnInfo
   776  		var numCols = 0
   777  
   778  		defer rows.Close()
   779  		for rows.Next() {
   780  			var c columnInfo
   781  			c.dataPrecision = 0
   782  			c.dataScale = 0
   783  
   784  			var typOid int
   785  			if err := rows.Scan(&c.name, &typOid, &c.cdefault, &c.isNullable); err != nil {
   786  				return err
   787  			}
   788  			c.dataType = types.OidToType[oid.Oid(typOid)]
   789  			if c.dataType.Family() == types.IntFamily {
   790  				actualType := columnTypeByColumnName[c.name]
   791  				if actualType == `INT2` {
   792  					c.intRange = 1 << 16
   793  				} else if actualType == `INT4` {
   794  					c.intRange = 1 << 32
   795  				}
   796  			}
   797  			cols = append(cols, c)
   798  			numCols++
   799  		}
   800  
   801  		if numCols == 0 {
   802  			return errors.Errorf("no columns detected")
   803  		}
   804  		w.state.columnsByTableName[tableName] = cols
   805  	}
   806  	return nil
   807  }
   808  
   809  // populateSamples selects at most w.numSamples of samples from each table that
   810  // at least one query was issued against the query log. The samples are stored
   811  // inside corresponding to the table columnInfo.
   812  func (w *querylog) populateSamples(ctx context.Context, db *gosql.DB) error {
   813  	log.Infof(ctx, "Populating samples started")
   814  	for _, tableName := range w.state.tableNames {
   815  		cols := w.state.columnsByTableName[tableName]
   816  		if cols == nil {
   817  			// There were no queries touching this table, so we skip it.
   818  			continue
   819  		}
   820  		log.Infof(ctx, "Sampling %s", tableName)
   821  		count, err := db.Query(fmt.Sprintf(`SELECT count(*) FROM %s`, tableName))
   822  		if err != nil {
   823  			return err
   824  		}
   825  		count.Next()
   826  		var numRows int
   827  		if err = count.Scan(&numRows); err != nil {
   828  			return err
   829  		}
   830  		if err = count.Close(); err != nil {
   831  			return err
   832  		}
   833  
   834  		// To ensure that samples correspond to the appropriate columns, we fix
   835  		// the order of columns (same to the order in `SHOW COLUMNS` query in
   836  		// getColumnInfo).
   837  		columnNames := make([]string, len(cols))
   838  		for i := range cols {
   839  			columnNames[i] = cols[i].name
   840  		}
   841  		columnsOrdered := strings.Join(columnNames, ", ")
   842  
   843  		var samples *gosql.Rows
   844  		if w.numSamples > numRows {
   845  			samples, err = db.Query(
   846  				fmt.Sprintf(`SELECT %s FROM %s`, columnsOrdered, tableName))
   847  		} else {
   848  			samplingProb := float64(w.numSamples) / float64(numRows)
   849  			if samplingProb < w.minSamplingProb {
   850  				// To speed up the query.
   851  				samplingProb = w.minSamplingProb
   852  			}
   853  			samples, err = db.Query(fmt.Sprintf(`SELECT %s FROM %s WHERE random() < %f LIMIT %d`,
   854  				columnsOrdered, tableName, samplingProb, w.numSamples))
   855  		}
   856  
   857  		if err != nil {
   858  			return err
   859  		}
   860  		for samples.Next() {
   861  			rowOfSamples := make([]interface{}, len(cols))
   862  			for i := range rowOfSamples {
   863  				rowOfSamples[i] = new(interface{})
   864  			}
   865  			if err := samples.Scan(rowOfSamples...); err != nil {
   866  				return err
   867  			}
   868  			for i, sample := range rowOfSamples {
   869  				cols[i].samples = append(cols[i].samples, sample)
   870  			}
   871  		}
   872  		if err = samples.Close(); err != nil {
   873  			return err
   874  		}
   875  	}
   876  	log.Infof(ctx, "Populating samples is complete")
   877  	return nil
   878  }
   879  
   880  // querybenchRun is the main function of a querybench worker. It is run only
   881  // when querylog is used in querybench mode (i.e. querybench-path argument was
   882  // specified). It chooses a query, attempts to generate values for the
   883  // placeholders, and - instead of executing the query - writes it into the
   884  // requested file in format that querybench understands.
   885  func (w *worker) querybenchRun(ctx context.Context) error {
   886  	file, err := os.Create(w.querybenchPath)
   887  	if err != nil {
   888  		return err
   889  	}
   890  	defer file.Close()
   891  
   892  	// We will skip all queries for which the placeholder values contain `$1`
   893  	// and alike to make it easier to replace actual placeholders.
   894  	reToAvoid := regexp.MustCompile(`\$[0-9]+`)
   895  	writer := bufio.NewWriter(file)
   896  	queryCount := uint(0)
   897  	for {
   898  		chosenQuery, tableName := w.chooseQuery()
   899  		pholdersColumnNames, numRepeats, err := w.deduceColumnNamesForPlaceholders(ctx, chosenQuery)
   900  		if err != nil {
   901  			if w.config.verbose {
   902  				log.Infof(ctx, "Encountered an error %s while deducing column names corresponding to the placeholders", err.Error())
   903  				printQueryShortened(ctx, chosenQuery)
   904  			}
   905  			continue
   906  		}
   907  
   908  		placeholders, err := w.generatePlaceholders(ctx, chosenQuery, pholdersColumnNames, numRepeats, tableName)
   909  		if err != nil {
   910  			if w.config.verbose {
   911  				log.Infof(ctx, "Encountered an error %s while generating values for the placeholders", err.Error())
   912  				printQueryShortened(ctx, chosenQuery)
   913  			}
   914  			continue
   915  		}
   916  
   917  		query := chosenQuery
   918  		skipQuery := false
   919  		// We're iterating over placeholders in reverse order so that `$1` does not
   920  		// replace the first two characters of `$10`.
   921  		for i := len(placeholders) - 1; i >= 0; i-- {
   922  			pholderString := fmt.Sprintf("$%d", i+1)
   923  			if !strings.Contains(chosenQuery, pholderString) {
   924  				skipQuery = true
   925  				break
   926  			}
   927  			pholderValue := printPlaceholder(placeholders[i])
   928  			if reToAvoid.MatchString(pholderValue) {
   929  				skipQuery = true
   930  				break
   931  			}
   932  			query = strings.Replace(query, pholderString, pholderValue, -1)
   933  		}
   934  		if skipQuery {
   935  			if w.config.verbose {
   936  				log.Infof(ctx, "Could not replace placeholders with values on query")
   937  				printQueryShortened(ctx, chosenQuery)
   938  			}
   939  			continue
   940  		}
   941  		if _, err = writer.WriteString(query + "\n\n"); err != nil {
   942  			return err
   943  		}
   944  
   945  		queryCount++
   946  		if queryCount%250 == 0 {
   947  			log.Infof(ctx, "%d queries have been written", queryCount)
   948  		}
   949  		if queryCount == w.config.count {
   950  			writer.Flush()
   951  			return nil
   952  		}
   953  	}
   954  }
   955  
   956  func printPlaceholder(i interface{}) string {
   957  	if ptr, ok := i.(*interface{}); ok {
   958  		return printPlaceholder(*ptr)
   959  	}
   960  	switch p := i.(type) {
   961  	case bool:
   962  		return fmt.Sprintf("%v", p)
   963  	case int64:
   964  		return fmt.Sprintf("%d", p)
   965  	case float64:
   966  		return fmt.Sprintf("%f", p)
   967  	case []uint8:
   968  		u, err := uuid.FromString(string(p))
   969  		if err != nil {
   970  			panic(err)
   971  		}
   972  		return fmt.Sprintf("'%s'", u.String())
   973  	case uuid.UUID:
   974  		return fmt.Sprintf("'%s'", p.String())
   975  	case string:
   976  		s := strings.Replace(p, "'", "''", -1)
   977  		// querybench assumes that each query is on a single line, so we remove
   978  		// line breaks.
   979  		s = strings.Replace(s, "\n", "", -1)
   980  		s = strings.Replace(s, "\r", "", -1)
   981  		return fmt.Sprintf("'%s'", s)
   982  	case time.Time:
   983  		timestamp := p.String()
   984  		// timestamp can be of the format '1970-01-09 00:14:01.000812453 +0000 UTC'
   985  		// or '2019-02-26 20:52:01.65434 +0000 +0000', and the parser complains
   986  		// that it could not parse it, so we remove all stuff that comes after
   987  		// first +0000 to make the parser happy.
   988  		idx := strings.Index(timestamp, `+0000`)
   989  		timestamp = timestamp[:idx+5]
   990  		return fmt.Sprintf("'%s':::TIMESTAMP", timestamp)
   991  	case nil:
   992  		return fmt.Sprintf("NULL")
   993  	default:
   994  		panic(fmt.Sprintf("unsupported type: %T", i))
   995  	}
   996  }
   997  
   998  // TODO(yuzefovich): columnInfo is copied from workload/rand package and
   999  // extended. Should we export workload/rand.col?
  1000  type columnInfo struct {
  1001  	name          string
  1002  	dataType      *types.T
  1003  	dataPrecision int
  1004  	dataScale     int
  1005  	cdefault      gosql.NullString
  1006  	isNullable    bool
  1007  
  1008  	intRange uint64 // To distinguish between INT2, INT4, and INT8.
  1009  	samples  []interface{}
  1010  }
  1011  
  1012  func printQueryShortened(ctx context.Context, query string) {
  1013  	if len(query) > 1000 {
  1014  		log.Infof(ctx, "%s...%s", query[:500], query[len(query)-500:])
  1015  	} else {
  1016  		log.Infof(ctx, "%s", query)
  1017  	}
  1018  }
  1019  
  1020  func isInsertOrUpsert(query string) bool {
  1021  	return strings.HasPrefix(query, "INSERT") || strings.HasPrefix(query, "UPSERT")
  1022  }
  1023  
  1024  func probInInterval(start, inc, total int, p float64) bool {
  1025  	return float64(start)/float64(total) <= p &&
  1026  		float64(start+inc)/float64(total) > p
  1027  }