
     1  // Copyright 2016 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     5  // Package db provides the high-level database interface for the
     6  // storage app.
     7  package db
     9  import (
    10  	"bytes"
    11  	"database/sql"
    12  	"fmt"
    13  	"io"
    14  	"regexp"
    15  	"sort"
    16  	"strconv"
    17  	"strings"
    18  	"text/template"
    19  	"time"
    21  	""
    22  	""
    23  	""
    24  	""
    25  )
    27  // TODO(quentin): Add Context to every function when App Engine supports Go >=1.8.
    29  // DB is a high-level interface to a database for the storage
    30  // app. It's safe for concurrent use by multiple goroutines.
    31  type DB struct {
    32  	sql        *sql.DB // underlying database connection
    33  	driverName string  // name of underlying driver for SQL differences
    34  	// prepared statements
    35  	lastUpload    *sql.Stmt
    36  	insertUpload  *sql.Stmt
    37  	checkUpload   *sql.Stmt
    38  	deleteRecords *sql.Stmt
    39  }
    41  // OpenSQL creates a DB backed by a SQL database. The parameters are
    42  // the same as the parameters for sql.Open. Only mysql and sqlite3 are
    43  // explicitly supported; other database engines will receive MySQL
    44  // query syntax which may or may not be compatible.
    45  func OpenSQL(driverName, dataSourceName string) (*DB, error) {
    46  	db, err := sql.Open(driverName, dataSourceName)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	if hook := openHooks[driverName]; hook != nil {
    51  		if err := hook(db); err != nil {
    52  			return nil, err
    53  		}
    54  	}
    55  	d := &DB{sql: db, driverName: driverName}
    56  	if err := d.createTables(driverName); err != nil {
    57  		return nil, err
    58  	}
    59  	if err := d.prepareStatements(driverName); err != nil {
    60  		return nil, err
    61  	}
    62  	return d, nil
    63  }
    65  var openHooks = make(map[string]func(*sql.DB) error)
    67  // RegisterOpenHook registers a hook to be called after opening a connection to driverName.
    68  // This is used by the sqlite3 package to register a ConnectHook.
    69  // It must be called from an init function.
    70  func RegisterOpenHook(driverName string, hook func(*sql.DB) error) {
    71  	openHooks[driverName] = hook
    72  }
    74  // createTmpl is the template used to prepare the CREATE statements
    75  // for the database. It is evaluated with . as a map containing one
    76  // entry whose key is the driver name.
    77  var createTmpl = template.Must(template.New("create").Parse(`
    79  	UploadID VARCHAR(20) PRIMARY KEY,
    80  	Day VARCHAR(8),
    82  {{if not .sqlite3}}
    83  	, Index (Day, Seq)
    84  {{end}}
    85  );
    86  {{if .sqlite3}}
    87  CREATE INDEX IF NOT EXISTS UploadDaySeq ON Uploads(Day, Seq);
    88  {{end}}
    90  	UploadID VARCHAR(20) NOT NULL,
    92  	Content BLOB NOT NULL,
    93  	PRIMARY KEY (UploadID, RecordID),
    95  );
    96  CREATE TABLE IF NOT EXISTS RecordLabels (
    97  	UploadID VARCHAR(20) NOT NULL,
    99  	Name VARCHAR(255) NOT NULL,
   100  	Value VARCHAR(8192) NOT NULL,
   101  {{if not .sqlite3}}
   102  	Index (Name(100), Value(100)),
   103  {{end}}
   104  	PRIMARY KEY (UploadID, RecordID, Name),
   106  );
   107  {{if .sqlite3}}
   108  CREATE INDEX IF NOT EXISTS RecordLabelsNameValue ON RecordLabels(Name, Value);
   109  {{end}}
   110  `))
   112  // createTables creates any missing tables on the connection in
   113  // db.sql. driverName is the same driver name passed to sql.Open and
   114  // is used to select the correct syntax.
   115  func (db *DB) createTables(driverName string) error {
   116  	var buf bytes.Buffer
   117  	if err := createTmpl.Execute(&buf, map[string]bool{driverName: true}); err != nil {
   118  		return err
   119  	}
   120  	for _, q := range strings.Split(buf.String(), ";") {
   121  		if strings.TrimSpace(q) == "" {
   122  			continue
   123  		}
   124  		if _, err := db.sql.Exec(q); err != nil {
   125  			return fmt.Errorf("create table: %v", err)
   126  		}
   127  	}
   128  	return nil
   129  }
   131  // prepareStatements calls db.sql.Prepare on reusable SQL statements.
   132  func (db *DB) prepareStatements(driverName string) error {
   133  	var err error
   134  	query := "SELECT UploadID FROM Uploads ORDER BY Day DESC, Seq DESC LIMIT 1"
   135  	if driverName != "sqlite3" {
   136  		query += " FOR UPDATE"
   137  	}
   138  	db.lastUpload, err = db.sql.Prepare(query)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	db.insertUpload, err = db.sql.Prepare("INSERT INTO Uploads(UploadID, Day, Seq) VALUES (?, ?, ?)")
   143  	if err != nil {
   144  		return err
   145  	}
   146  	db.checkUpload, err = db.sql.Prepare("SELECT 1 FROM Uploads WHERE UploadID = ?")
   147  	if err != nil {
   148  		return err
   149  	}
   150  	db.deleteRecords, err = db.sql.Prepare("DELETE FROM Records WHERE UploadID = ?")
   151  	if err != nil {
   152  		return err
   153  	}
   154  	return nil
   155  }
   157  // An Upload is a collection of files that share an upload ID.
   158  type Upload struct {
   159  	// ID is the value of the "upload" key that should be
   160  	// associated with every record in this upload.
   161  	ID string
   163  	// recordid is the index of the next record to insert.
   164  	recordid int64
   165  	// db is the underlying database that this upload is going to.
   166  	db *DB
   167  	// tx is the transaction used by the upload.
   168  	tx *sql.Tx
   170  	// pending arguments for flush
   171  	insertRecordArgs []interface{}
   172  	insertLabelArgs  []interface{}
   173  	lastResult       *benchfmt.Result
   174  }
   176  // now is a hook for testing
   177  var now = time.Now
   179  // ReplaceUpload removes the records associated with id if any and
   180  // allows insertion of new records.
   181  func (db *DB) ReplaceUpload(id string) (*Upload, error) {
   182  	if _, err := db.deleteRecords.Exec(id); err != nil {
   183  		return nil, err
   184  	}
   185  	var found bool
   186  	err := db.checkUpload.QueryRow(id).Scan(&found)
   187  	switch err {
   188  	case sql.ErrNoRows:
   189  		var day sql.NullString
   190  		var num sql.NullInt64
   191  		if m := regexp.MustCompile(`^(\d+)\.(\d+)$`).FindStringSubmatch(id); m != nil {
   192  			day.Valid, num.Valid = true, true
   193  			day.String = m[1]
   194  			num.Int64, _ = strconv.ParseInt(m[2], 10, 64)
   195  		}
   196  		if _, err := db.insertUpload.Exec(id, day, num); err != nil {
   197  			return nil, err
   198  		}
   199  	case nil:
   200  	default:
   201  		return nil, err
   202  	}
   203  	tx, err := db.sql.Begin()
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	u := &Upload{
   208  		ID: id,
   209  		db: db,
   210  		tx: tx,
   211  	}
   212  	return u, nil
   213  }
   215  // NewUpload returns an upload for storing new files.
   216  // All records written to the Upload will have the same upload ID.
   217  func (db *DB) NewUpload(ctx context.Context) (*Upload, error) {
   218  	day := now().UTC().Format("20060102")
   220  	num := 0
   222  	tx, err := db.sql.Begin()
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	defer func() {
   227  		if tx != nil {
   228  			tx.Rollback()
   229  		}
   230  	}()
   231  	var lastID string
   232  	err = tx.Stmt(db.lastUpload).QueryRow().Scan(&lastID)
   233  	switch err {
   234  	case sql.ErrNoRows:
   235  	case nil:
   236  		if strings.HasPrefix(lastID, day) {
   237  			num, err = strconv.Atoi(lastID[len(day)+1:])
   238  			if err != nil {
   239  				return nil, err
   240  			}
   241  		}
   242  	default:
   243  		return nil, err
   244  	}
   246  	num++
   248  	id := fmt.Sprintf("%s.%d", day, num)
   250  	_, err = tx.Stmt(db.insertUpload).Exec(id, day, num)
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  	if err := tx.Commit(); err != nil {
   255  		return nil, err
   256  	}
   257  	tx = nil
   259  	utx, err := db.sql.Begin()
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	u := &Upload{
   264  		ID: id,
   265  		db: db,
   266  		tx: utx,
   267  	}
   268  	return u, nil
   269  }
   271  // InsertRecord inserts a single record in an existing upload.
   272  // If InsertRecord returns a non-nil error, the Upload has failed and u.Abort() must be called.
   273  func (u *Upload) InsertRecord(r *benchfmt.Result) error {
   274  	if u.lastResult != nil && u.lastResult.SameLabels(r) {
   275  		data := u.insertRecordArgs[len(u.insertRecordArgs)-1].([]byte)
   276  		data = append(data, r.Content...)
   277  		data = append(data, '\n')
   278  		u.insertRecordArgs[len(u.insertRecordArgs)-1] = data
   279  		return nil
   280  	}
   281  	// TODO(quentin): Support multiple lines (slice of results?)
   282  	var buf bytes.Buffer
   283  	if err := benchfmt.NewPrinter(&buf).Print(r); err != nil {
   284  		return err
   285  	}
   286  	u.lastResult = r
   287  	u.insertRecordArgs = append(u.insertRecordArgs, u.ID, u.recordid, buf.Bytes())
   288  	for _, k := range r.Labels.Keys() {
   289  		if err := u.insertLabel(k, r.Labels[k]); err != nil {
   290  			return err
   291  		}
   292  	}
   293  	for _, k := range r.NameLabels.Keys() {
   294  		if err := u.insertLabel(k, r.NameLabels[k]); err != nil {
   295  			return err
   296  		}
   297  	}
   298  	u.recordid++
   300  	return nil
   301  }
   303  // insertLabel queues a label pair for insertion.
   304  // If there are enough labels queued, flush is called.
   305  func (u *Upload) insertLabel(key, value string) error {
   306  	// N.B. sqlite3 has a max of 999 arguments.
   307  	//
   308  	if len(u.insertLabelArgs) >= 990 {
   309  		if err := u.flush(); err != nil {
   310  			return err
   311  		}
   312  	}
   313  	u.insertLabelArgs = append(u.insertLabelArgs, u.ID, u.recordid, key, value)
   314  	return nil
   315  }
   317  // repeatDelim returns a string consisting of n copies of s with delim between each copy.
   318  func repeatDelim(s, delim string, n int) string {
   319  	return strings.TrimSuffix(strings.Repeat(s+delim, n), delim)
   320  }
   322  // insertMultiple executes a single INSERT statement to insert multiple rows.
   323  func insertMultiple(tx *sql.Tx, sqlPrefix string, argsPerRow int, args []interface{}) error {
   324  	if len(args) == 0 {
   325  		return nil
   326  	}
   327  	query := sqlPrefix + repeatDelim("("+repeatDelim("?", ", ", argsPerRow)+")", ", ", len(args)/argsPerRow)
   328  	_, err := tx.Exec(query, args...)
   329  	return err
   330  }
   332  // flush sends INSERT statements for any pending data in u.insertRecordArgs and u.insertLabelArgs.
   333  func (u *Upload) flush() error {
   334  	if n := len(u.insertRecordArgs); n > 0 {
   335  		if err := insertMultiple(u.tx, "INSERT INTO Records(UploadID, RecordID, Content) VALUES ", 3, u.insertRecordArgs); err != nil {
   336  			return err
   337  		}
   338  		u.insertRecordArgs = nil
   339  	}
   340  	if n := len(u.insertLabelArgs); n > 0 {
   341  		if err := insertMultiple(u.tx, "INSERT INTO RecordLabels VALUES ", 4, u.insertLabelArgs); err != nil {
   342  			return err
   343  		}
   344  		u.insertLabelArgs = nil
   345  	}
   346  	u.lastResult = nil
   347  	return nil
   348  }
   350  // Commit finishes processing the upload.
   351  func (u *Upload) Commit() error {
   352  	if err := u.flush(); err != nil {
   353  		return err
   354  	}
   355  	return u.tx.Commit()
   356  }
   358  // Abort cleans up resources associated with the upload.
   359  // It does not attempt to clean up partial database state.
   360  func (u *Upload) Abort() error {
   361  	return u.tx.Rollback()
   362  }
   364  // parseQuery parses a query into a slice of SQL subselects and a slice of arguments.
   365  // The subselects must be joined with INNER JOIN in the order returned.
   366  func parseQuery(q string) (sql []string, args []interface{}, err error) {
   367  	var keys []string
   368  	parts := make(map[string]part)
   369  	for _, word := range query.SplitWords(q) {
   370  		p, err := parseWord(word)
   371  		if err != nil {
   372  			return nil, nil, err
   373  		}
   374  		if _, ok := parts[p.key]; ok {
   375  			parts[p.key], err = parts[p.key].merge(p)
   376  			if err != nil {
   377  				return nil, nil, err
   378  			}
   379  		} else {
   380  			keys = append(keys, p.key)
   381  			parts[p.key] = p
   382  		}
   383  	}
   384  	// Process each key
   385  	sort.Strings(keys)
   386  	for _, key := range keys {
   387  		s, a, err := parts[key].sql()
   388  		if err != nil {
   389  			return nil, nil, err
   390  		}
   391  		sql = append(sql, s)
   392  		args = append(args, a...)
   393  	}
   394  	return
   395  }
   397  // Query searches for results matching the given query string.
   398  //
   399  // The query string is first parsed into quoted words (as in the shell)
   400  // and then each word must be formatted as one of the following:
   401  // key:value - exact match on label "key" = "value"
   402  // key>value - value greater than (useful for dates)
   403  // key<value - value less than (also useful for dates)
   404  func (db *DB) Query(q string) *Query {
   405  	ret := &Query{q: q}
   407  	query := "SELECT r.Content FROM "
   409  	sql, args, err := parseQuery(q)
   410  	if err != nil {
   411  		ret.err = err
   412  		return ret
   413  	}
   414  	for i, part := range sql {
   415  		if i > 0 {
   416  			query += " INNER JOIN "
   417  		}
   418  		query += fmt.Sprintf("(%s) t%d", part, i)
   419  		if i > 0 {
   420  			query += " USING (UploadID, RecordID)"
   421  		}
   422  	}
   424  	if len(sql) > 0 {
   425  		query += " LEFT JOIN"
   426  	}
   427  	query += " Records r"
   428  	if len(sql) > 0 {
   429  		query += " USING (UploadID, RecordID)"
   430  	}
   432  	ret.sqlQuery, ret.sqlArgs = query, args
   433  	ret.rows, ret.err = db.sql.Query(query, args...)
   434  	return ret
   435  }
   437  // Query is the result of a query.
   438  // Use Next to advance through the rows, making sure to call Close when done:
   439  //
   440  //	q := db.Query("key:value")
   441  //	defer q.Close()
   442  //	for q.Next() {
   443  //	  res := q.Result()
   444  //	  ...
   445  //	}
   446  //	err = q.Err() // get any error encountered during iteration
   447  //	...
   448  type Query struct {
   449  	rows *sql.Rows
   450  	// for Debug
   451  	q        string
   452  	sqlQuery string
   453  	sqlArgs  []interface{}
   454  	// from last call to Next
   455  	br  *benchfmt.Reader
   456  	err error
   457  }
   459  // Debug returns the human-readable state of the query.
   460  func (q *Query) Debug() string {
   461  	ret := fmt.Sprintf("q=%q", q.q)
   462  	if q.sqlQuery != "" || len(q.sqlArgs) > 0 {
   463  		ret += fmt.Sprintf(" sql={%q %#v}", q.sqlQuery, q.sqlArgs)
   464  	}
   465  	if q.err != nil {
   466  		ret += fmt.Sprintf(" err=%v", q.err)
   467  	}
   468  	return ret
   469  }
   471  // Next prepares the next result for reading with the Result
   472  // method. It returns false when there are no more results, either by
   473  // reaching the end of the input or an error.
   474  func (q *Query) Next() bool {
   475  	if q.err != nil {
   476  		return false
   477  	}
   478  	if != nil {
   479  		if {
   480  			return true
   481  		}
   482  		q.err =
   483  		if q.err != nil {
   484  			return false
   485  		}
   486  	}
   487  	if !q.rows.Next() {
   488  		return false
   489  	}
   490  	var content []byte
   491  	q.err = q.rows.Scan(&content)
   492  	if q.err != nil {
   493  		return false
   494  	}
   495 = benchfmt.NewReader(bytes.NewReader(content))
   496  	if ! {
   497  		q.err =
   498  		if q.err == nil {
   499  			q.err = io.ErrUnexpectedEOF
   500  		}
   501  		return false
   502  	}
   503  	return q.err == nil
   504  }
   506  // Result returns the most recent result generated by a call to Next.
   507  func (q *Query) Result() *benchfmt.Result {
   508  	return
   509  }
   511  // Err returns the error state of the query.
   512  func (q *Query) Err() error {
   513  	if q.err == io.EOF {
   514  		return nil
   515  	}
   516  	return q.err
   517  }
   519  // Close frees resources associated with the query.
   520  func (q *Query) Close() error {
   521  	if q.rows != nil {
   522  		return q.rows.Close()
   523  	}
   524  	return q.Err()
   525  }
   527  // CountUploads returns the number of uploads in the database.
   528  func (db *DB) CountUploads() (int, error) {
   529  	var uploads int
   530  	err := db.sql.QueryRow("SELECT COUNT(*) FROM Uploads").Scan(&uploads)
   531  	return uploads, err
   532  }
   534  // Close closes the database connections, releasing any open resources.
   535  func (db *DB) Close() error {
   536  	for _, stmt := range []*sql.Stmt{db.lastUpload, db.insertUpload, db.checkUpload, db.deleteRecords} {
   537  		if err := stmt.Close(); err != nil {
   538  			return err
   539  		}
   540  	}
   541  	return db.sql.Close()
   542  }
   544  // UploadList is the result of ListUploads.
   545  // Use Next to advance through the rows, making sure to call Close when done:
   546  //
   547  //	q := db.ListUploads("key:value")
   548  //	defer q.Close()
   549  //	for q.Next() {
   550  //	  info := q.Info()
   551  //	  ...
   552  //	}
   553  //	err = q.Err() // get any error encountered during iteration
   554  //	...
   555  type UploadList struct {
   556  	rows        *sql.Rows
   557  	extraLabels []string
   558  	// for Debug
   559  	q        string
   560  	sqlQuery string
   561  	sqlArgs  []interface{}
   562  	// from last call to Next
   563  	count       int
   564  	uploadID    string
   565  	labelValues []sql.NullString
   566  	err         error
   567  }
   569  // Debug returns the human-readable state of ul.
   570  func (ul *UploadList) Debug() string {
   571  	ret := fmt.Sprintf("q=%q", ul.q)
   572  	if ul.sqlQuery != "" || len(ul.sqlArgs) > 0 {
   573  		ret += fmt.Sprintf(" sql={%q %#v}", ul.sqlQuery, ul.sqlArgs)
   574  	}
   575  	if ul.err != nil {
   576  		ret += fmt.Sprintf(" err=%v", ul.err)
   577  	}
   578  	return ret
   579  }
   581  // ListUploads searches for uploads containing results matching the given query string.
   582  // The query may be empty, in which case all uploads will be returned.
   583  // For each label in extraLabels, one unspecified record's value will be obtained for each upload.
   584  // If limit is non-zero, only the limit most recent uploads will be returned.
   585  func (db *DB) ListUploads(q string, extraLabels []string, limit int) *UploadList {
   586  	ret := &UploadList{q: q, extraLabels: extraLabels}
   588  	var args []interface{}
   589  	query := "SELECT j.UploadID, rCount"
   590  	for i, label := range extraLabels {
   591  		query += fmt.Sprintf(", (SELECT l%d.Value FROM RecordLabels l%d WHERE l%d.UploadID = j.UploadID AND Name = ? LIMIT 1)", i, i, i)
   592  		args = append(args, label)
   593  	}
   594  	sql, qArgs, err := parseQuery(q)
   595  	if err != nil {
   596  		ret.err = err
   597  		return ret
   598  	}
   599  	if len(sql) == 0 {
   600  		// Optimize empty query.
   601  		query += " FROM (SELECT UploadID, (SELECT COUNT(*) FROM Records r WHERE r.UploadID = u.UploadID) AS rCount FROM Uploads u "
   602  		switch db.driverName {
   603  		case "sqlite3":
   604  			query += "WHERE"
   605  		default:
   606  			query += "HAVING"
   607  		}
   608  		query += " rCount > 0 ORDER BY u.Day DESC, u.Seq DESC, u.UploadID DESC"
   609  		if limit != 0 {
   610  			query += fmt.Sprintf(" LIMIT %d", limit)
   611  		}
   612  		query += ") j"
   613  	} else {
   614  		// Join individual queries.
   615  		query += " FROM (SELECT UploadID, COUNT(*) as rCount FROM "
   616  		args = append(args, qArgs...)
   617  		for i, part := range sql {
   618  			if i > 0 {
   619  				query += " INNER JOIN "
   620  			}
   621  			query += fmt.Sprintf("(%s) t%d", part, i)
   622  			if i > 0 {
   623  				query += " USING (UploadID, RecordID)"
   624  			}
   625  		}
   627  		query += " LEFT JOIN Records r USING (UploadID, RecordID)"
   628  		query += " GROUP BY UploadID) j LEFT JOIN Uploads u USING (UploadID) ORDER BY u.Day DESC, u.Seq DESC, u.UploadID DESC"
   629  		if limit != 0 {
   630  			query += fmt.Sprintf(" LIMIT %d", limit)
   631  		}
   632  	}
   634  	ret.sqlQuery, ret.sqlArgs = query, args
   635  	ret.rows, ret.err = db.sql.Query(query, args...)
   636  	return ret
   637  }
   639  // Next prepares the next result for reading with the Result
   640  // method. It returns false when there are no more results, either by
   641  // reaching the end of the input or an error.
   642  func (ul *UploadList) Next() bool {
   643  	if ul.err != nil {
   644  		return false
   645  	}
   646  	if !ul.rows.Next() {
   647  		return false
   648  	}
   649  	args := []interface{}{&ul.uploadID, &ul.count}
   650  	ul.labelValues = make([]sql.NullString, len(ul.extraLabels))
   651  	for i := range ul.labelValues {
   652  		args = append(args, &ul.labelValues[i])
   653  	}
   654  	ul.err = ul.rows.Scan(args...)
   655  	if ul.err != nil {
   656  		return false
   657  	}
   658  	return ul.err == nil
   659  }
   661  // Info returns the most recent UploadInfo generated by a call to Next.
   662  func (ul *UploadList) Info() storage.UploadInfo {
   663  	l := make(benchfmt.Labels)
   664  	for i := range ul.extraLabels {
   665  		if ul.labelValues[i].Valid {
   666  			l[ul.extraLabels[i]] = ul.labelValues[i].String
   667  		}
   668  	}
   669  	return storage.UploadInfo{UploadID: ul.uploadID, Count: ul.count, LabelValues: l}
   670  }
   672  // Err returns the error state of the query.
   673  func (ul *UploadList) Err() error {
   674  	return ul.err
   675  }
   677  // Close frees resources associated with the query.
   678  func (ul *UploadList) Close() error {
   679  	if ul.rows != nil {
   680  		return ul.rows.Close()
   681  	}
   682  	return ul.err
   683  }