github.com/jackc/pgx/v5@v5.5.5/pgtype/composite.go (about) 1 package pgtype 2 3 import ( 4 "database/sql/driver" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 "strings" 9 10 "github.com/jackc/pgx/v5/internal/pgio" 11 ) 12 13 // CompositeIndexGetter is a type accessed by index that can be converted into a PostgreSQL composite. 14 type CompositeIndexGetter interface { 15 // IsNull returns true if the value is SQL NULL. 16 IsNull() bool 17 18 // Index returns the element at i. 19 Index(i int) any 20 } 21 22 // CompositeIndexScanner is a type accessed by index that can be scanned from a PostgreSQL composite. 23 type CompositeIndexScanner interface { 24 // ScanNull sets the value to SQL NULL. 25 ScanNull() error 26 27 // ScanIndex returns a value usable as a scan target for i. 28 ScanIndex(i int) any 29 } 30 31 type CompositeCodecField struct { 32 Name string 33 Type *Type 34 } 35 36 type CompositeCodec struct { 37 Fields []CompositeCodecField 38 } 39 40 func (c *CompositeCodec) FormatSupported(format int16) bool { 41 for _, f := range c.Fields { 42 if !f.Type.Codec.FormatSupported(format) { 43 return false 44 } 45 } 46 47 return true 48 } 49 50 func (c *CompositeCodec) PreferredFormat() int16 { 51 if c.FormatSupported(BinaryFormatCode) { 52 return BinaryFormatCode 53 } 54 return TextFormatCode 55 } 56 57 func (c *CompositeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { 58 if _, ok := value.(CompositeIndexGetter); !ok { 59 return nil 60 } 61 62 switch format { 63 case BinaryFormatCode: 64 return &encodePlanCompositeCodecCompositeIndexGetterToBinary{cc: c, m: m} 65 case TextFormatCode: 66 return &encodePlanCompositeCodecCompositeIndexGetterToText{cc: c, m: m} 67 } 68 69 return nil 70 } 71 72 type encodePlanCompositeCodecCompositeIndexGetterToBinary struct { 73 cc *CompositeCodec 74 m *Map 75 } 76 77 func (plan *encodePlanCompositeCodecCompositeIndexGetterToBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { 78 getter := value.(CompositeIndexGetter) 79 80 if getter.IsNull() { 81 return nil, nil 82 } 83 84 builder := NewCompositeBinaryBuilder(plan.m, buf) 85 for i, field := range plan.cc.Fields { 86 builder.AppendValue(field.Type.OID, getter.Index(i)) 87 } 88 89 return builder.Finish() 90 } 91 92 type encodePlanCompositeCodecCompositeIndexGetterToText struct { 93 cc *CompositeCodec 94 m *Map 95 } 96 97 func (plan *encodePlanCompositeCodecCompositeIndexGetterToText) Encode(value any, buf []byte) (newBuf []byte, err error) { 98 getter := value.(CompositeIndexGetter) 99 100 if getter.IsNull() { 101 return nil, nil 102 } 103 104 b := NewCompositeTextBuilder(plan.m, buf) 105 for i, field := range plan.cc.Fields { 106 b.AppendValue(field.Type.OID, getter.Index(i)) 107 } 108 109 return b.Finish() 110 } 111 112 func (c *CompositeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { 113 switch format { 114 case BinaryFormatCode: 115 switch target.(type) { 116 case CompositeIndexScanner: 117 return &scanPlanBinaryCompositeToCompositeIndexScanner{cc: c, m: m} 118 } 119 case TextFormatCode: 120 switch target.(type) { 121 case CompositeIndexScanner: 122 return &scanPlanTextCompositeToCompositeIndexScanner{cc: c, m: m} 123 } 124 } 125 126 return nil 127 } 128 129 type scanPlanBinaryCompositeToCompositeIndexScanner struct { 130 cc *CompositeCodec 131 m *Map 132 } 133 134 func (plan *scanPlanBinaryCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { 135 targetScanner := (target).(CompositeIndexScanner) 136 137 if src == nil { 138 return targetScanner.ScanNull() 139 } 140 141 scanner := NewCompositeBinaryScanner(plan.m, src) 142 for i, field := range plan.cc.Fields { 143 if scanner.Next() { 144 fieldTarget := targetScanner.ScanIndex(i) 145 if fieldTarget != nil { 146 fieldPlan := plan.m.PlanScan(field.Type.OID, BinaryFormatCode, fieldTarget) 147 if fieldPlan == nil { 148 return fmt.Errorf("unable to encode %v into OID %d in binary format", field, field.Type.OID) 149 } 150 151 err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) 152 if err != nil { 153 return err 154 } 155 } 156 } else { 157 return errors.New("read past end of composite") 158 } 159 } 160 161 if err := scanner.Err(); err != nil { 162 return err 163 } 164 165 return nil 166 } 167 168 type scanPlanTextCompositeToCompositeIndexScanner struct { 169 cc *CompositeCodec 170 m *Map 171 } 172 173 func (plan *scanPlanTextCompositeToCompositeIndexScanner) Scan(src []byte, target any) error { 174 targetScanner := (target).(CompositeIndexScanner) 175 176 if src == nil { 177 return targetScanner.ScanNull() 178 } 179 180 scanner := NewCompositeTextScanner(plan.m, src) 181 for i, field := range plan.cc.Fields { 182 if scanner.Next() { 183 fieldTarget := targetScanner.ScanIndex(i) 184 if fieldTarget != nil { 185 fieldPlan := plan.m.PlanScan(field.Type.OID, TextFormatCode, fieldTarget) 186 if fieldPlan == nil { 187 return fmt.Errorf("unable to encode %v into OID %d in text format", field, field.Type.OID) 188 } 189 190 err := fieldPlan.Scan(scanner.Bytes(), fieldTarget) 191 if err != nil { 192 return err 193 } 194 } 195 } else { 196 return errors.New("read past end of composite") 197 } 198 } 199 200 if err := scanner.Err(); err != nil { 201 return err 202 } 203 204 return nil 205 } 206 207 func (c *CompositeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { 208 if src == nil { 209 return nil, nil 210 } 211 212 switch format { 213 case TextFormatCode: 214 return string(src), nil 215 case BinaryFormatCode: 216 buf := make([]byte, len(src)) 217 copy(buf, src) 218 return buf, nil 219 default: 220 return nil, fmt.Errorf("unknown format code %d", format) 221 } 222 } 223 224 func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { 225 if src == nil { 226 return nil, nil 227 } 228 229 switch format { 230 case TextFormatCode: 231 scanner := NewCompositeTextScanner(m, src) 232 values := make(map[string]any, len(c.Fields)) 233 for i := 0; scanner.Next() && i < len(c.Fields); i++ { 234 var v any 235 fieldPlan := m.PlanScan(c.Fields[i].Type.OID, TextFormatCode, &v) 236 if fieldPlan == nil { 237 return nil, fmt.Errorf("unable to scan OID %d in text format into %v", c.Fields[i].Type.OID, v) 238 } 239 240 err := fieldPlan.Scan(scanner.Bytes(), &v) 241 if err != nil { 242 return nil, err 243 } 244 245 values[c.Fields[i].Name] = v 246 } 247 248 if err := scanner.Err(); err != nil { 249 return nil, err 250 } 251 252 return values, nil 253 case BinaryFormatCode: 254 scanner := NewCompositeBinaryScanner(m, src) 255 values := make(map[string]any, len(c.Fields)) 256 for i := 0; scanner.Next() && i < len(c.Fields); i++ { 257 var v any 258 fieldPlan := m.PlanScan(scanner.OID(), BinaryFormatCode, &v) 259 if fieldPlan == nil { 260 return nil, fmt.Errorf("unable to scan OID %d in binary format into %v", scanner.OID(), v) 261 } 262 263 err := fieldPlan.Scan(scanner.Bytes(), &v) 264 if err != nil { 265 return nil, err 266 } 267 268 values[c.Fields[i].Name] = v 269 } 270 271 if err := scanner.Err(); err != nil { 272 return nil, err 273 } 274 275 return values, nil 276 default: 277 return nil, fmt.Errorf("unknown format code %d", format) 278 } 279 280 } 281 282 type CompositeBinaryScanner struct { 283 m *Map 284 rp int 285 src []byte 286 287 fieldCount int32 288 fieldBytes []byte 289 fieldOID uint32 290 err error 291 } 292 293 // NewCompositeBinaryScanner a scanner over a binary encoded composite balue. 294 func NewCompositeBinaryScanner(m *Map, src []byte) *CompositeBinaryScanner { 295 rp := 0 296 if len(src[rp:]) < 4 { 297 return &CompositeBinaryScanner{err: fmt.Errorf("Record incomplete %v", src)} 298 } 299 300 fieldCount := int32(binary.BigEndian.Uint32(src[rp:])) 301 rp += 4 302 303 return &CompositeBinaryScanner{ 304 m: m, 305 rp: rp, 306 src: src, 307 fieldCount: fieldCount, 308 } 309 } 310 311 // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After 312 // Next returns false, the Err method can be called to check if any errors occurred. 313 func (cfs *CompositeBinaryScanner) Next() bool { 314 if cfs.err != nil { 315 return false 316 } 317 318 if cfs.rp == len(cfs.src) { 319 return false 320 } 321 322 if len(cfs.src[cfs.rp:]) < 8 { 323 cfs.err = fmt.Errorf("Record incomplete %v", cfs.src) 324 return false 325 } 326 cfs.fieldOID = binary.BigEndian.Uint32(cfs.src[cfs.rp:]) 327 cfs.rp += 4 328 329 fieldLen := int(int32(binary.BigEndian.Uint32(cfs.src[cfs.rp:]))) 330 cfs.rp += 4 331 332 if fieldLen >= 0 { 333 if len(cfs.src[cfs.rp:]) < fieldLen { 334 cfs.err = fmt.Errorf("Record incomplete rp=%d src=%v", cfs.rp, cfs.src) 335 return false 336 } 337 cfs.fieldBytes = cfs.src[cfs.rp : cfs.rp+fieldLen] 338 cfs.rp += fieldLen 339 } else { 340 cfs.fieldBytes = nil 341 } 342 343 return true 344 } 345 346 func (cfs *CompositeBinaryScanner) FieldCount() int { 347 return int(cfs.fieldCount) 348 } 349 350 // Bytes returns the bytes of the field most recently read by Scan(). 351 func (cfs *CompositeBinaryScanner) Bytes() []byte { 352 return cfs.fieldBytes 353 } 354 355 // OID returns the OID of the field most recently read by Scan(). 356 func (cfs *CompositeBinaryScanner) OID() uint32 { 357 return cfs.fieldOID 358 } 359 360 // Err returns any error encountered by the scanner. 361 func (cfs *CompositeBinaryScanner) Err() error { 362 return cfs.err 363 } 364 365 type CompositeTextScanner struct { 366 m *Map 367 rp int 368 src []byte 369 370 fieldBytes []byte 371 err error 372 } 373 374 // NewCompositeTextScanner a scanner over a text encoded composite value. 375 func NewCompositeTextScanner(m *Map, src []byte) *CompositeTextScanner { 376 if len(src) < 2 { 377 return &CompositeTextScanner{err: fmt.Errorf("Record incomplete %v", src)} 378 } 379 380 if src[0] != '(' { 381 return &CompositeTextScanner{err: fmt.Errorf("composite text format must start with '('")} 382 } 383 384 if src[len(src)-1] != ')' { 385 return &CompositeTextScanner{err: fmt.Errorf("composite text format must end with ')'")} 386 } 387 388 return &CompositeTextScanner{ 389 m: m, 390 rp: 1, 391 src: src, 392 } 393 } 394 395 // Next advances the scanner to the next field. It returns false after the last field is read or an error occurs. After 396 // Next returns false, the Err method can be called to check if any errors occurred. 397 func (cfs *CompositeTextScanner) Next() bool { 398 if cfs.err != nil { 399 return false 400 } 401 402 if cfs.rp == len(cfs.src) { 403 return false 404 } 405 406 switch cfs.src[cfs.rp] { 407 case ',', ')': // null 408 cfs.rp++ 409 cfs.fieldBytes = nil 410 return true 411 case '"': // quoted value 412 cfs.rp++ 413 cfs.fieldBytes = make([]byte, 0, 16) 414 for { 415 ch := cfs.src[cfs.rp] 416 417 if ch == '"' { 418 cfs.rp++ 419 if cfs.src[cfs.rp] == '"' { 420 cfs.fieldBytes = append(cfs.fieldBytes, '"') 421 cfs.rp++ 422 } else { 423 break 424 } 425 } else if ch == '\\' { 426 cfs.rp++ 427 cfs.fieldBytes = append(cfs.fieldBytes, cfs.src[cfs.rp]) 428 cfs.rp++ 429 } else { 430 cfs.fieldBytes = append(cfs.fieldBytes, ch) 431 cfs.rp++ 432 } 433 } 434 cfs.rp++ 435 return true 436 default: // unquoted value 437 start := cfs.rp 438 for { 439 ch := cfs.src[cfs.rp] 440 if ch == ',' || ch == ')' { 441 break 442 } 443 cfs.rp++ 444 } 445 cfs.fieldBytes = cfs.src[start:cfs.rp] 446 cfs.rp++ 447 return true 448 } 449 } 450 451 // Bytes returns the bytes of the field most recently read by Scan(). 452 func (cfs *CompositeTextScanner) Bytes() []byte { 453 return cfs.fieldBytes 454 } 455 456 // Err returns any error encountered by the scanner. 457 func (cfs *CompositeTextScanner) Err() error { 458 return cfs.err 459 } 460 461 type CompositeBinaryBuilder struct { 462 m *Map 463 buf []byte 464 startIdx int 465 fieldCount uint32 466 err error 467 } 468 469 func NewCompositeBinaryBuilder(m *Map, buf []byte) *CompositeBinaryBuilder { 470 startIdx := len(buf) 471 buf = append(buf, 0, 0, 0, 0) // allocate room for number of fields 472 return &CompositeBinaryBuilder{m: m, buf: buf, startIdx: startIdx} 473 } 474 475 func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { 476 if b.err != nil { 477 return 478 } 479 480 if field == nil { 481 b.buf = pgio.AppendUint32(b.buf, oid) 482 b.buf = pgio.AppendInt32(b.buf, -1) 483 b.fieldCount++ 484 return 485 } 486 487 plan := b.m.PlanEncode(oid, BinaryFormatCode, field) 488 if plan == nil { 489 b.err = fmt.Errorf("unable to encode %v into OID %d in binary format", field, oid) 490 return 491 } 492 493 b.buf = pgio.AppendUint32(b.buf, oid) 494 lengthPos := len(b.buf) 495 b.buf = pgio.AppendInt32(b.buf, -1) 496 fieldBuf, err := plan.Encode(field, b.buf) 497 if err != nil { 498 b.err = err 499 return 500 } 501 if fieldBuf != nil { 502 binary.BigEndian.PutUint32(fieldBuf[lengthPos:], uint32(len(fieldBuf)-len(b.buf))) 503 b.buf = fieldBuf 504 } 505 506 b.fieldCount++ 507 } 508 509 func (b *CompositeBinaryBuilder) Finish() ([]byte, error) { 510 if b.err != nil { 511 return nil, b.err 512 } 513 514 binary.BigEndian.PutUint32(b.buf[b.startIdx:], b.fieldCount) 515 return b.buf, nil 516 } 517 518 type CompositeTextBuilder struct { 519 m *Map 520 buf []byte 521 startIdx int 522 fieldCount uint32 523 err error 524 fieldBuf [32]byte 525 } 526 527 func NewCompositeTextBuilder(m *Map, buf []byte) *CompositeTextBuilder { 528 buf = append(buf, '(') // allocate room for number of fields 529 return &CompositeTextBuilder{m: m, buf: buf} 530 } 531 532 func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { 533 if b.err != nil { 534 return 535 } 536 537 if field == nil { 538 b.buf = append(b.buf, ',') 539 return 540 } 541 542 plan := b.m.PlanEncode(oid, TextFormatCode, field) 543 if plan == nil { 544 b.err = fmt.Errorf("unable to encode %v into OID %d in text format", field, oid) 545 return 546 } 547 548 fieldBuf, err := plan.Encode(field, b.fieldBuf[0:0]) 549 if err != nil { 550 b.err = err 551 return 552 } 553 if fieldBuf != nil { 554 b.buf = append(b.buf, quoteCompositeFieldIfNeeded(string(fieldBuf))...) 555 } 556 557 b.buf = append(b.buf, ',') 558 } 559 560 func (b *CompositeTextBuilder) Finish() ([]byte, error) { 561 if b.err != nil { 562 return nil, b.err 563 } 564 565 b.buf[len(b.buf)-1] = ')' 566 return b.buf, nil 567 } 568 569 var quoteCompositeReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) 570 571 func quoteCompositeField(src string) string { 572 return `"` + quoteCompositeReplacer.Replace(src) + `"` 573 } 574 575 func quoteCompositeFieldIfNeeded(src string) string { 576 if src == "" || src[0] == ' ' || src[len(src)-1] == ' ' || strings.ContainsAny(src, `(),"\`) { 577 return quoteCompositeField(src) 578 } 579 return src 580 } 581 582 // CompositeFields represents the values of a composite value. It can be used as an encoding source or as a scan target. 583 // It cannot scan a NULL, but the composite fields can be NULL. 584 type CompositeFields []any 585 586 func (cf CompositeFields) SkipUnderlyingTypePlan() {} 587 588 func (cf CompositeFields) IsNull() bool { 589 return cf == nil 590 } 591 592 func (cf CompositeFields) Index(i int) any { 593 return cf[i] 594 } 595 596 func (cf CompositeFields) ScanNull() error { 597 return fmt.Errorf("cannot scan NULL into CompositeFields") 598 } 599 600 func (cf CompositeFields) ScanIndex(i int) any { 601 return cf[i] 602 }