github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/mydump/csv_parser.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package mydump
    15  
    16  import (
    17  	"bytes"
    18  	"io"
    19  	"strings"
    20  	"unicode"
    21  
    22  	"github.com/pingcap/br/pkg/utils"
    23  
    24  	"github.com/pingcap/errors"
    25  	"github.com/pingcap/tidb/types"
    26  
    27  	"github.com/pingcap/tidb-lightning/lightning/config"
    28  	"github.com/pingcap/tidb-lightning/lightning/worker"
    29  )
    30  
    31  var (
    32  	errUnterminatedQuotedField = errors.NewNoStackError("syntax error: unterminated quoted field")
    33  	errDanglingBackslash       = errors.NewNoStackError("syntax error: no character after backslash")
    34  	errUnexpectedQuoteField    = errors.NewNoStackError("syntax error: cannot have consecutive fields without separator")
    35  )
    36  
    37  // CSVParser is basically a copy of encoding/csv, but special-cased for MySQL-like input.
    38  type CSVParser struct {
    39  	blockParser
    40  	cfg       *config.CSVConfig
    41  	escFlavor backslashEscapeFlavor
    42  
    43  	comma            []byte
    44  	quote            []byte
    45  	quoteIndexFunc   func([]byte) int
    46  	unquoteIndexFunc func([]byte) int
    47  
    48  	// recordBuffer holds the unescaped fields, one after another.
    49  	// The fields can be accessed by using the indexes in fieldIndexes.
    50  	// E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de`
    51  	// and fieldIndexes will contain the indexes [1, 2, 5, 6].
    52  	recordBuffer []byte
    53  
    54  	// fieldIndexes is an index of fields inside recordBuffer.
    55  	// The i'th field ends at offset fieldIndexes[i] in recordBuffer.
    56  	fieldIndexes []int
    57  
    58  	lastRecord []string
    59  
    60  	// if set to true, csv parser will treat the first non-empty line as header line
    61  	shouldParseHeader bool
    62  }
    63  
    64  func NewCSVParser(
    65  	cfg *config.CSVConfig,
    66  	reader ReadSeekCloser,
    67  	blockBufSize int64,
    68  	ioWorkers *worker.Pool,
    69  	shouldParseHeader bool,
    70  ) *CSVParser {
    71  	escFlavor := backslashEscapeFlavorNone
    72  	var quoteStopSet []byte
    73  	unquoteStopSet := []byte{'\r', '\n', cfg.Separator[0]}
    74  	if len(cfg.Delimiter) > 0 {
    75  		quoteStopSet = []byte{cfg.Delimiter[0]}
    76  		unquoteStopSet = append(unquoteStopSet, cfg.Delimiter[0])
    77  	}
    78  	if cfg.BackslashEscape {
    79  		escFlavor = backslashEscapeFlavorMySQL
    80  		quoteStopSet = append(quoteStopSet, '\\')
    81  		unquoteStopSet = append(unquoteStopSet, '\\')
    82  		// we need special treatment of the NULL value \N, used by MySQL.
    83  		if !cfg.NotNull && cfg.Null == `\N` {
    84  			escFlavor = backslashEscapeFlavorMySQLWithNull
    85  		}
    86  	}
    87  
    88  	return &CSVParser{
    89  		blockParser:       makeBlockParser(reader, blockBufSize, ioWorkers),
    90  		cfg:               cfg,
    91  		comma:             []byte(cfg.Separator),
    92  		quote:             []byte(cfg.Delimiter),
    93  		escFlavor:         escFlavor,
    94  		quoteIndexFunc:    makeBytesIndexFunc(quoteStopSet),
    95  		unquoteIndexFunc:  makeBytesIndexFunc(unquoteStopSet),
    96  		shouldParseHeader: shouldParseHeader,
    97  	}
    98  }
    99  
   100  func makeBytesIndexFunc(chars []byte) func([]byte) int {
   101  	// chars are guaranteed to be ascii str, so this call will always success
   102  	as := makeByteSet(chars)
   103  	return func(s []byte) int {
   104  		return IndexAnyByte(s, &as)
   105  	}
   106  }
   107  
   108  func (parser *CSVParser) unescapeString(input string) (unescaped string, isNull bool) {
   109  	if parser.escFlavor == backslashEscapeFlavorMySQLWithNull && input == `\N` {
   110  		return input, true
   111  	}
   112  	unescaped = unescape(input, "", parser.escFlavor)
   113  	isNull = parser.escFlavor != backslashEscapeFlavorMySQLWithNull &&
   114  		!parser.cfg.NotNull &&
   115  		unescaped == parser.cfg.Null
   116  	return
   117  }
   118  
   119  func (parser *CSVParser) readByte() (byte, error) {
   120  	if len(parser.buf) == 0 {
   121  		if err := parser.readBlock(); err != nil {
   122  			return 0, err
   123  		}
   124  	}
   125  	if len(parser.buf) == 0 {
   126  		return 0, io.EOF
   127  	}
   128  	b := parser.buf[0]
   129  	parser.buf = parser.buf[1:]
   130  	parser.pos++
   131  	return b, nil
   132  }
   133  
   134  func (parser *CSVParser) readBytes(buf []byte) (int, error) {
   135  	cnt := 0
   136  	for cnt < len(buf) {
   137  		if len(parser.buf) == 0 {
   138  			if err := parser.readBlock(); err != nil {
   139  				return cnt, err
   140  			}
   141  		}
   142  		if len(parser.buf) == 0 {
   143  			parser.pos += int64(cnt)
   144  			return cnt, io.EOF
   145  		}
   146  		readCnt := utils.MinInt(len(buf)-cnt, len(parser.buf))
   147  		copy(buf[cnt:], parser.buf[:readCnt])
   148  		parser.buf = parser.buf[readCnt:]
   149  		cnt += readCnt
   150  	}
   151  	parser.pos += int64(cnt)
   152  	return cnt, nil
   153  }
   154  
   155  func (parser *CSVParser) peekBytes(cnt int) ([]byte, error) {
   156  	if len(parser.buf) < cnt {
   157  		if err := parser.readBlock(); err != nil {
   158  			return nil, err
   159  		}
   160  	}
   161  	if len(parser.buf) == 0 {
   162  		return nil, io.EOF
   163  	}
   164  	cnt = utils.MinInt(cnt, len(parser.buf))
   165  	return parser.buf[:cnt], nil
   166  }
   167  
   168  func (parser *CSVParser) skipBytes(n int) {
   169  	parser.buf = parser.buf[n:]
   170  	parser.pos += int64(n)
   171  }
   172  
   173  // readUntil reads the buffer until any character from the `chars` set is found.
   174  // that character is excluded from the final buffer.
   175  func (parser *CSVParser) readUntil(findIndexFunc func([]byte) int) ([]byte, byte, error) {
   176  	index := findIndexFunc(parser.buf)
   177  	if index >= 0 {
   178  		ret := parser.buf[:index]
   179  		parser.buf = parser.buf[index:]
   180  		parser.pos += int64(index)
   181  		return ret, parser.buf[0], nil
   182  	}
   183  
   184  	// not found in parser.buf, need allocate and loop.
   185  	var buf []byte
   186  	for {
   187  		buf = append(buf, parser.buf...)
   188  		parser.buf = nil
   189  		if err := parser.readBlock(); err != nil || len(parser.buf) == 0 {
   190  			if err == nil {
   191  				err = io.EOF
   192  			}
   193  			parser.pos += int64(len(buf))
   194  			return buf, 0, errors.Trace(err)
   195  		}
   196  		index := findIndexFunc(parser.buf)
   197  		if index >= 0 {
   198  			buf = append(buf, parser.buf[:index]...)
   199  			parser.buf = parser.buf[index:]
   200  			parser.pos += int64(len(buf))
   201  			return buf, parser.buf[0], nil
   202  		}
   203  	}
   204  }
   205  
   206  func (parser *CSVParser) readRecord(dst []string) ([]string, error) {
   207  	parser.recordBuffer = parser.recordBuffer[:0]
   208  	parser.fieldIndexes = parser.fieldIndexes[:0]
   209  
   210  	isEmptyLine := true
   211  	whitespaceLine := true
   212  
   213  	processDefault := func(b byte) error {
   214  		if b == '\\' && parser.escFlavor != backslashEscapeFlavorNone {
   215  			if err := parser.readByteForBackslashEscape(); err != nil {
   216  				return err
   217  			}
   218  		} else {
   219  			parser.recordBuffer = append(parser.recordBuffer, b)
   220  		}
   221  		return parser.readUnquoteField()
   222  	}
   223  
   224  	processQuote := func(b byte) error {
   225  		return parser.readQuotedField()
   226  	}
   227  	if len(parser.quote) > 1 {
   228  		processQuote = func(b byte) error {
   229  			pb, err := parser.peekBytes(len(parser.quote) - 1)
   230  			if err != nil && errors.Cause(err) != io.EOF {
   231  				return err
   232  			}
   233  			if bytes.Equal(pb, parser.quote[1:]) {
   234  				parser.skipBytes(len(parser.quote) - 1)
   235  				return parser.readQuotedField()
   236  			}
   237  			return processDefault(b)
   238  		}
   239  	}
   240  
   241  	processComma := func(b byte) error {
   242  		parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer))
   243  		return nil
   244  	}
   245  	if len(parser.comma) > 1 {
   246  		processNotComma := processDefault
   247  		if len(parser.quote) > 0 && parser.comma[0] == parser.quote[0] {
   248  			processNotComma = processQuote
   249  		}
   250  		processComma = func(b byte) error {
   251  			pb, err := parser.peekBytes(len(parser.comma) - 1)
   252  			if err != nil && errors.Cause(err) != io.EOF {
   253  				return err
   254  			}
   255  			if bytes.Equal(pb, parser.comma[1:]) {
   256  				parser.skipBytes(len(parser.comma) - 1)
   257  				parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer))
   258  				return nil
   259  			}
   260  			return processNotComma(b)
   261  		}
   262  	}
   263  
   264  outside:
   265  	for {
   266  		firstByte, err := parser.readByte()
   267  		if err != nil {
   268  			if isEmptyLine || errors.Cause(err) != io.EOF {
   269  				return nil, err
   270  			}
   271  			// treat EOF as the same as trailing \n.
   272  			firstByte = '\n'
   273  		}
   274  
   275  		switch {
   276  		case firstByte == parser.comma[0]:
   277  			whitespaceLine = false
   278  			if err = processComma(firstByte); err != nil {
   279  				return nil, err
   280  			}
   281  
   282  		case len(parser.quote) > 0 && firstByte == parser.quote[0]:
   283  			if err = processQuote(firstByte); err != nil {
   284  				return nil, err
   285  			}
   286  			whitespaceLine = false
   287  		case firstByte == '\r', firstByte == '\n':
   288  			// new line = end of record (ignore empty lines)
   289  			if isEmptyLine {
   290  				continue
   291  			}
   292  			// skip lines only contain whitespaces
   293  			if err == nil && whitespaceLine && len(bytes.TrimFunc(parser.recordBuffer, unicode.IsSpace)) == 0 {
   294  				parser.recordBuffer = parser.recordBuffer[:0]
   295  				continue
   296  			}
   297  			parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer))
   298  			break outside
   299  		default:
   300  			if err = processDefault(firstByte); err != nil {
   301  				return nil, err
   302  			}
   303  		}
   304  		isEmptyLine = false
   305  	}
   306  	// Create a single string and create slices out of it.
   307  	// This pins the memory of the fields together, but allocates once.
   308  	str := string(parser.recordBuffer) // Convert to string once to batch allocations
   309  	dst = dst[:0]
   310  	if cap(dst) < len(parser.fieldIndexes) {
   311  		dst = make([]string, len(parser.fieldIndexes))
   312  	}
   313  	dst = dst[:len(parser.fieldIndexes)]
   314  	var preIdx int
   315  	for i, idx := range parser.fieldIndexes {
   316  		dst[i] = str[preIdx:idx]
   317  		preIdx = idx
   318  	}
   319  
   320  	// Check or update the expected fields per record.
   321  	return dst, nil
   322  }
   323  
   324  func (parser *CSVParser) readByteForBackslashEscape() error {
   325  	b, err := parser.readByte()
   326  	err = parser.replaceEOF(err, errDanglingBackslash)
   327  	if err != nil {
   328  		return err
   329  	}
   330  	parser.recordBuffer = append(parser.recordBuffer, '\\', b)
   331  	return nil
   332  }
   333  
   334  func (parser *CSVParser) readQuotedField() error {
   335  	processDefault := func() error {
   336  		// in all other cases, we've got a syntax error.
   337  		parser.logSyntaxError()
   338  		return errors.AddStack(errUnexpectedQuoteField)
   339  	}
   340  
   341  	processComma := func() error { return nil }
   342  	if len(parser.comma) > 1 {
   343  		processComma = func() error {
   344  			b, err := parser.peekBytes(len(parser.comma))
   345  			if err != nil && errors.Cause(err) != io.EOF {
   346  				return err
   347  			}
   348  			if !bytes.Equal(b, parser.comma) {
   349  				return processDefault()
   350  			}
   351  			return nil
   352  		}
   353  	}
   354  	for {
   355  		content, terminator, err := parser.readUntil(parser.quoteIndexFunc)
   356  		err = parser.replaceEOF(err, errUnterminatedQuotedField)
   357  		if err != nil {
   358  			return err
   359  		}
   360  		parser.recordBuffer = append(parser.recordBuffer, content...)
   361  		parser.skipBytes(1)
   362  		switch {
   363  		case len(parser.quote) > 0 && terminator == parser.quote[0]:
   364  			if len(parser.quote) > 1 {
   365  				b, err := parser.peekBytes(len(parser.quote) - 1)
   366  				if err != nil && err != io.EOF {
   367  					return err
   368  				}
   369  				if !bytes.Equal(b, parser.quote[1:]) {
   370  					parser.recordBuffer = append(parser.recordBuffer, terminator)
   371  					continue
   372  				}
   373  				parser.skipBytes(len(parser.quote) - 1)
   374  			}
   375  			// encountered '"' -> continue if we're seeing '""'.
   376  			b, err := parser.peekBytes(1)
   377  			if err != nil {
   378  				if err == io.EOF {
   379  					err = nil
   380  				}
   381  				return err
   382  			}
   383  			switch b[0] {
   384  			case parser.quote[0]:
   385  				// consume the double quotation mark and continue
   386  				if len(parser.quote) > 1 {
   387  					b, err := parser.peekBytes(len(parser.quote))
   388  					if err != nil && err != io.EOF {
   389  						return err
   390  					}
   391  					if !bytes.Equal(b, parser.quote) {
   392  						if parser.quote[0] == parser.comma[0] {
   393  							return processComma()
   394  						} else {
   395  							return processDefault()
   396  						}
   397  					}
   398  				}
   399  				parser.skipBytes(len(parser.quote))
   400  				parser.recordBuffer = append(parser.recordBuffer, parser.quote...)
   401  			case '\r', '\n':
   402  				// end the field if the next is a separator
   403  				return nil
   404  			case parser.comma[0]:
   405  				return processComma()
   406  			default:
   407  				return processDefault()
   408  			}
   409  
   410  		case terminator == '\\':
   411  			if err := parser.readByteForBackslashEscape(); err != nil {
   412  				return err
   413  			}
   414  		}
   415  	}
   416  }
   417  
   418  func (parser *CSVParser) readUnquoteField() error {
   419  	addByte := func(b byte) {
   420  		// read the following byte
   421  		parser.recordBuffer = append(parser.recordBuffer, b)
   422  		parser.skipBytes(1)
   423  	}
   424  	parseQuote := func(b byte) error {
   425  		r, err := parser.checkBytes(parser.quote)
   426  		if err != nil {
   427  			return errors.Trace(err)
   428  		}
   429  		if r {
   430  			parser.logSyntaxError()
   431  			return errors.AddStack(errUnexpectedQuoteField)
   432  		}
   433  		addByte(b)
   434  		return nil
   435  	}
   436  
   437  	parserNoComma := func(b byte) error {
   438  		addByte(b)
   439  		return nil
   440  	}
   441  	if len(parser.quote) > 0 && parser.comma[0] == parser.quote[0] {
   442  		parserNoComma = parseQuote
   443  	}
   444  	for {
   445  		content, terminator, err := parser.readUntil(parser.unquoteIndexFunc)
   446  		parser.recordBuffer = append(parser.recordBuffer, content...)
   447  		finished := false
   448  		if err != nil {
   449  			if errors.Cause(err) == io.EOF {
   450  				finished = true
   451  				err = nil
   452  			}
   453  			if err != nil {
   454  				return err
   455  			}
   456  		}
   457  
   458  		switch {
   459  		case terminator == '\r', terminator == '\n', finished:
   460  			return nil
   461  		case terminator == parser.comma[0]:
   462  			r, err := parser.checkBytes(parser.comma)
   463  			if err != nil {
   464  				return errors.Trace(err)
   465  			}
   466  			if r {
   467  				return nil
   468  			}
   469  			if err = parserNoComma(terminator); err != nil {
   470  				return err
   471  			}
   472  		case len(parser.quote) > 0 && terminator == parser.quote[0]:
   473  			r, err := parser.checkBytes(parser.quote)
   474  			if err != nil {
   475  				return errors.Trace(err)
   476  			}
   477  			if r {
   478  				parser.logSyntaxError()
   479  				return errors.AddStack(errUnexpectedQuoteField)
   480  			}
   481  		case terminator == '\\':
   482  			parser.skipBytes(1)
   483  			if err := parser.readByteForBackslashEscape(); err != nil {
   484  				return err
   485  			}
   486  		}
   487  	}
   488  }
   489  
   490  func (parser *CSVParser) checkBytes(b []byte) (bool, error) {
   491  	if len(b) == 1 {
   492  		return true, nil
   493  	}
   494  	pb, err := parser.peekBytes(len(b))
   495  	if err != nil {
   496  		return false, err
   497  	}
   498  	return bytes.Equal(pb, b), nil
   499  }
   500  
   501  func (parser *CSVParser) replaceEOF(err error, replaced error) error {
   502  	if err == nil || errors.Cause(err) != io.EOF {
   503  		return err
   504  	}
   505  	if replaced != nil {
   506  		parser.logSyntaxError()
   507  		replaced = errors.AddStack(replaced)
   508  	}
   509  	return replaced
   510  }
   511  
   512  // ReadRow reads a row from the datafile.
   513  func (parser *CSVParser) ReadRow() error {
   514  	row := &parser.lastRow
   515  	row.RowID++
   516  
   517  	// skip the header first
   518  	if parser.shouldParseHeader {
   519  		err := parser.ReadColumns()
   520  		if err != nil {
   521  			return errors.Trace(err)
   522  		}
   523  		parser.shouldParseHeader = false
   524  	}
   525  
   526  	records, err := parser.readRecord(parser.lastRecord)
   527  	if err != nil {
   528  		return errors.Trace(err)
   529  	}
   530  	parser.lastRecord = records
   531  	// remove the last empty value
   532  	if parser.cfg.TrimLastSep {
   533  		i := len(records) - 1
   534  		if i >= 0 && len(records[i]) == 0 {
   535  			records = records[:i]
   536  		}
   537  	}
   538  
   539  	row.Row = parser.acquireDatumSlice()
   540  	if cap(row.Row) >= len(records) {
   541  		row.Row = row.Row[:len(records)]
   542  	} else {
   543  		row.Row = make([]types.Datum, len(records))
   544  	}
   545  	for i, record := range records {
   546  		unescaped, isNull := parser.unescapeString(record)
   547  		if isNull {
   548  			row.Row[i].SetNull()
   549  		} else {
   550  			row.Row[i].SetString(unescaped, "utf8mb4_bin")
   551  		}
   552  	}
   553  
   554  	return nil
   555  }
   556  
   557  func (parser *CSVParser) ReadColumns() error {
   558  	columns, err := parser.readRecord(nil)
   559  	if err != nil {
   560  		return errors.Trace(err)
   561  	}
   562  	parser.columns = make([]string, 0, len(columns))
   563  	for _, colName := range columns {
   564  		colName, _ = parser.unescapeString(colName)
   565  		parser.columns = append(parser.columns, strings.ToLower(colName))
   566  	}
   567  	return nil
   568  }
   569  
   570  var newLineAsciiSet = makeByteSet([]byte{'\r', '\n'})
   571  
   572  func indexOfNewLine(b []byte) int {
   573  	return IndexAnyByte(b, &newLineAsciiSet)
   574  }
   575  func (parser *CSVParser) ReadUntilTokNewLine() (int64, error) {
   576  	_, _, err := parser.readUntil(indexOfNewLine)
   577  	if err != nil {
   578  		return 0, err
   579  	}
   580  	parser.skipBytes(1)
   581  	return parser.pos, nil
   582  }