github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/read_import_pgcopy.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package importccl
    10  
    11  import (
    12  	"bufio"
    13  	"bytes"
    14  	"context"
    15  	"fmt"
    16  	"io"
    17  	"strconv"
    18  	"unicode"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
    22  	"github.com/cockroachdb/cockroach/pkg/sql/row"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    25  	"github.com/cockroachdb/cockroach/pkg/storage/cloud"
    26  	"github.com/cockroachdb/cockroach/pkg/util/ctxgroup"
    27  	"github.com/cockroachdb/errors"
    28  )
    29  
    30  // defaultScanBuffer is the default max row size of the PGCOPY and PGDUMP
    31  // scanner.
    32  const defaultScanBuffer = 1024 * 1024 * 4
    33  
    34  type pgCopyReader struct {
    35  	conv row.DatumRowConverter
    36  	opts roachpb.PgCopyOptions
    37  }
    38  
    39  var _ inputConverter = &pgCopyReader{}
    40  
    41  func newPgCopyReader(
    42  	ctx context.Context,
    43  	kvCh chan row.KVBatch,
    44  	opts roachpb.PgCopyOptions,
    45  	tableDesc *sqlbase.TableDescriptor,
    46  	evalCtx *tree.EvalContext,
    47  ) (*pgCopyReader, error) {
    48  	conv, err := row.NewDatumRowConverter(ctx, tableDesc, nil /* targetColNames */, evalCtx, kvCh)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	return &pgCopyReader{
    53  		conv: *conv,
    54  		opts: opts,
    55  	}, nil
    56  }
    57  
    58  func (d *pgCopyReader) start(ctx ctxgroup.Group) {
    59  }
    60  
    61  func (d *pgCopyReader) readFiles(
    62  	ctx context.Context,
    63  	dataFiles map[int32]string,
    64  	resumePos map[int32]int64,
    65  	format roachpb.IOFileFormat,
    66  	makeExternalStorage cloud.ExternalStorageFactory,
    67  ) error {
    68  	return readInputFiles(ctx, dataFiles, resumePos, format, d.readFile, makeExternalStorage)
    69  }
    70  
    71  type postgreStreamCopy struct {
    72  	s         *bufio.Scanner
    73  	delimiter rune
    74  	null      string
    75  }
    76  
    77  // newPostgreStreamCopy streams COPY data rows from s with specified delimiter
    78  // and null string. If readEmptyLine is true, the first line is expected to
    79  // be empty as the newline following the previous COPY statement.
    80  //
    81  // s must be an existing bufio.Scanner configured to split on
    82  // bufio.ScanLines. It must be done outside of this function because we cannot
    83  // change the split func mid scan. And in order to use the same scanner as
    84  // the SQL parser, which can change the splitter using a closure.
    85  func newPostgreStreamCopy(s *bufio.Scanner, delimiter rune, null string) *postgreStreamCopy {
    86  	return &postgreStreamCopy{
    87  		s:         s,
    88  		delimiter: delimiter,
    89  		null:      null,
    90  	}
    91  }
    92  
    93  var errCopyDone = errors.New("COPY done")
    94  
    95  // Next returns the next row. io.EOF is the returned error upon EOF. The
    96  // errCopyDone error is returned if "\." is encountered, indicating the end
    97  // of the COPY section.
    98  func (p *postgreStreamCopy) Next() (copyData, error) {
    99  	var row copyData
   100  	// the current field being read.
   101  	var field []byte
   102  
   103  	addField := func() error {
   104  		// COPY does its backslash removal after the entire field has been extracted
   105  		// so it can compare to the NULL setting earlier.
   106  		if string(field) == p.null {
   107  			row = append(row, nil)
   108  		} else {
   109  			// Use the same backing store since we are guaranteed to only remove
   110  			// characters.
   111  			nf := field[:0]
   112  			for i := 0; i < len(field); i++ {
   113  				if field[i] == '\\' {
   114  					i++
   115  					if len(field) <= i {
   116  						return errors.New("unmatched escape")
   117  					}
   118  					// See https://www.postgresql.org/docs/current/static/sql-copy.html
   119  					c := field[i]
   120  					switch c {
   121  					case 'b':
   122  						nf = append(nf, '\b')
   123  					case 'f':
   124  						nf = append(nf, '\f')
   125  					case 'n':
   126  						nf = append(nf, '\n')
   127  					case 'r':
   128  						nf = append(nf, '\r')
   129  					case 't':
   130  						nf = append(nf, '\t')
   131  					case 'v':
   132  						nf = append(nf, '\v')
   133  					default:
   134  						if c == 'x' || (c >= '0' && c <= '9') {
   135  							// Handle \xNN and \NNN hex and octal escapes.)
   136  							if len(field) <= i+2 {
   137  								return errors.Errorf("unsupported escape sequence: \\%s", field[i:])
   138  							}
   139  							base := 8
   140  							idx := 0
   141  							if c == 'x' {
   142  								base = 16
   143  								idx = 1
   144  							}
   145  							v, err := strconv.ParseInt(string(field[i+idx:i+3]), base, 8)
   146  							if err != nil {
   147  								return err
   148  							}
   149  							i += 2
   150  							nf = append(nf, byte(v))
   151  						} else {
   152  							nf = append(nf, string(c)...)
   153  						}
   154  					}
   155  				} else {
   156  					nf = append(nf, field[i])
   157  				}
   158  			}
   159  			ns := string(nf)
   160  			row = append(row, &ns)
   161  		}
   162  		field = field[:0]
   163  		return nil
   164  	}
   165  
   166  	// Attempt to read an entire line.
   167  	scanned := p.s.Scan()
   168  	if err := p.s.Err(); err != nil {
   169  		if errors.Is(err, bufio.ErrTooLong) {
   170  			err = errors.New("line too long")
   171  		}
   172  		return nil, err
   173  	}
   174  	if !scanned {
   175  		return nil, io.EOF
   176  	}
   177  	// Check for the copy done marker.
   178  	if bytes.Equal(p.s.Bytes(), []byte(`\.`)) {
   179  		return nil, errCopyDone
   180  	}
   181  	reader := bytes.NewReader(p.s.Bytes())
   182  
   183  	var sawBackslash bool
   184  	// Start by finding field delimiters.
   185  	for {
   186  		c, w, err := reader.ReadRune()
   187  		if err == io.EOF {
   188  			break
   189  		}
   190  		if err != nil {
   191  			return nil, err
   192  		}
   193  		if c == unicode.ReplacementChar && w == 1 {
   194  			return nil, errors.New("error decoding UTF-8 Rune")
   195  		}
   196  
   197  		// We only care about backslashes here if they are followed by a field
   198  		// delimiter. Otherwise we pass them through. They will be escaped later on.
   199  		if sawBackslash {
   200  			sawBackslash = false
   201  			if c == p.delimiter {
   202  				field = append(field, string(p.delimiter)...)
   203  			} else {
   204  				field = append(field, '\\')
   205  				field = append(field, string(c)...)
   206  			}
   207  			continue
   208  		} else if c == '\\' {
   209  			sawBackslash = true
   210  			continue
   211  		}
   212  
   213  		const rowSeparator = '\n'
   214  		// Are we done with the field?
   215  		if c == p.delimiter || c == rowSeparator {
   216  			if err := addField(); err != nil {
   217  				return nil, err
   218  			}
   219  		} else {
   220  			field = append(field, string(c)...)
   221  		}
   222  	}
   223  	if sawBackslash {
   224  		return nil, errors.Errorf("unmatched escape")
   225  	}
   226  	// We always want to call this because there's at least 1 field per row. If
   227  	// the row is empty we should return a row with a single, empty field.
   228  	if err := addField(); err != nil {
   229  		return nil, err
   230  	}
   231  	return row, nil
   232  }
   233  
   234  const (
   235  	copyDefaultDelimiter = '\t'
   236  	copyDefaultNull      = `\N`
   237  )
   238  
   239  type copyData []*string
   240  
   241  func (c copyData) String() string {
   242  	var buf bytes.Buffer
   243  	for i, s := range c {
   244  		if i > 0 {
   245  			buf.WriteByte(copyDefaultDelimiter)
   246  		}
   247  		if s == nil {
   248  			buf.WriteString(copyDefaultNull)
   249  		} else {
   250  			// TODO(mjibson): this isn't correct COPY syntax, but it's only used in tests.
   251  			fmt.Fprintf(&buf, "%q", *s)
   252  		}
   253  	}
   254  	return buf.String()
   255  }
   256  
   257  func (d *pgCopyReader) readFile(
   258  	ctx context.Context, input *fileReader, inputIdx int32, resumePos int64, rejected chan string,
   259  ) error {
   260  	s := bufio.NewScanner(input)
   261  	s.Split(bufio.ScanLines)
   262  	s.Buffer(nil, int(d.opts.MaxRowSize))
   263  	c := newPostgreStreamCopy(
   264  		s,
   265  		d.opts.Delimiter,
   266  		d.opts.Null,
   267  	)
   268  	d.conv.KvBatch.Source = inputIdx
   269  	d.conv.FractionFn = input.ReadFraction
   270  	count := int64(1)
   271  	d.conv.CompletedRowFn = func() int64 {
   272  		return count
   273  	}
   274  
   275  	for ; ; count++ {
   276  		row, err := c.Next()
   277  		if err == io.EOF {
   278  			break
   279  		}
   280  		if err != nil {
   281  			return wrapRowErr(err, "", count, pgcode.Uncategorized, "")
   282  		}
   283  
   284  		if count <= resumePos {
   285  			continue
   286  		}
   287  
   288  		if len(row) != len(d.conv.VisibleColTypes) {
   289  			return makeRowErr("", count, pgcode.Syntax,
   290  				"expected %d values, got %d", len(d.conv.VisibleColTypes), len(row))
   291  		}
   292  		for i, s := range row {
   293  			if s == nil {
   294  				d.conv.Datums[i] = tree.DNull
   295  			} else {
   296  				d.conv.Datums[i], err = sqlbase.ParseDatumStringAs(d.conv.VisibleColTypes[i], *s, d.conv.EvalCtx)
   297  				if err != nil {
   298  					col := d.conv.VisibleCols[i]
   299  					return wrapRowErr(err, "", count, pgcode.Syntax,
   300  						"parse %q as %s", col.Name, col.Type.SQLString())
   301  				}
   302  			}
   303  		}
   304  
   305  		if err := d.conv.Row(ctx, inputIdx, count); err != nil {
   306  			return wrapRowErr(err, "", count, pgcode.Uncategorized, "")
   307  		}
   308  	}
   309  
   310  	return d.conv.SendBatch(ctx)
   311  }