github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/copy.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"io"
    17  	"strconv"
    18  	"time"
    19  	"unsafe"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/kv"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/catalog/resolver"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/privilege"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    30  	"github.com/cockroachdb/cockroach/pkg/util/log"
    31  	"github.com/cockroachdb/cockroach/pkg/util/mon"
    32  	"github.com/cockroachdb/errors"
    33  )
    34  
    35  type copyMachineInterface interface {
    36  	run(ctx context.Context) error
    37  }
    38  
    39  // copyMachine supports the Copy-in pgwire subprotocol (COPY...FROM STDIN). The
    40  // machine is created by the Executor when that statement is executed; from that
    41  // moment on, the machine takes control of the pgwire connection until
    42  // copyMachine.run() returns. During this time, the machine is responsible for
    43  // sending all the protocol messages (including the messages that are usually
    44  // associated with statement results). Errors however are not sent on the
    45  // connection by the machine; the higher layer is responsible for sending them.
    46  //
    47  // Incoming data is buffered and batched; batches are turned into insertNodes
    48  // that are executed. INSERT privileges are required on the destination table.
    49  //
    50  // See: https://www.postgresql.org/docs/current/static/sql-copy.html
    51  // and: https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-COPY
    52  type copyMachine struct {
    53  	table         tree.TableExpr
    54  	columns       tree.NameList
    55  	resultColumns sqlbase.ResultColumns
    56  	// buf is used to parse input data into rows. It also accumulates a partial
    57  	// row between protocol messages.
    58  	buf bytes.Buffer
    59  	// rows accumulates a batch of rows to be eventually inserted.
    60  	rows []tree.Exprs
    61  	// insertedRows keeps track of the total number of rows inserted by the
    62  	// machine.
    63  	insertedRows int
    64  	// rowsMemAcc accounts for memory used by `rows`.
    65  	rowsMemAcc mon.BoundAccount
    66  	// bufMemAcc accounts for memory used by `buf`; it is kept in sync with
    67  	// buf.Cap().
    68  	bufMemAcc mon.BoundAccount
    69  
    70  	// conn is the pgwire connection from which data is to be read.
    71  	conn pgwirebase.Conn
    72  
    73  	// execInsertPlan is a function to be used to execute the plan (stored in the
    74  	// planner) which performs an INSERT.
    75  	execInsertPlan func(ctx context.Context, p *planner, res RestrictedCommandResult) error
    76  
    77  	txnOpt copyTxnOpt
    78  
    79  	// p is the planner used to plan inserts. preparePlanner() needs to be called
    80  	// before preparing each new statement.
    81  	p planner
    82  
    83  	// parsingEvalCtx is an EvalContext used for the very limited needs to strings
    84  	// parsing. Is it not correctly initialized with timestamps, transactions and
    85  	// other things that statements more generally need.
    86  	parsingEvalCtx *tree.EvalContext
    87  
    88  	processRows func(ctx context.Context) error
    89  }
    90  
    91  // newCopyMachine creates a new copyMachine.
    92  func newCopyMachine(
    93  	ctx context.Context,
    94  	conn pgwirebase.Conn,
    95  	n *tree.CopyFrom,
    96  	txnOpt copyTxnOpt,
    97  	execCfg *ExecutorConfig,
    98  	execInsertPlan func(ctx context.Context, p *planner, res RestrictedCommandResult) error,
    99  ) (_ *copyMachine, retErr error) {
   100  	c := &copyMachine{
   101  		conn: conn,
   102  		// TODO(georgiah): Currently, insertRows depends on Table and Columns,
   103  		//  but that dependency can be removed by refactoring it.
   104  		table:   &n.Table,
   105  		columns: n.Columns,
   106  		txnOpt:  txnOpt,
   107  		// The planner will be prepared before use.
   108  		p:              planner{execCfg: execCfg, alloc: &sqlbase.DatumAlloc{}},
   109  		execInsertPlan: execInsertPlan,
   110  	}
   111  
   112  	// We need a planner to do the initial planning, in addition
   113  	// to those used for the main execution of the COPY afterwards.
   114  	cleanup := c.p.preparePlannerForCopy(ctx, txnOpt)
   115  	defer func() {
   116  		retErr = cleanup(ctx, retErr)
   117  	}()
   118  	c.parsingEvalCtx = c.p.EvalContext()
   119  
   120  	tableDesc, err := resolver.ResolveExistingTableObject(ctx, &c.p, &n.Table, tree.ObjectLookupFlagsWithRequired(), resolver.ResolveRequireTableDesc)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	if err := c.p.CheckPrivilege(ctx, tableDesc, privilege.INSERT); err != nil {
   125  		return nil, err
   126  	}
   127  	cols, err := sqlbase.ProcessTargetColumns(tableDesc, n.Columns,
   128  		true /* ensureColumns */, false /* allowMutations */)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	c.resultColumns = make(sqlbase.ResultColumns, len(cols))
   133  	for i := range cols {
   134  		c.resultColumns[i] = sqlbase.ResultColumn{
   135  			Name:           cols[i].Name,
   136  			Typ:            cols[i].Type,
   137  			TableID:        tableDesc.GetID(),
   138  			PGAttributeNum: cols[i].GetLogicalColumnID(),
   139  		}
   140  	}
   141  	c.rowsMemAcc = c.p.extendedEvalCtx.Mon.MakeBoundAccount()
   142  	c.bufMemAcc = c.p.extendedEvalCtx.Mon.MakeBoundAccount()
   143  	c.processRows = c.insertRows
   144  	return c, nil
   145  }
   146  
   147  // copyTxnOpt contains information about the transaction in which the copying
   148  // should take place. Can be empty, in which case the copyMachine is responsible
   149  // for managing its own transactions.
   150  type copyTxnOpt struct {
   151  	// If set, txn is the transaction within which all writes have to be
   152  	// performed. Committing the txn is left to the higher layer.  If not set, the
   153  	// machine will split writes between multiple transactions that it will
   154  	// initiate.
   155  	txn           *kv.Txn
   156  	txnTimestamp  time.Time
   157  	stmtTimestamp time.Time
   158  	resetPlanner  func(ctx context.Context, p *planner, txn *kv.Txn, txnTS time.Time, stmtTS time.Time)
   159  }
   160  
   161  // run consumes all the copy-in data from the network connection and inserts it
   162  // in the database.
   163  func (c *copyMachine) run(ctx context.Context) error {
   164  	defer c.rowsMemAcc.Close(ctx)
   165  	defer c.bufMemAcc.Close(ctx)
   166  
   167  	// Send the message describing the columns to the client.
   168  	if err := c.conn.BeginCopyIn(ctx, c.resultColumns); err != nil {
   169  		return err
   170  	}
   171  
   172  	// Read from the connection until we see an ClientMsgCopyDone.
   173  	readBuf := pgwirebase.ReadBuffer{}
   174  
   175  Loop:
   176  	for {
   177  		typ, _, err := readBuf.ReadTypedMsg(c.conn.Rd())
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		switch typ {
   183  		case pgwirebase.ClientMsgCopyData:
   184  			if err := c.processCopyData(
   185  				ctx, string(readBuf.Msg), c.p.EvalContext(), false, /* final */
   186  			); err != nil {
   187  				return err
   188  			}
   189  		case pgwirebase.ClientMsgCopyDone:
   190  			if err := c.processCopyData(
   191  				ctx, "" /* data */, c.p.EvalContext(), true, /* final */
   192  			); err != nil {
   193  				return err
   194  			}
   195  			break Loop
   196  		case pgwirebase.ClientMsgCopyFail:
   197  			return errors.Newf("client canceled COPY")
   198  		case pgwirebase.ClientMsgFlush, pgwirebase.ClientMsgSync:
   199  			// Spec says to "ignore Flush and Sync messages received during copy-in mode".
   200  		default:
   201  			return pgwirebase.NewUnrecognizedMsgTypeErr(typ)
   202  		}
   203  	}
   204  
   205  	// Finalize execution by sending the statement tag and number of rows
   206  	// inserted.
   207  	dummy := tree.CopyFrom{}
   208  	tag := []byte(dummy.StatementTag())
   209  	tag = append(tag, ' ')
   210  	tag = strconv.AppendInt(tag, int64(c.insertedRows), 10 /* base */)
   211  	return c.conn.SendCommandComplete(tag)
   212  }
   213  
   214  const (
   215  	nullString = `\N`
   216  	lineDelim  = '\n'
   217  )
   218  
   219  var (
   220  	fieldDelim = []byte{'\t'}
   221  )
   222  
   223  // processCopyData buffers incoming data and, once the buffer fills up, inserts
   224  // the accumulated rows.
   225  //
   226  // Args:
   227  // final: If set, buffered data is written even if the buffer is not full.
   228  func (c *copyMachine) processCopyData(
   229  	ctx context.Context, data string, evalCtx *tree.EvalContext, final bool,
   230  ) (retErr error) {
   231  	// At the end, adjust the mem accounting to reflect what's left in the buffer.
   232  	defer func() {
   233  		if err := c.bufMemAcc.ResizeTo(ctx, int64(c.buf.Cap())); err != nil && retErr == nil {
   234  			retErr = err
   235  		}
   236  	}()
   237  
   238  	// When this many rows are in the copy buffer, they are inserted.
   239  	const copyBatchRowSize = 100
   240  
   241  	if len(data) > (c.buf.Cap() - c.buf.Len()) {
   242  		// If it looks like the buffer will need to allocate to accommodate data,
   243  		// account for the memory here. This is not particularly accurate - we don't
   244  		// know how much the buffer will actually grow by.
   245  		if err := c.bufMemAcc.ResizeTo(ctx, int64(len(data))); err != nil {
   246  			return err
   247  		}
   248  	}
   249  	c.buf.WriteString(data)
   250  	for c.buf.Len() > 0 {
   251  		line, err := c.buf.ReadBytes(lineDelim)
   252  		if err != nil {
   253  			if err != io.EOF {
   254  				return err
   255  			} else if !final {
   256  				// Put the incomplete row back in the buffer, to be processed next time.
   257  				c.buf.Write(line)
   258  				break
   259  			}
   260  		} else {
   261  			// Remove lineDelim from end.
   262  			line = line[:len(line)-1]
   263  			// Remove a single '\r' at EOL, if present.
   264  			if len(line) > 0 && line[len(line)-1] == '\r' {
   265  				line = line[:len(line)-1]
   266  			}
   267  		}
   268  		if c.buf.Len() == 0 && bytes.Equal(line, []byte(`\.`)) {
   269  			break
   270  		}
   271  		if err := c.addRow(ctx, line); err != nil {
   272  			return err
   273  		}
   274  	}
   275  	// Only do work if we have a full batch of rows or this is the end.
   276  	if ln := len(c.rows); !final && (ln == 0 || ln < copyBatchRowSize) {
   277  		return nil
   278  	}
   279  	return c.processRows(ctx)
   280  }
   281  
   282  // preparePlannerForCopy resets the planner so that it can be used during
   283  // a COPY operation (either COPY to table, or COPY to file).
   284  //
   285  // Depending on how the requesting COPY machine was configured, a new
   286  // transaction might be created.
   287  //
   288  // It returns a cleanup function that needs to be called when we're
   289  // done with the planner (before preparePlannerForCopy is called
   290  // again). The cleanup function commits the txn (if it hasn't already
   291  // been committed) or rolls it back depending on whether it is passed
   292  // an error. If an error is passed in to the cleanup function, the
   293  // same error is returned.
   294  func (p *planner) preparePlannerForCopy(
   295  	ctx context.Context, txnOpt copyTxnOpt,
   296  ) func(context.Context, error) error {
   297  	txn := txnOpt.txn
   298  	txnTs := txnOpt.txnTimestamp
   299  	stmtTs := txnOpt.stmtTimestamp
   300  	autoCommit := false
   301  	if txn == nil {
   302  		txn = kv.NewTxnWithSteppingEnabled(ctx, p.execCfg.DB, p.execCfg.NodeID.Get())
   303  		txnTs = p.execCfg.Clock.PhysicalTime()
   304  		stmtTs = txnTs
   305  		autoCommit = true
   306  	}
   307  	txnOpt.resetPlanner(ctx, p, txn, txnTs, stmtTs)
   308  	p.autoCommit = autoCommit
   309  
   310  	return func(ctx context.Context, err error) error {
   311  		if err == nil {
   312  			// Ensure that the txn is committed if the copyMachine is in charge of
   313  			// committing its transactions and the execution didn't already commit it
   314  			// (through the planner.autoCommit optimization).
   315  			if autoCommit && !txn.IsCommitted() {
   316  				return txn.CommitOrCleanup(ctx)
   317  			}
   318  			return nil
   319  		}
   320  		txn.CleanupOnError(ctx, err)
   321  		return err
   322  	}
   323  }
   324  
   325  // insertRows transforms the buffered rows into an insertNode and executes it.
   326  func (c *copyMachine) insertRows(ctx context.Context) (retErr error) {
   327  	if len(c.rows) == 0 {
   328  		return nil
   329  	}
   330  	cleanup := c.p.preparePlannerForCopy(ctx, c.txnOpt)
   331  	defer func() {
   332  		retErr = cleanup(ctx, retErr)
   333  	}()
   334  
   335  	vc := &tree.ValuesClause{Rows: c.rows}
   336  	numRows := len(c.rows)
   337  	// Reuse the same backing array once the Insert is complete.
   338  	c.rows = c.rows[:0]
   339  	c.rowsMemAcc.Clear(ctx)
   340  
   341  	c.p.stmt = &Statement{}
   342  	c.p.stmt.AST = &tree.Insert{
   343  		Table:   c.table,
   344  		Columns: c.columns,
   345  		Rows: &tree.Select{
   346  			Select: vc,
   347  		},
   348  		Returning: tree.AbsentReturningClause,
   349  	}
   350  	if err := c.p.makeOptimizerPlan(ctx); err != nil {
   351  		return err
   352  	}
   353  
   354  	var res bufferedCommandResult
   355  	err := c.execInsertPlan(ctx, &c.p, &res)
   356  	if err != nil {
   357  		return err
   358  	}
   359  	if err := res.Err(); err != nil {
   360  		return err
   361  	}
   362  
   363  	if rows := res.RowsAffected(); rows != numRows {
   364  		log.Fatalf(ctx, "didn't insert all buffered rows and yet no error was reported. "+
   365  			"Inserted %d out of %d rows.", rows, numRows)
   366  	}
   367  	c.insertedRows += numRows
   368  
   369  	return nil
   370  }
   371  
   372  func (c *copyMachine) addRow(ctx context.Context, line []byte) error {
   373  	var err error
   374  	parts := bytes.Split(line, fieldDelim)
   375  	if len(parts) != len(c.resultColumns) {
   376  		return pgerror.Newf(pgcode.ProtocolViolation,
   377  			"expected %d values, got %d", len(c.resultColumns), len(parts))
   378  	}
   379  	exprs := make(tree.Exprs, len(parts))
   380  	for i, part := range parts {
   381  		s := string(part)
   382  		if s == nullString {
   383  			exprs[i] = tree.DNull
   384  			continue
   385  		}
   386  		switch t := c.resultColumns[i].Typ; t.Family() {
   387  		case types.BytesFamily,
   388  			types.DateFamily,
   389  			types.IntervalFamily,
   390  			types.INetFamily,
   391  			types.StringFamily,
   392  			types.TimestampFamily,
   393  			types.TimestampTZFamily,
   394  			types.UuidFamily:
   395  			s, err = decodeCopy(s)
   396  			if err != nil {
   397  				return err
   398  			}
   399  		}
   400  		d, err := sqlbase.ParseDatumStringAsWithRawBytes(c.resultColumns[i].Typ, s, c.parsingEvalCtx)
   401  		if err != nil {
   402  			return err
   403  		}
   404  
   405  		sz := d.Size()
   406  		if err := c.rowsMemAcc.Grow(ctx, int64(sz)); err != nil {
   407  			return err
   408  		}
   409  
   410  		exprs[i] = d
   411  	}
   412  	if err := c.rowsMemAcc.Grow(ctx, int64(unsafe.Sizeof(exprs))); err != nil {
   413  		return err
   414  	}
   415  
   416  	c.rows = append(c.rows, exprs)
   417  	return nil
   418  }
   419  
   420  // decodeCopy unescapes a single COPY field.
   421  //
   422  // See: https://www.postgresql.org/docs/9.5/static/sql-copy.html#AEN74432
   423  func decodeCopy(in string) (string, error) {
   424  	var buf bytes.Buffer
   425  	start := 0
   426  	for i, n := 0, len(in); i < n; i++ {
   427  		if in[i] != '\\' {
   428  			continue
   429  		}
   430  		buf.WriteString(in[start:i])
   431  		i++
   432  		if i >= n {
   433  			return "", pgerror.Newf(pgcode.Syntax,
   434  				"unknown escape sequence: %q", in[i-1:])
   435  		}
   436  
   437  		ch := in[i]
   438  		if decodedChar := decodeMap[ch]; decodedChar != 0 {
   439  			buf.WriteByte(decodedChar)
   440  		} else if ch == 'x' {
   441  			// \x can be followed by 1 or 2 hex digits.
   442  			i++
   443  			if i >= n {
   444  				return "", pgerror.Newf(pgcode.Syntax,
   445  					"unknown escape sequence: %q", in[i-2:])
   446  			}
   447  			ch = in[i]
   448  			digit, ok := decodeHexDigit(ch)
   449  			if !ok {
   450  				return "", pgerror.Newf(pgcode.Syntax,
   451  					"unknown escape sequence: %q", in[i-2:i])
   452  			}
   453  			if i+1 < n {
   454  				if v, ok := decodeHexDigit(in[i+1]); ok {
   455  					i++
   456  					digit <<= 4
   457  					digit += v
   458  				}
   459  			}
   460  			buf.WriteByte(digit)
   461  		} else if ch >= '0' && ch <= '7' {
   462  			digit, _ := decodeOctDigit(ch)
   463  			// 1 to 2 more octal digits follow.
   464  			if i+1 < n {
   465  				if v, ok := decodeOctDigit(in[i+1]); ok {
   466  					i++
   467  					digit <<= 3
   468  					digit += v
   469  				}
   470  			}
   471  			if i+1 < n {
   472  				if v, ok := decodeOctDigit(in[i+1]); ok {
   473  					i++
   474  					digit <<= 3
   475  					digit += v
   476  				}
   477  			}
   478  			buf.WriteByte(digit)
   479  		} else {
   480  			return "", pgerror.Newf(pgcode.Syntax,
   481  				"unknown escape sequence: %q", in[i-1:i+1])
   482  		}
   483  		start = i + 1
   484  	}
   485  	buf.WriteString(in[start:])
   486  	return buf.String(), nil
   487  }
   488  
   489  func decodeDigit(c byte, onlyOctal bool) (byte, bool) {
   490  	switch {
   491  	case c >= '0' && c <= '7':
   492  		return c - '0', true
   493  	case !onlyOctal && c >= '8' && c <= '9':
   494  		return c - '0', true
   495  	case !onlyOctal && c >= 'a' && c <= 'f':
   496  		return c - 'a' + 10, true
   497  	case !onlyOctal && c >= 'A' && c <= 'F':
   498  		return c - 'A' + 10, true
   499  	default:
   500  		return 0, false
   501  	}
   502  }
   503  
   504  func decodeOctDigit(c byte) (byte, bool) { return decodeDigit(c, true) }
   505  func decodeHexDigit(c byte) (byte, bool) { return decodeDigit(c, false) }
   506  
   507  var decodeMap = map[byte]byte{
   508  	'b':  '\b',
   509  	'f':  '\f',
   510  	'n':  '\n',
   511  	'r':  '\r',
   512  	't':  '\t',
   513  	'v':  '\v',
   514  	'\\': '\\',
   515  }