github.com/jackc/pgx/v5@v5.5.5/copy_from.go (about) 1 package pgx 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 9 "github.com/jackc/pgx/v5/internal/pgio" 10 "github.com/jackc/pgx/v5/pgconn" 11 ) 12 13 // CopyFromRows returns a CopyFromSource interface over the provided rows slice 14 // making it usable by *Conn.CopyFrom. 15 func CopyFromRows(rows [][]any) CopyFromSource { 16 return ©FromRows{rows: rows, idx: -1} 17 } 18 19 type copyFromRows struct { 20 rows [][]any 21 idx int 22 } 23 24 func (ctr *copyFromRows) Next() bool { 25 ctr.idx++ 26 return ctr.idx < len(ctr.rows) 27 } 28 29 func (ctr *copyFromRows) Values() ([]any, error) { 30 return ctr.rows[ctr.idx], nil 31 } 32 33 func (ctr *copyFromRows) Err() error { 34 return nil 35 } 36 37 // CopyFromSlice returns a CopyFromSource interface over a dynamic func 38 // making it usable by *Conn.CopyFrom. 39 func CopyFromSlice(length int, next func(int) ([]any, error)) CopyFromSource { 40 return ©FromSlice{next: next, idx: -1, len: length} 41 } 42 43 type copyFromSlice struct { 44 next func(int) ([]any, error) 45 idx int 46 len int 47 err error 48 } 49 50 func (cts *copyFromSlice) Next() bool { 51 cts.idx++ 52 return cts.idx < cts.len 53 } 54 55 func (cts *copyFromSlice) Values() ([]any, error) { 56 values, err := cts.next(cts.idx) 57 if err != nil { 58 cts.err = err 59 } 60 return values, err 61 } 62 63 func (cts *copyFromSlice) Err() error { 64 return cts.err 65 } 66 67 // CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values. 68 // nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil, 69 // or it returns an error. If nxtf returns an error, the copy is aborted. 70 func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource { 71 return ©FromFunc{next: nxtf} 72 } 73 74 type copyFromFunc struct { 75 next func() ([]any, error) 76 valueRow []any 77 err error 78 } 79 80 func (g *copyFromFunc) Next() bool { 81 g.valueRow, g.err = g.next() 82 // only return true if valueRow exists and no error 83 return g.valueRow != nil && g.err == nil 84 } 85 86 func (g *copyFromFunc) Values() ([]any, error) { 87 return g.valueRow, g.err 88 } 89 90 func (g *copyFromFunc) Err() error { 91 return g.err 92 } 93 94 // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. 95 type CopyFromSource interface { 96 // Next returns true if there is another row and makes the next row data 97 // available to Values(). When there are no more rows available or an error 98 // has occurred it returns false. 99 Next() bool 100 101 // Values returns the values for the current row. 102 Values() ([]any, error) 103 104 // Err returns any error that has been encountered by the CopyFromSource. If 105 // this is not nil *Conn.CopyFrom will abort the copy. 106 Err() error 107 } 108 109 type copyFrom struct { 110 conn *Conn 111 tableName Identifier 112 columnNames []string 113 rowSrc CopyFromSource 114 readerErrChan chan error 115 mode QueryExecMode 116 } 117 118 func (ct *copyFrom) run(ctx context.Context) (int64, error) { 119 if ct.conn.copyFromTracer != nil { 120 ctx = ct.conn.copyFromTracer.TraceCopyFromStart(ctx, ct.conn, TraceCopyFromStartData{ 121 TableName: ct.tableName, 122 ColumnNames: ct.columnNames, 123 }) 124 } 125 126 quotedTableName := ct.tableName.Sanitize() 127 cbuf := &bytes.Buffer{} 128 for i, cn := range ct.columnNames { 129 if i != 0 { 130 cbuf.WriteString(", ") 131 } 132 cbuf.WriteString(quoteIdentifier(cn)) 133 } 134 quotedColumnNames := cbuf.String() 135 136 var sd *pgconn.StatementDescription 137 switch ct.mode { 138 case QueryExecModeExec, QueryExecModeSimpleProtocol: 139 // These modes don't support the binary format. Before the inclusion of the 140 // QueryExecModes, Conn.Prepare was called on every COPY operation to get 141 // the OIDs. These prepared statements were not cached. 142 // 143 // Since that's the same behavior provided by QueryExecModeDescribeExec, 144 // we'll default to that mode. 145 ct.mode = QueryExecModeDescribeExec 146 fallthrough 147 case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec: 148 var err error 149 sd, err = ct.conn.getStatementDescription( 150 ctx, 151 ct.mode, 152 fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName), 153 ) 154 if err != nil { 155 return 0, fmt.Errorf("statement description failed: %w", err) 156 } 157 default: 158 return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode) 159 } 160 161 r, w := io.Pipe() 162 doneChan := make(chan struct{}) 163 164 go func() { 165 defer close(doneChan) 166 167 // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. 168 buf := ct.conn.wbuf 169 170 buf = append(buf, "PGCOPY\n\377\r\n\000"...) 171 buf = pgio.AppendInt32(buf, 0) 172 buf = pgio.AppendInt32(buf, 0) 173 174 moreRows := true 175 for moreRows { 176 var err error 177 moreRows, buf, err = ct.buildCopyBuf(buf, sd) 178 if err != nil { 179 w.CloseWithError(err) 180 return 181 } 182 183 if ct.rowSrc.Err() != nil { 184 w.CloseWithError(ct.rowSrc.Err()) 185 return 186 } 187 188 if len(buf) > 0 { 189 _, err = w.Write(buf) 190 if err != nil { 191 w.Close() 192 return 193 } 194 } 195 196 buf = buf[:0] 197 } 198 199 w.Close() 200 }() 201 202 commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) 203 204 r.Close() 205 <-doneChan 206 207 if ct.conn.copyFromTracer != nil { 208 ct.conn.copyFromTracer.TraceCopyFromEnd(ctx, ct.conn, TraceCopyFromEndData{ 209 CommandTag: commandTag, 210 Err: err, 211 }) 212 } 213 214 return commandTag.RowsAffected(), err 215 } 216 217 func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { 218 const sendBufSize = 65536 - 5 // The packet has a 5-byte header 219 lastBufLen := 0 220 largestRowLen := 0 221 222 for ct.rowSrc.Next() { 223 lastBufLen = len(buf) 224 225 values, err := ct.rowSrc.Values() 226 if err != nil { 227 return false, nil, err 228 } 229 if len(values) != len(ct.columnNames) { 230 return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) 231 } 232 233 buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) 234 for i, val := range values { 235 buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) 236 if err != nil { 237 return false, nil, err 238 } 239 } 240 241 rowLen := len(buf) - lastBufLen 242 if rowLen > largestRowLen { 243 largestRowLen = rowLen 244 } 245 246 // Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of 247 // io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531 248 // 13, 65531, 13, 65531, 13. 249 if len(buf) > sendBufSize-largestRowLen { 250 return true, buf, nil 251 } 252 } 253 254 return false, buf, nil 255 } 256 257 // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and 258 // an error. 259 // 260 // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered 261 // for the type of each column. Almost all types implemented by pgx support the binary format. 262 // 263 // Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with 264 // Conn.LoadType and pgtype.Map.RegisterType. 265 func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { 266 ct := ©From{ 267 conn: c, 268 tableName: tableName, 269 columnNames: columnNames, 270 rowSrc: rowSrc, 271 readerErrChan: make(chan error), 272 mode: c.config.DefaultQueryExecMode, 273 } 274 275 return ct.run(ctx) 276 }