github.com/pingcap/tidb-lightning@v5.0.0-rc.0.20210428090220-84b649866577+incompatible/lightning/mydump/csv_parser.go (about) 1 // Copyright 2020 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package mydump 15 16 import ( 17 "bytes" 18 "io" 19 "strings" 20 "unicode" 21 22 "github.com/pingcap/br/pkg/utils" 23 24 "github.com/pingcap/errors" 25 "github.com/pingcap/tidb/types" 26 27 "github.com/pingcap/tidb-lightning/lightning/config" 28 "github.com/pingcap/tidb-lightning/lightning/worker" 29 ) 30 31 var ( 32 errUnterminatedQuotedField = errors.NewNoStackError("syntax error: unterminated quoted field") 33 errDanglingBackslash = errors.NewNoStackError("syntax error: no character after backslash") 34 errUnexpectedQuoteField = errors.NewNoStackError("syntax error: cannot have consecutive fields without separator") 35 ) 36 37 // CSVParser is basically a copy of encoding/csv, but special-cased for MySQL-like input. 38 type CSVParser struct { 39 blockParser 40 cfg *config.CSVConfig 41 escFlavor backslashEscapeFlavor 42 43 comma []byte 44 quote []byte 45 quoteIndexFunc func([]byte) int 46 unquoteIndexFunc func([]byte) int 47 48 // recordBuffer holds the unescaped fields, one after another. 49 // The fields can be accessed by using the indexes in fieldIndexes. 50 // E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de` 51 // and fieldIndexes will contain the indexes [1, 2, 5, 6]. 52 recordBuffer []byte 53 54 // fieldIndexes is an index of fields inside recordBuffer. 55 // The i'th field ends at offset fieldIndexes[i] in recordBuffer. 56 fieldIndexes []int 57 58 lastRecord []string 59 60 // if set to true, csv parser will treat the first non-empty line as header line 61 shouldParseHeader bool 62 } 63 64 func NewCSVParser( 65 cfg *config.CSVConfig, 66 reader ReadSeekCloser, 67 blockBufSize int64, 68 ioWorkers *worker.Pool, 69 shouldParseHeader bool, 70 ) *CSVParser { 71 escFlavor := backslashEscapeFlavorNone 72 var quoteStopSet []byte 73 unquoteStopSet := []byte{'\r', '\n', cfg.Separator[0]} 74 if len(cfg.Delimiter) > 0 { 75 quoteStopSet = []byte{cfg.Delimiter[0]} 76 unquoteStopSet = append(unquoteStopSet, cfg.Delimiter[0]) 77 } 78 if cfg.BackslashEscape { 79 escFlavor = backslashEscapeFlavorMySQL 80 quoteStopSet = append(quoteStopSet, '\\') 81 unquoteStopSet = append(unquoteStopSet, '\\') 82 // we need special treatment of the NULL value \N, used by MySQL. 83 if !cfg.NotNull && cfg.Null == `\N` { 84 escFlavor = backslashEscapeFlavorMySQLWithNull 85 } 86 } 87 88 return &CSVParser{ 89 blockParser: makeBlockParser(reader, blockBufSize, ioWorkers), 90 cfg: cfg, 91 comma: []byte(cfg.Separator), 92 quote: []byte(cfg.Delimiter), 93 escFlavor: escFlavor, 94 quoteIndexFunc: makeBytesIndexFunc(quoteStopSet), 95 unquoteIndexFunc: makeBytesIndexFunc(unquoteStopSet), 96 shouldParseHeader: shouldParseHeader, 97 } 98 } 99 100 func makeBytesIndexFunc(chars []byte) func([]byte) int { 101 // chars are guaranteed to be ascii str, so this call will always success 102 as := makeByteSet(chars) 103 return func(s []byte) int { 104 return IndexAnyByte(s, &as) 105 } 106 } 107 108 func (parser *CSVParser) unescapeString(input string) (unescaped string, isNull bool) { 109 if parser.escFlavor == backslashEscapeFlavorMySQLWithNull && input == `\N` { 110 return input, true 111 } 112 unescaped = unescape(input, "", parser.escFlavor) 113 isNull = parser.escFlavor != backslashEscapeFlavorMySQLWithNull && 114 !parser.cfg.NotNull && 115 unescaped == parser.cfg.Null 116 return 117 } 118 119 func (parser *CSVParser) readByte() (byte, error) { 120 if len(parser.buf) == 0 { 121 if err := parser.readBlock(); err != nil { 122 return 0, err 123 } 124 } 125 if len(parser.buf) == 0 { 126 return 0, io.EOF 127 } 128 b := parser.buf[0] 129 parser.buf = parser.buf[1:] 130 parser.pos++ 131 return b, nil 132 } 133 134 func (parser *CSVParser) readBytes(buf []byte) (int, error) { 135 cnt := 0 136 for cnt < len(buf) { 137 if len(parser.buf) == 0 { 138 if err := parser.readBlock(); err != nil { 139 return cnt, err 140 } 141 } 142 if len(parser.buf) == 0 { 143 parser.pos += int64(cnt) 144 return cnt, io.EOF 145 } 146 readCnt := utils.MinInt(len(buf)-cnt, len(parser.buf)) 147 copy(buf[cnt:], parser.buf[:readCnt]) 148 parser.buf = parser.buf[readCnt:] 149 cnt += readCnt 150 } 151 parser.pos += int64(cnt) 152 return cnt, nil 153 } 154 155 func (parser *CSVParser) peekBytes(cnt int) ([]byte, error) { 156 if len(parser.buf) < cnt { 157 if err := parser.readBlock(); err != nil { 158 return nil, err 159 } 160 } 161 if len(parser.buf) == 0 { 162 return nil, io.EOF 163 } 164 cnt = utils.MinInt(cnt, len(parser.buf)) 165 return parser.buf[:cnt], nil 166 } 167 168 func (parser *CSVParser) skipBytes(n int) { 169 parser.buf = parser.buf[n:] 170 parser.pos += int64(n) 171 } 172 173 // readUntil reads the buffer until any character from the `chars` set is found. 174 // that character is excluded from the final buffer. 175 func (parser *CSVParser) readUntil(findIndexFunc func([]byte) int) ([]byte, byte, error) { 176 index := findIndexFunc(parser.buf) 177 if index >= 0 { 178 ret := parser.buf[:index] 179 parser.buf = parser.buf[index:] 180 parser.pos += int64(index) 181 return ret, parser.buf[0], nil 182 } 183 184 // not found in parser.buf, need allocate and loop. 185 var buf []byte 186 for { 187 buf = append(buf, parser.buf...) 188 parser.buf = nil 189 if err := parser.readBlock(); err != nil || len(parser.buf) == 0 { 190 if err == nil { 191 err = io.EOF 192 } 193 parser.pos += int64(len(buf)) 194 return buf, 0, errors.Trace(err) 195 } 196 index := findIndexFunc(parser.buf) 197 if index >= 0 { 198 buf = append(buf, parser.buf[:index]...) 199 parser.buf = parser.buf[index:] 200 parser.pos += int64(len(buf)) 201 return buf, parser.buf[0], nil 202 } 203 } 204 } 205 206 func (parser *CSVParser) readRecord(dst []string) ([]string, error) { 207 parser.recordBuffer = parser.recordBuffer[:0] 208 parser.fieldIndexes = parser.fieldIndexes[:0] 209 210 isEmptyLine := true 211 whitespaceLine := true 212 213 processDefault := func(b byte) error { 214 if b == '\\' && parser.escFlavor != backslashEscapeFlavorNone { 215 if err := parser.readByteForBackslashEscape(); err != nil { 216 return err 217 } 218 } else { 219 parser.recordBuffer = append(parser.recordBuffer, b) 220 } 221 return parser.readUnquoteField() 222 } 223 224 processQuote := func(b byte) error { 225 return parser.readQuotedField() 226 } 227 if len(parser.quote) > 1 { 228 processQuote = func(b byte) error { 229 pb, err := parser.peekBytes(len(parser.quote) - 1) 230 if err != nil && errors.Cause(err) != io.EOF { 231 return err 232 } 233 if bytes.Equal(pb, parser.quote[1:]) { 234 parser.skipBytes(len(parser.quote) - 1) 235 return parser.readQuotedField() 236 } 237 return processDefault(b) 238 } 239 } 240 241 processComma := func(b byte) error { 242 parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer)) 243 return nil 244 } 245 if len(parser.comma) > 1 { 246 processNotComma := processDefault 247 if len(parser.quote) > 0 && parser.comma[0] == parser.quote[0] { 248 processNotComma = processQuote 249 } 250 processComma = func(b byte) error { 251 pb, err := parser.peekBytes(len(parser.comma) - 1) 252 if err != nil && errors.Cause(err) != io.EOF { 253 return err 254 } 255 if bytes.Equal(pb, parser.comma[1:]) { 256 parser.skipBytes(len(parser.comma) - 1) 257 parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer)) 258 return nil 259 } 260 return processNotComma(b) 261 } 262 } 263 264 outside: 265 for { 266 firstByte, err := parser.readByte() 267 if err != nil { 268 if isEmptyLine || errors.Cause(err) != io.EOF { 269 return nil, err 270 } 271 // treat EOF as the same as trailing \n. 272 firstByte = '\n' 273 } 274 275 switch { 276 case firstByte == parser.comma[0]: 277 whitespaceLine = false 278 if err = processComma(firstByte); err != nil { 279 return nil, err 280 } 281 282 case len(parser.quote) > 0 && firstByte == parser.quote[0]: 283 if err = processQuote(firstByte); err != nil { 284 return nil, err 285 } 286 whitespaceLine = false 287 case firstByte == '\r', firstByte == '\n': 288 // new line = end of record (ignore empty lines) 289 if isEmptyLine { 290 continue 291 } 292 // skip lines only contain whitespaces 293 if err == nil && whitespaceLine && len(bytes.TrimFunc(parser.recordBuffer, unicode.IsSpace)) == 0 { 294 parser.recordBuffer = parser.recordBuffer[:0] 295 continue 296 } 297 parser.fieldIndexes = append(parser.fieldIndexes, len(parser.recordBuffer)) 298 break outside 299 default: 300 if err = processDefault(firstByte); err != nil { 301 return nil, err 302 } 303 } 304 isEmptyLine = false 305 } 306 // Create a single string and create slices out of it. 307 // This pins the memory of the fields together, but allocates once. 308 str := string(parser.recordBuffer) // Convert to string once to batch allocations 309 dst = dst[:0] 310 if cap(dst) < len(parser.fieldIndexes) { 311 dst = make([]string, len(parser.fieldIndexes)) 312 } 313 dst = dst[:len(parser.fieldIndexes)] 314 var preIdx int 315 for i, idx := range parser.fieldIndexes { 316 dst[i] = str[preIdx:idx] 317 preIdx = idx 318 } 319 320 // Check or update the expected fields per record. 321 return dst, nil 322 } 323 324 func (parser *CSVParser) readByteForBackslashEscape() error { 325 b, err := parser.readByte() 326 err = parser.replaceEOF(err, errDanglingBackslash) 327 if err != nil { 328 return err 329 } 330 parser.recordBuffer = append(parser.recordBuffer, '\\', b) 331 return nil 332 } 333 334 func (parser *CSVParser) readQuotedField() error { 335 processDefault := func() error { 336 // in all other cases, we've got a syntax error. 337 parser.logSyntaxError() 338 return errors.AddStack(errUnexpectedQuoteField) 339 } 340 341 processComma := func() error { return nil } 342 if len(parser.comma) > 1 { 343 processComma = func() error { 344 b, err := parser.peekBytes(len(parser.comma)) 345 if err != nil && errors.Cause(err) != io.EOF { 346 return err 347 } 348 if !bytes.Equal(b, parser.comma) { 349 return processDefault() 350 } 351 return nil 352 } 353 } 354 for { 355 content, terminator, err := parser.readUntil(parser.quoteIndexFunc) 356 err = parser.replaceEOF(err, errUnterminatedQuotedField) 357 if err != nil { 358 return err 359 } 360 parser.recordBuffer = append(parser.recordBuffer, content...) 361 parser.skipBytes(1) 362 switch { 363 case len(parser.quote) > 0 && terminator == parser.quote[0]: 364 if len(parser.quote) > 1 { 365 b, err := parser.peekBytes(len(parser.quote) - 1) 366 if err != nil && err != io.EOF { 367 return err 368 } 369 if !bytes.Equal(b, parser.quote[1:]) { 370 parser.recordBuffer = append(parser.recordBuffer, terminator) 371 continue 372 } 373 parser.skipBytes(len(parser.quote) - 1) 374 } 375 // encountered '"' -> continue if we're seeing '""'. 376 b, err := parser.peekBytes(1) 377 if err != nil { 378 if err == io.EOF { 379 err = nil 380 } 381 return err 382 } 383 switch b[0] { 384 case parser.quote[0]: 385 // consume the double quotation mark and continue 386 if len(parser.quote) > 1 { 387 b, err := parser.peekBytes(len(parser.quote)) 388 if err != nil && err != io.EOF { 389 return err 390 } 391 if !bytes.Equal(b, parser.quote) { 392 if parser.quote[0] == parser.comma[0] { 393 return processComma() 394 } else { 395 return processDefault() 396 } 397 } 398 } 399 parser.skipBytes(len(parser.quote)) 400 parser.recordBuffer = append(parser.recordBuffer, parser.quote...) 401 case '\r', '\n': 402 // end the field if the next is a separator 403 return nil 404 case parser.comma[0]: 405 return processComma() 406 default: 407 return processDefault() 408 } 409 410 case terminator == '\\': 411 if err := parser.readByteForBackslashEscape(); err != nil { 412 return err 413 } 414 } 415 } 416 } 417 418 func (parser *CSVParser) readUnquoteField() error { 419 addByte := func(b byte) { 420 // read the following byte 421 parser.recordBuffer = append(parser.recordBuffer, b) 422 parser.skipBytes(1) 423 } 424 parseQuote := func(b byte) error { 425 r, err := parser.checkBytes(parser.quote) 426 if err != nil { 427 return errors.Trace(err) 428 } 429 if r { 430 parser.logSyntaxError() 431 return errors.AddStack(errUnexpectedQuoteField) 432 } 433 addByte(b) 434 return nil 435 } 436 437 parserNoComma := func(b byte) error { 438 addByte(b) 439 return nil 440 } 441 if len(parser.quote) > 0 && parser.comma[0] == parser.quote[0] { 442 parserNoComma = parseQuote 443 } 444 for { 445 content, terminator, err := parser.readUntil(parser.unquoteIndexFunc) 446 parser.recordBuffer = append(parser.recordBuffer, content...) 447 finished := false 448 if err != nil { 449 if errors.Cause(err) == io.EOF { 450 finished = true 451 err = nil 452 } 453 if err != nil { 454 return err 455 } 456 } 457 458 switch { 459 case terminator == '\r', terminator == '\n', finished: 460 return nil 461 case terminator == parser.comma[0]: 462 r, err := parser.checkBytes(parser.comma) 463 if err != nil { 464 return errors.Trace(err) 465 } 466 if r { 467 return nil 468 } 469 if err = parserNoComma(terminator); err != nil { 470 return err 471 } 472 case len(parser.quote) > 0 && terminator == parser.quote[0]: 473 r, err := parser.checkBytes(parser.quote) 474 if err != nil { 475 return errors.Trace(err) 476 } 477 if r { 478 parser.logSyntaxError() 479 return errors.AddStack(errUnexpectedQuoteField) 480 } 481 case terminator == '\\': 482 parser.skipBytes(1) 483 if err := parser.readByteForBackslashEscape(); err != nil { 484 return err 485 } 486 } 487 } 488 } 489 490 func (parser *CSVParser) checkBytes(b []byte) (bool, error) { 491 if len(b) == 1 { 492 return true, nil 493 } 494 pb, err := parser.peekBytes(len(b)) 495 if err != nil { 496 return false, err 497 } 498 return bytes.Equal(pb, b), nil 499 } 500 501 func (parser *CSVParser) replaceEOF(err error, replaced error) error { 502 if err == nil || errors.Cause(err) != io.EOF { 503 return err 504 } 505 if replaced != nil { 506 parser.logSyntaxError() 507 replaced = errors.AddStack(replaced) 508 } 509 return replaced 510 } 511 512 // ReadRow reads a row from the datafile. 513 func (parser *CSVParser) ReadRow() error { 514 row := &parser.lastRow 515 row.RowID++ 516 517 // skip the header first 518 if parser.shouldParseHeader { 519 err := parser.ReadColumns() 520 if err != nil { 521 return errors.Trace(err) 522 } 523 parser.shouldParseHeader = false 524 } 525 526 records, err := parser.readRecord(parser.lastRecord) 527 if err != nil { 528 return errors.Trace(err) 529 } 530 parser.lastRecord = records 531 // remove the last empty value 532 if parser.cfg.TrimLastSep { 533 i := len(records) - 1 534 if i >= 0 && len(records[i]) == 0 { 535 records = records[:i] 536 } 537 } 538 539 row.Row = parser.acquireDatumSlice() 540 if cap(row.Row) >= len(records) { 541 row.Row = row.Row[:len(records)] 542 } else { 543 row.Row = make([]types.Datum, len(records)) 544 } 545 for i, record := range records { 546 unescaped, isNull := parser.unescapeString(record) 547 if isNull { 548 row.Row[i].SetNull() 549 } else { 550 row.Row[i].SetString(unescaped, "utf8mb4_bin") 551 } 552 } 553 554 return nil 555 } 556 557 func (parser *CSVParser) ReadColumns() error { 558 columns, err := parser.readRecord(nil) 559 if err != nil { 560 return errors.Trace(err) 561 } 562 parser.columns = make([]string, 0, len(columns)) 563 for _, colName := range columns { 564 colName, _ = parser.unescapeString(colName) 565 parser.columns = append(parser.columns, strings.ToLower(colName)) 566 } 567 return nil 568 } 569 570 var newLineAsciiSet = makeByteSet([]byte{'\r', '\n'}) 571 572 func indexOfNewLine(b []byte) int { 573 return IndexAnyByte(b, &newLineAsciiSet) 574 } 575 func (parser *CSVParser) ReadUntilTokNewLine() (int64, error) { 576 _, _, err := parser.readUntil(indexOfNewLine) 577 if err != nil { 578 return 0, err 579 } 580 parser.skipBytes(1) 581 return parser.pos, nil 582 }