github.com/jackc/pgx/v5@v5.5.5/pgtype/multirange.go (about) 1 package pgtype 2 3 import ( 4 "bytes" 5 "database/sql/driver" 6 "encoding/binary" 7 "fmt" 8 "reflect" 9 10 "github.com/jackc/pgx/v5/internal/pgio" 11 ) 12 13 // MultirangeGetter is a type that can be converted into a PostgreSQL multirange. 14 type MultirangeGetter interface { 15 // IsNull returns true if the value is SQL NULL. 16 IsNull() bool 17 18 // Len returns the number of elements in the multirange. 19 Len() int 20 21 // Index returns the element at i. 22 Index(i int) any 23 24 // IndexType returns a non-nil scan target of the type Index will return. This is used by MultirangeCodec.PlanEncode. 25 IndexType() any 26 } 27 28 // MultirangeSetter is a type can be set from a PostgreSQL multirange. 29 type MultirangeSetter interface { 30 // ScanNull sets the value to SQL NULL. 31 ScanNull() error 32 33 // SetLen prepares the value such that ScanIndex can be called for each element. This will remove any existing 34 // elements. 35 SetLen(n int) error 36 37 // ScanIndex returns a value usable as a scan target for i. SetLen must be called before ScanIndex. 38 ScanIndex(i int) any 39 40 // ScanIndexType returns a non-nil scan target of the type ScanIndex will return. This is used by 41 // MultirangeCodec.PlanScan. 42 ScanIndexType() any 43 } 44 45 // MultirangeCodec is a codec for any multirange type. 46 type MultirangeCodec struct { 47 ElementType *Type 48 } 49 50 func (c *MultirangeCodec) FormatSupported(format int16) bool { 51 return c.ElementType.Codec.FormatSupported(format) 52 } 53 54 func (c *MultirangeCodec) PreferredFormat() int16 { 55 return c.ElementType.Codec.PreferredFormat() 56 } 57 58 func (c *MultirangeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { 59 multirangeValuer, ok := value.(MultirangeGetter) 60 if !ok { 61 return nil 62 } 63 64 elementType := multirangeValuer.IndexType() 65 66 elementEncodePlan := m.PlanEncode(c.ElementType.OID, format, elementType) 67 if elementEncodePlan == nil { 68 return nil 69 } 70 71 switch format { 72 case BinaryFormatCode: 73 return &encodePlanMultirangeCodecBinary{ac: c, m: m, oid: oid} 74 case TextFormatCode: 75 return &encodePlanMultirangeCodecText{ac: c, m: m, oid: oid} 76 } 77 78 return nil 79 } 80 81 type encodePlanMultirangeCodecText struct { 82 ac *MultirangeCodec 83 m *Map 84 oid uint32 85 } 86 87 func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { 88 multirange := value.(MultirangeGetter) 89 90 if multirange.IsNull() { 91 return nil, nil 92 } 93 94 elementCount := multirange.Len() 95 96 buf = append(buf, '{') 97 98 var encodePlan EncodePlan 99 var lastElemType reflect.Type 100 inElemBuf := make([]byte, 0, 32) 101 for i := 0; i < elementCount; i++ { 102 if i > 0 { 103 buf = append(buf, ',') 104 } 105 106 elem := multirange.Index(i) 107 var elemBuf []byte 108 if elem != nil { 109 elemType := reflect.TypeOf(elem) 110 if lastElemType != elemType { 111 lastElemType = elemType 112 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, TextFormatCode, elem) 113 if encodePlan == nil { 114 return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) 115 } 116 } 117 elemBuf, err = encodePlan.Encode(elem, inElemBuf) 118 if err != nil { 119 return nil, err 120 } 121 } 122 123 if elemBuf == nil { 124 return nil, fmt.Errorf("multirange cannot contain NULL element") 125 } else { 126 buf = append(buf, elemBuf...) 127 } 128 } 129 130 buf = append(buf, '}') 131 132 return buf, nil 133 } 134 135 type encodePlanMultirangeCodecBinary struct { 136 ac *MultirangeCodec 137 m *Map 138 oid uint32 139 } 140 141 func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { 142 multirange := value.(MultirangeGetter) 143 144 if multirange.IsNull() { 145 return nil, nil 146 } 147 148 elementCount := multirange.Len() 149 150 buf = pgio.AppendInt32(buf, int32(elementCount)) 151 152 var encodePlan EncodePlan 153 var lastElemType reflect.Type 154 for i := 0; i < elementCount; i++ { 155 sp := len(buf) 156 buf = pgio.AppendInt32(buf, -1) 157 158 elem := multirange.Index(i) 159 var elemBuf []byte 160 if elem != nil { 161 elemType := reflect.TypeOf(elem) 162 if lastElemType != elemType { 163 lastElemType = elemType 164 encodePlan = p.m.PlanEncode(p.ac.ElementType.OID, BinaryFormatCode, elem) 165 if encodePlan == nil { 166 return nil, fmt.Errorf("unable to encode %v", multirange.Index(i)) 167 } 168 } 169 elemBuf, err = encodePlan.Encode(elem, buf) 170 if err != nil { 171 return nil, err 172 } 173 } 174 175 if elemBuf == nil { 176 return nil, fmt.Errorf("multirange cannot contain NULL element") 177 } else { 178 buf = elemBuf 179 pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) 180 } 181 } 182 183 return buf, nil 184 } 185 186 func (c *MultirangeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { 187 multirangeScanner, ok := target.(MultirangeSetter) 188 if !ok { 189 return nil 190 } 191 192 elementType := multirangeScanner.ScanIndexType() 193 194 elementScanPlan := m.PlanScan(c.ElementType.OID, format, elementType) 195 if _, ok := elementScanPlan.(*scanPlanFail); ok { 196 return nil 197 } 198 199 return &scanPlanMultirangeCodec{ 200 multirangeCodec: c, 201 m: m, 202 oid: oid, 203 formatCode: format, 204 } 205 } 206 207 func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { 208 rp := 0 209 210 elementCount := int(binary.BigEndian.Uint32(src[rp:])) 211 rp += 4 212 213 err := multirange.SetLen(elementCount) 214 if err != nil { 215 return err 216 } 217 218 if elementCount == 0 { 219 return nil 220 } 221 222 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) 223 if elementScanPlan == nil { 224 elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) 225 } 226 227 for i := 0; i < elementCount; i++ { 228 elem := multirange.ScanIndex(i) 229 elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) 230 rp += 4 231 var elemSrc []byte 232 if elemLen >= 0 { 233 elemSrc = src[rp : rp+elemLen] 234 rp += elemLen 235 } 236 err = elementScanPlan.Scan(elemSrc, elem) 237 if err != nil { 238 return fmt.Errorf("failed to scan multirange element %d: %w", i, err) 239 } 240 } 241 242 return nil 243 } 244 245 func (c *MultirangeCodec) decodeText(m *Map, multirangeOID uint32, src []byte, multirange MultirangeSetter) error { 246 elements, err := parseUntypedTextMultirange(src) 247 if err != nil { 248 return err 249 } 250 251 err = multirange.SetLen(len(elements)) 252 if err != nil { 253 return err 254 } 255 256 if len(elements) == 0 { 257 return nil 258 } 259 260 elementScanPlan := c.ElementType.Codec.PlanScan(m, c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) 261 if elementScanPlan == nil { 262 elementScanPlan = m.PlanScan(c.ElementType.OID, TextFormatCode, multirange.ScanIndex(0)) 263 } 264 265 for i, s := range elements { 266 elem := multirange.ScanIndex(i) 267 err = elementScanPlan.Scan([]byte(s), elem) 268 if err != nil { 269 return err 270 } 271 } 272 273 return nil 274 } 275 276 type scanPlanMultirangeCodec struct { 277 multirangeCodec *MultirangeCodec 278 m *Map 279 oid uint32 280 formatCode int16 281 elementScanPlan ScanPlan 282 } 283 284 func (spac *scanPlanMultirangeCodec) Scan(src []byte, dst any) error { 285 c := spac.multirangeCodec 286 m := spac.m 287 oid := spac.oid 288 formatCode := spac.formatCode 289 290 multirange := dst.(MultirangeSetter) 291 292 if src == nil { 293 return multirange.ScanNull() 294 } 295 296 switch formatCode { 297 case BinaryFormatCode: 298 return c.decodeBinary(m, oid, src, multirange) 299 case TextFormatCode: 300 return c.decodeText(m, oid, src, multirange) 301 default: 302 return fmt.Errorf("unknown format code %d", formatCode) 303 } 304 } 305 306 func (c *MultirangeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { 307 if src == nil { 308 return nil, nil 309 } 310 311 switch format { 312 case TextFormatCode: 313 return string(src), nil 314 case BinaryFormatCode: 315 buf := make([]byte, len(src)) 316 copy(buf, src) 317 return buf, nil 318 default: 319 return nil, fmt.Errorf("unknown format code %d", format) 320 } 321 } 322 323 func (c *MultirangeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { 324 if src == nil { 325 return nil, nil 326 } 327 328 var multirange Multirange[Range[any]] 329 err := m.PlanScan(oid, format, &multirange).Scan(src, &multirange) 330 return multirange, err 331 } 332 333 func parseUntypedTextMultirange(src []byte) ([]string, error) { 334 elements := make([]string, 0) 335 336 buf := bytes.NewBuffer(src) 337 338 skipWhitespace(buf) 339 340 r, _, err := buf.ReadRune() 341 if err != nil { 342 return nil, fmt.Errorf("invalid array: %w", err) 343 } 344 345 if r != '{' { 346 return nil, fmt.Errorf("invalid multirange, expected '{' got %v", r) 347 } 348 349 parseValueLoop: 350 for { 351 r, _, err = buf.ReadRune() 352 if err != nil { 353 return nil, fmt.Errorf("invalid multirange: %w", err) 354 } 355 356 switch r { 357 case ',': // skip range separator 358 case '}': 359 break parseValueLoop 360 default: 361 buf.UnreadRune() 362 value, err := parseRange(buf) 363 if err != nil { 364 return nil, fmt.Errorf("invalid multirange value: %w", err) 365 } 366 elements = append(elements, value) 367 } 368 } 369 370 skipWhitespace(buf) 371 372 if buf.Len() > 0 { 373 return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) 374 } 375 376 return elements, nil 377 378 } 379 380 func parseRange(buf *bytes.Buffer) (string, error) { 381 s := &bytes.Buffer{} 382 383 boundSepRead := false 384 for { 385 r, _, err := buf.ReadRune() 386 if err != nil { 387 return "", err 388 } 389 390 switch r { 391 case ',', '}': 392 if r == ',' && !boundSepRead { 393 boundSepRead = true 394 break 395 } 396 buf.UnreadRune() 397 return s.String(), nil 398 } 399 400 s.WriteRune(r) 401 } 402 } 403 404 // Multirange is a generic multirange type. 405 // 406 // T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to 407 // enforce the RangeScanner constraint. 408 type Multirange[T RangeValuer] []T 409 410 func (r Multirange[T]) IsNull() bool { 411 return r == nil 412 } 413 414 func (r Multirange[T]) Len() int { 415 return len(r) 416 } 417 418 func (r Multirange[T]) Index(i int) any { 419 return r[i] 420 } 421 422 func (r Multirange[T]) IndexType() any { 423 var zero T 424 return zero 425 } 426 427 func (r *Multirange[T]) ScanNull() error { 428 *r = nil 429 return nil 430 } 431 432 func (r *Multirange[T]) SetLen(n int) error { 433 *r = make([]T, n) 434 return nil 435 } 436 437 func (r Multirange[T]) ScanIndex(i int) any { 438 return &r[i] 439 } 440 441 func (r Multirange[T]) ScanIndexType() any { 442 return new(T) 443 }