github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cr2pg/sqlstream/stream.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 // Package sqlstream streams an io.Reader into SQL statements. 12 package sqlstream 13 14 import ( 15 "bufio" 16 "io" 17 18 "github.com/cockroachdb/cockroach/pkg/sql/parser" 19 // Include this because the parser assumes builtin functions exist. 20 _ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" 21 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 22 "github.com/cockroachdb/errors" 23 ) 24 25 // Modified from importccl/read_import_pgdump.go. 26 27 // Stream streams an io.Reader into tree.Statements. 28 type Stream struct { 29 scan *bufio.Scanner 30 } 31 32 // NewStream returns a new Stream to read from r. 33 func NewStream(r io.Reader) *Stream { 34 const defaultMax = 1024 * 1024 * 32 35 s := bufio.NewScanner(r) 36 s.Buffer(make([]byte, 0, defaultMax), defaultMax) 37 p := &Stream{scan: s} 38 s.Split(splitSQLSemicolon) 39 return p 40 } 41 42 // splitSQLSemicolon is a bufio.SplitFunc that splits on SQL semicolon tokens. 43 func splitSQLSemicolon(data []byte, atEOF bool) (advance int, token []byte, err error) { 44 if atEOF && len(data) == 0 { 45 return 0, nil, nil 46 } 47 48 if pos, ok := parser.SplitFirstStatement(string(data)); ok { 49 return pos, data[:pos], nil 50 } 51 // If we're at EOF, we have a final, non-terminated line. Return it. 52 if atEOF { 53 return len(data), data, nil 54 } 55 // Request more data. 56 return 0, nil, nil 57 } 58 59 // Next returns the next statement, or io.EOF if complete. 60 func (s *Stream) Next() (tree.Statement, error) { 61 for s.scan.Scan() { 62 t := s.scan.Text() 63 stmts, err := parser.Parse(t) 64 if err != nil { 65 return nil, err 66 } 67 switch len(stmts) { 68 case 0: 69 // Got whitespace or comments; try again. 70 case 1: 71 return stmts[0].AST, nil 72 default: 73 return nil, errors.Errorf("unexpected: got %d statements", len(stmts)) 74 } 75 } 76 if err := s.scan.Err(); err != nil { 77 if errors.Is(err, bufio.ErrTooLong) { 78 err = errors.HandledWithMessage(err, "line too long") 79 } 80 return nil, err 81 } 82 return nil, io.EOF 83 }