github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/read_import_pgcopy.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Licensed as a CockroachDB Enterprise file under the Cockroach Community 4 // License (the "License"); you may not use this file except in compliance with 5 // the License. You may obtain a copy of the License at 6 // 7 // https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt 8 9 package importccl 10 11 import ( 12 "bufio" 13 "bytes" 14 "context" 15 "fmt" 16 "io" 17 "strconv" 18 "unicode" 19 20 "github.com/cockroachdb/cockroach/pkg/roachpb" 21 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" 22 "github.com/cockroachdb/cockroach/pkg/sql/row" 23 "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" 24 "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" 25 "github.com/cockroachdb/cockroach/pkg/storage/cloud" 26 "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" 27 "github.com/cockroachdb/errors" 28 ) 29 30 // defaultScanBuffer is the default max row size of the PGCOPY and PGDUMP 31 // scanner. 32 const defaultScanBuffer = 1024 * 1024 * 4 33 34 type pgCopyReader struct { 35 conv row.DatumRowConverter 36 opts roachpb.PgCopyOptions 37 } 38 39 var _ inputConverter = &pgCopyReader{} 40 41 func newPgCopyReader( 42 ctx context.Context, 43 kvCh chan row.KVBatch, 44 opts roachpb.PgCopyOptions, 45 tableDesc *sqlbase.TableDescriptor, 46 evalCtx *tree.EvalContext, 47 ) (*pgCopyReader, error) { 48 conv, err := row.NewDatumRowConverter(ctx, tableDesc, nil /* targetColNames */, evalCtx, kvCh) 49 if err != nil { 50 return nil, err 51 } 52 return &pgCopyReader{ 53 conv: *conv, 54 opts: opts, 55 }, nil 56 } 57 58 func (d *pgCopyReader) start(ctx ctxgroup.Group) { 59 } 60 61 func (d *pgCopyReader) readFiles( 62 ctx context.Context, 63 dataFiles map[int32]string, 64 resumePos map[int32]int64, 65 format roachpb.IOFileFormat, 66 makeExternalStorage cloud.ExternalStorageFactory, 67 ) error { 68 return readInputFiles(ctx, dataFiles, resumePos, format, d.readFile, makeExternalStorage) 69 } 70 71 type postgreStreamCopy struct { 72 s *bufio.Scanner 73 delimiter rune 74 null string 75 } 76 77 // newPostgreStreamCopy streams COPY data rows from s with specified delimiter 78 // and null string. If readEmptyLine is true, the first line is expected to 79 // be empty as the newline following the previous COPY statement. 80 // 81 // s must be an existing bufio.Scanner configured to split on 82 // bufio.ScanLines. It must be done outside of this function because we cannot 83 // change the split func mid scan. And in order to use the same scanner as 84 // the SQL parser, which can change the splitter using a closure. 85 func newPostgreStreamCopy(s *bufio.Scanner, delimiter rune, null string) *postgreStreamCopy { 86 return &postgreStreamCopy{ 87 s: s, 88 delimiter: delimiter, 89 null: null, 90 } 91 } 92 93 var errCopyDone = errors.New("COPY done") 94 95 // Next returns the next row. io.EOF is the returned error upon EOF. The 96 // errCopyDone error is returned if "\." is encountered, indicating the end 97 // of the COPY section. 98 func (p *postgreStreamCopy) Next() (copyData, error) { 99 var row copyData 100 // the current field being read. 101 var field []byte 102 103 addField := func() error { 104 // COPY does its backslash removal after the entire field has been extracted 105 // so it can compare to the NULL setting earlier. 106 if string(field) == p.null { 107 row = append(row, nil) 108 } else { 109 // Use the same backing store since we are guaranteed to only remove 110 // characters. 111 nf := field[:0] 112 for i := 0; i < len(field); i++ { 113 if field[i] == '\\' { 114 i++ 115 if len(field) <= i { 116 return errors.New("unmatched escape") 117 } 118 // See https://www.postgresql.org/docs/current/static/sql-copy.html 119 c := field[i] 120 switch c { 121 case 'b': 122 nf = append(nf, '\b') 123 case 'f': 124 nf = append(nf, '\f') 125 case 'n': 126 nf = append(nf, '\n') 127 case 'r': 128 nf = append(nf, '\r') 129 case 't': 130 nf = append(nf, '\t') 131 case 'v': 132 nf = append(nf, '\v') 133 default: 134 if c == 'x' || (c >= '0' && c <= '9') { 135 // Handle \xNN and \NNN hex and octal escapes.) 136 if len(field) <= i+2 { 137 return errors.Errorf("unsupported escape sequence: \\%s", field[i:]) 138 } 139 base := 8 140 idx := 0 141 if c == 'x' { 142 base = 16 143 idx = 1 144 } 145 v, err := strconv.ParseInt(string(field[i+idx:i+3]), base, 8) 146 if err != nil { 147 return err 148 } 149 i += 2 150 nf = append(nf, byte(v)) 151 } else { 152 nf = append(nf, string(c)...) 153 } 154 } 155 } else { 156 nf = append(nf, field[i]) 157 } 158 } 159 ns := string(nf) 160 row = append(row, &ns) 161 } 162 field = field[:0] 163 return nil 164 } 165 166 // Attempt to read an entire line. 167 scanned := p.s.Scan() 168 if err := p.s.Err(); err != nil { 169 if errors.Is(err, bufio.ErrTooLong) { 170 err = errors.New("line too long") 171 } 172 return nil, err 173 } 174 if !scanned { 175 return nil, io.EOF 176 } 177 // Check for the copy done marker. 178 if bytes.Equal(p.s.Bytes(), []byte(`\.`)) { 179 return nil, errCopyDone 180 } 181 reader := bytes.NewReader(p.s.Bytes()) 182 183 var sawBackslash bool 184 // Start by finding field delimiters. 185 for { 186 c, w, err := reader.ReadRune() 187 if err == io.EOF { 188 break 189 } 190 if err != nil { 191 return nil, err 192 } 193 if c == unicode.ReplacementChar && w == 1 { 194 return nil, errors.New("error decoding UTF-8 Rune") 195 } 196 197 // We only care about backslashes here if they are followed by a field 198 // delimiter. Otherwise we pass them through. They will be escaped later on. 199 if sawBackslash { 200 sawBackslash = false 201 if c == p.delimiter { 202 field = append(field, string(p.delimiter)...) 203 } else { 204 field = append(field, '\\') 205 field = append(field, string(c)...) 206 } 207 continue 208 } else if c == '\\' { 209 sawBackslash = true 210 continue 211 } 212 213 const rowSeparator = '\n' 214 // Are we done with the field? 215 if c == p.delimiter || c == rowSeparator { 216 if err := addField(); err != nil { 217 return nil, err 218 } 219 } else { 220 field = append(field, string(c)...) 221 } 222 } 223 if sawBackslash { 224 return nil, errors.Errorf("unmatched escape") 225 } 226 // We always want to call this because there's at least 1 field per row. If 227 // the row is empty we should return a row with a single, empty field. 228 if err := addField(); err != nil { 229 return nil, err 230 } 231 return row, nil 232 } 233 234 const ( 235 copyDefaultDelimiter = '\t' 236 copyDefaultNull = `\N` 237 ) 238 239 type copyData []*string 240 241 func (c copyData) String() string { 242 var buf bytes.Buffer 243 for i, s := range c { 244 if i > 0 { 245 buf.WriteByte(copyDefaultDelimiter) 246 } 247 if s == nil { 248 buf.WriteString(copyDefaultNull) 249 } else { 250 // TODO(mjibson): this isn't correct COPY syntax, but it's only used in tests. 251 fmt.Fprintf(&buf, "%q", *s) 252 } 253 } 254 return buf.String() 255 } 256 257 func (d *pgCopyReader) readFile( 258 ctx context.Context, input *fileReader, inputIdx int32, resumePos int64, rejected chan string, 259 ) error { 260 s := bufio.NewScanner(input) 261 s.Split(bufio.ScanLines) 262 s.Buffer(nil, int(d.opts.MaxRowSize)) 263 c := newPostgreStreamCopy( 264 s, 265 d.opts.Delimiter, 266 d.opts.Null, 267 ) 268 d.conv.KvBatch.Source = inputIdx 269 d.conv.FractionFn = input.ReadFraction 270 count := int64(1) 271 d.conv.CompletedRowFn = func() int64 { 272 return count 273 } 274 275 for ; ; count++ { 276 row, err := c.Next() 277 if err == io.EOF { 278 break 279 } 280 if err != nil { 281 return wrapRowErr(err, "", count, pgcode.Uncategorized, "") 282 } 283 284 if count <= resumePos { 285 continue 286 } 287 288 if len(row) != len(d.conv.VisibleColTypes) { 289 return makeRowErr("", count, pgcode.Syntax, 290 "expected %d values, got %d", len(d.conv.VisibleColTypes), len(row)) 291 } 292 for i, s := range row { 293 if s == nil { 294 d.conv.Datums[i] = tree.DNull 295 } else { 296 d.conv.Datums[i], err = sqlbase.ParseDatumStringAs(d.conv.VisibleColTypes[i], *s, d.conv.EvalCtx) 297 if err != nil { 298 col := d.conv.VisibleCols[i] 299 return wrapRowErr(err, "", count, pgcode.Syntax, 300 "parse %q as %s", col.Name, col.Type.SQLString()) 301 } 302 } 303 } 304 305 if err := d.conv.Row(ctx, inputIdx, count); err != nil { 306 return wrapRowErr(err, "", count, pgcode.Uncategorized, "") 307 } 308 } 309 310 return d.conv.SendBatch(ctx) 311 }