github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/types/array.go (about) 1 // Copyright (c) 2011-2013, 'pq' Contributors Portions Copyright (C) 2011 Blake Mizerany. MIT license. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining 4 // a copy of this software and associated documentation files (the "Software"), 5 // to deal in the Software without restriction, including without limitation the 6 // rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software 8 // is furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included 11 // in all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 15 // PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 16 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 17 // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 18 // SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 20 package types 21 22 import ( 23 "bytes" 24 "database/sql" 25 "database/sql/driver" 26 "encoding/hex" 27 "errors" 28 "fmt" 29 "math" 30 "reflect" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 36 "github.com/ericlagergren/decimal" 37 "github.com/lib/pq/oid" 38 "github.com/volatiletech/randomize" 39 ) 40 41 type parameterStatus struct { 42 // server version in the same format as server_version_num, or 0 if 43 // unavailable 44 serverVersion int 45 46 // the current location based on the TimeZone value of the session, if 47 // available 48 currentLocation *time.Location 49 } 50 51 func errorf(s string, args ...interface{}) { 52 panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) 53 } 54 55 func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { 56 switch v := x.(type) { 57 case int64: 58 return strconv.AppendInt(nil, v, 10) 59 case float64: 60 return strconv.AppendFloat(nil, v, 'f', -1, 64) 61 case []byte: 62 if pgtypOid == oid.T_bytea { 63 return encodeBytea(parameterStatus.serverVersion, v) 64 } 65 66 return v 67 case string: 68 if pgtypOid == oid.T_bytea { 69 return encodeBytea(parameterStatus.serverVersion, []byte(v)) 70 } 71 72 return []byte(v) 73 case bool: 74 return strconv.AppendBool(nil, v) 75 case time.Time: 76 return formatTs(v) 77 78 default: 79 errorf("encode: unknown type for %T", v) 80 } 81 82 panic("not reached") 83 } 84 85 // Parse a bytea value received from the server. Both "hex" and the legacy 86 // "escape" format are supported. 87 func parseBytea(s []byte) (result []byte, err error) { 88 if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { 89 // bytea_output = hex 90 s = s[2:] // trim off leading "\\x" 91 result = make([]byte, hex.DecodedLen(len(s))) 92 _, err := hex.Decode(result, s) 93 if err != nil { 94 return nil, err 95 } 96 } else { 97 // bytea_output = escape 98 for len(s) > 0 { 99 if s[0] == '\\' { 100 // escaped '\\' 101 if len(s) >= 2 && s[1] == '\\' { 102 result = append(result, '\\') 103 s = s[2:] 104 continue 105 } 106 107 // '\\' followed by an octal number 108 if len(s) < 4 { 109 return nil, fmt.Errorf("invalid bytea sequence %v", s) 110 } 111 r, err := strconv.ParseInt(string(s[1:4]), 8, 9) 112 if err != nil { 113 return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) 114 } 115 result = append(result, byte(r)) 116 s = s[4:] 117 } else { 118 // We hit an unescaped, raw byte. Try to read in as many as 119 // possible in one go. 120 i := bytes.IndexByte(s, '\\') 121 if i == -1 { 122 result = append(result, s...) 123 break 124 } 125 result = append(result, s[:i]...) 126 s = s[i:] 127 } 128 } 129 } 130 131 return result, nil 132 } 133 134 func encodeBytea(serverVersion int, v []byte) (result []byte) { 135 if serverVersion >= 90000 { 136 // Use the hex format if we know that the server supports it 137 result = make([]byte, 2+hex.EncodedLen(len(v))) 138 result[0] = '\\' 139 result[1] = 'x' 140 hex.Encode(result[2:], v) 141 } else { 142 // .. or resort to "escape" 143 for _, b := range v { 144 if b == '\\' { 145 result = append(result, '\\', '\\') 146 } else if b < 0x20 || b > 0x7e { 147 result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) 148 } else { 149 result = append(result, b) 150 } 151 } 152 } 153 154 return result 155 } 156 157 var errInvalidTimestamp = errors.New("invalid timestamp") 158 159 type timestampParser struct { 160 err error 161 } 162 163 func (p *timestampParser) expect(str string, char byte, pos int) { 164 if p.err != nil { 165 return 166 } 167 if pos+1 > len(str) { 168 p.err = errInvalidTimestamp 169 return 170 } 171 if c := str[pos]; c != char && p.err == nil { 172 p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) 173 } 174 } 175 176 func (p *timestampParser) mustAtoi(str string, begin int, end int) int { 177 if p.err != nil { 178 return 0 179 } 180 if begin < 0 || end < 0 || begin > end || end > len(str) { 181 p.err = errInvalidTimestamp 182 return 0 183 } 184 result, err := strconv.Atoi(str[begin:end]) 185 if err != nil { 186 if p.err == nil { 187 p.err = fmt.Errorf("expected number; got '%v'", str) 188 } 189 return 0 190 } 191 return result 192 } 193 194 // The location cache caches the time zones typically used by the client. 195 type locationCache struct { 196 cache map[int]*time.Location 197 lock sync.Mutex 198 } 199 200 // All connections share the same list of timezones. Benchmarking shows that 201 // about 5% speed could be gained by putting the cache in the connection and 202 // losing the mutex, at the cost of a small amount of memory and a somewhat 203 // significant increase in code complexity. 204 var globalLocationCache = newLocationCache() 205 206 func newLocationCache() *locationCache { 207 return &locationCache{cache: make(map[int]*time.Location)} 208 } 209 210 // Returns the cached timezone for the specified offset, creating and caching 211 // it if necessary. 212 func (c *locationCache) getLocation(offset int) *time.Location { 213 c.lock.Lock() 214 defer c.lock.Unlock() 215 216 location, ok := c.cache[offset] 217 if !ok { 218 location = time.FixedZone("", offset) 219 c.cache[offset] = location 220 } 221 222 return location 223 } 224 225 var infinityTsEnabled = false 226 var infinityTsNegative time.Time 227 var infinityTsPositive time.Time 228 229 const ( 230 infinityTsEnabledAlready = "pq: infinity timestamp enabled already" 231 infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" 232 ) 233 234 // EnableInfinityTs controls the handling of Postgres' "-infinity" and 235 // "infinity" "timestamp"s. 236 // 237 // If EnableInfinityTs is not called, "-infinity" and "infinity" will return 238 // []byte("-infinity") and []byte("infinity") respectively, and potentially 239 // cause error "sql: Scan error on column index 0: unsupported driver -> Scan 240 // pair: []uint8 -> *time.Time", when scanning into a time.Time value. 241 // 242 // Once EnableInfinityTs has been called, all connections created using this 243 // driver will decode Postgres' "-infinity" and "infinity" for "timestamp", 244 // "timestamp with time zone" and "date" types to the predefined minimum and 245 // maximum times, respectively. When encoding time.Time values, any time which 246 // equals or precedes the predefined minimum time will be encoded to 247 // "-infinity". Any values at or past the maximum time will similarly be 248 // encoded to "infinity". 249 // 250 // If EnableInfinityTs is called with negative >= positive, it will panic. 251 // Calling EnableInfinityTs after a connection has been established results in 252 // undefined behavior. If EnableInfinityTs is called more than once, it will 253 // panic. 254 func EnableInfinityTs(negative time.Time, positive time.Time) { 255 if infinityTsEnabled { 256 panic(infinityTsEnabledAlready) 257 } 258 if !negative.Before(positive) { 259 panic(infinityTsNegativeMustBeSmaller) 260 } 261 infinityTsEnabled = true 262 infinityTsNegative = negative 263 infinityTsPositive = positive 264 } 265 266 /* 267 * Testing might want to toggle infinityTsEnabled 268 */ 269 func disableInfinityTs() { 270 infinityTsEnabled = false 271 } 272 273 // This is a time function specific to the Postgres default DateStyle 274 // setting ("ISO, MDY"), the only one we currently support. This 275 // accounts for the discrepancies between the parsing available with 276 // time.Parse and the Postgres date formatting quirks. 277 func parseTs(currentLocation *time.Location, str string) interface{} { 278 switch str { 279 case "-infinity": 280 if infinityTsEnabled { 281 return infinityTsNegative 282 } 283 return []byte(str) 284 case "infinity": 285 if infinityTsEnabled { 286 return infinityTsPositive 287 } 288 return []byte(str) 289 } 290 t, err := ParseTimestamp(currentLocation, str) 291 if err != nil { 292 panic(err) 293 } 294 return t 295 } 296 297 // ParseTimestamp parses Postgres' text format. It returns a time.Time in 298 // currentLocation iff that time's offset agrees with the offset sent from the 299 // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the 300 // fixed offset offset provided by the Postgres server. 301 func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { 302 p := timestampParser{} 303 304 monSep := strings.IndexRune(str, '-') 305 // this is Gregorian year, not ISO Year 306 // In Gregorian system, the year 1 BC is followed by AD 1 307 year := p.mustAtoi(str, 0, monSep) 308 daySep := monSep + 3 309 month := p.mustAtoi(str, monSep+1, daySep) 310 p.expect(str, '-', daySep) 311 timeSep := daySep + 3 312 day := p.mustAtoi(str, daySep+1, timeSep) 313 314 var hour, minute, second int 315 if len(str) > monSep+len("01-01")+1 { 316 p.expect(str, ' ', timeSep) 317 minSep := timeSep + 3 318 p.expect(str, ':', minSep) 319 hour = p.mustAtoi(str, timeSep+1, minSep) 320 secSep := minSep + 3 321 p.expect(str, ':', secSep) 322 minute = p.mustAtoi(str, minSep+1, secSep) 323 secEnd := secSep + 3 324 second = p.mustAtoi(str, secSep+1, secEnd) 325 } 326 remainderIdx := monSep + len("01-01 00:00:00") + 1 327 // Three optional (but ordered) sections follow: the 328 // fractional seconds, the time zone offset, and the BC 329 // designation. We set them up here and adjust the other 330 // offsets if the preceding sections exist. 331 332 nanoSec := 0 333 tzOff := 0 334 335 if remainderIdx < len(str) && str[remainderIdx] == '.' { 336 fracStart := remainderIdx + 1 337 fracOff := strings.IndexAny(str[fracStart:], "-+ ") 338 if fracOff < 0 { 339 fracOff = len(str) - fracStart 340 } 341 fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) 342 nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) 343 344 remainderIdx += fracOff + 1 345 } 346 if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { 347 // time zone separator is always '-' or '+' (UTC is +00) 348 var tzSign int 349 switch c := str[tzStart]; c { 350 case '-': 351 tzSign = -1 352 case '+': 353 tzSign = +1 354 default: 355 return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) 356 } 357 tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) 358 remainderIdx += 3 359 var tzMin, tzSec int 360 if remainderIdx < len(str) && str[remainderIdx] == ':' { 361 tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) 362 remainderIdx += 3 363 } 364 if remainderIdx < len(str) && str[remainderIdx] == ':' { 365 tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) 366 remainderIdx += 3 367 } 368 tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) 369 } 370 var isoYear int 371 if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { 372 isoYear = 1 - year 373 remainderIdx += 3 374 } else { 375 isoYear = year 376 } 377 if remainderIdx < len(str) { 378 return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) 379 } 380 t := time.Date(isoYear, time.Month(month), day, 381 hour, minute, second, nanoSec, 382 globalLocationCache.getLocation(tzOff)) 383 384 if currentLocation != nil { 385 // Set the location of the returned Time based on the session's 386 // TimeZone value, but only if the local time zone database agrees with 387 // the remote database on the offset. 388 lt := t.In(currentLocation) 389 _, newOff := lt.Zone() 390 if newOff == tzOff { 391 t = lt 392 } 393 } 394 395 return t, p.err 396 } 397 398 // formatTs formats t into a format postgres understands. 399 func formatTs(t time.Time) []byte { 400 if infinityTsEnabled { 401 // t <= -infinity : ! (t > -infinity) 402 if !t.After(infinityTsNegative) { 403 return []byte("-infinity") 404 } 405 // t >= infinity : ! (!t < infinity) 406 if !t.Before(infinityTsPositive) { 407 return []byte("infinity") 408 } 409 } 410 return FormatTimestamp(t) 411 } 412 413 // FormatTimestamp formats t into Postgres' text format for timestamps. 414 func FormatTimestamp(t time.Time) []byte { 415 // Need to send dates before 0001 A.D. with " BC" suffix, instead of the 416 // minus sign preferred by Go. 417 // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on 418 bc := false 419 if t.Year() <= 0 { 420 // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" 421 t = t.AddDate((-t.Year())*2+1, 0, 0) 422 bc = true 423 } 424 b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) 425 426 _, offset := t.Zone() 427 offset %= 60 428 if offset != 0 { 429 // RFC3339Nano already printed the minus sign 430 if offset < 0 { 431 offset = -offset 432 } 433 434 b = append(b, ':') 435 if offset < 10 { 436 b = append(b, '0') 437 } 438 b = strconv.AppendInt(b, int64(offset), 10) 439 } 440 441 if bc { 442 b = append(b, " BC"...) 443 } 444 return b 445 } 446 447 var typeByteSlice = reflect.TypeOf([]byte{}) 448 var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() 449 var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 450 451 // Array returns the optimal driver.Valuer and sql.Scanner for an array or 452 // slice of any dimension. 453 // 454 // For example: 455 // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) 456 // 457 // var x []sql.NullInt64 458 // db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) 459 // 460 // Scanning multi-dimensional arrays is not supported. Arrays where the lower 461 // bound is not one (such as `[0:0]={1}') are not supported. 462 func Array(a interface{}) interface { 463 driver.Valuer 464 sql.Scanner 465 } { 466 switch a := a.(type) { 467 case []bool: 468 return (*BoolArray)(&a) 469 case []float64: 470 return (*Float64Array)(&a) 471 case []int64: 472 return (*Int64Array)(&a) 473 case []string: 474 return (*StringArray)(&a) 475 476 case *[]bool: 477 return (*BoolArray)(a) 478 case *[]float64: 479 return (*Float64Array)(a) 480 case *[]int64: 481 return (*Int64Array)(a) 482 case *[]string: 483 return (*StringArray)(a) 484 } 485 486 return GenericArray{a} 487 } 488 489 // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner 490 // to override the array delimiter used by GenericArray. 491 type ArrayDelimiter interface { 492 // ArrayDelimiter returns the delimiter character(s) for this element's type. 493 ArrayDelimiter() string 494 } 495 496 // BoolArray represents a one-dimensional array of the PostgreSQL boolean type. 497 type BoolArray []bool 498 499 // Scan implements the sql.Scanner interface. 500 func (a *BoolArray) Scan(src interface{}) error { 501 switch src := src.(type) { 502 case []byte: 503 return a.scanBytes(src) 504 case string: 505 return a.scanBytes([]byte(src)) 506 case nil: 507 *a = nil 508 return nil 509 } 510 511 return fmt.Errorf("boil: cannot convert %T to BoolArray", src) 512 } 513 514 func (a *BoolArray) scanBytes(src []byte) error { 515 elems, err := scanLinearArray(src, []byte{','}, "BoolArray") 516 if err != nil { 517 return err 518 } 519 if *a != nil && len(elems) == 0 { 520 *a = (*a)[:0] 521 } else { 522 b := make(BoolArray, len(elems)) 523 for i, v := range elems { 524 if len(v) < 1 { 525 return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) 526 } 527 switch v[:1][0] { 528 case 't', 'T': 529 b[i] = true 530 case 'f', 'F': 531 b[i] = false 532 default: 533 return fmt.Errorf("boil: could not parse boolean array index %d: invalid boolean %q", i, v) 534 } 535 } 536 *a = b 537 } 538 return nil 539 } 540 541 // Value implements the driver.Valuer interface. 542 func (a BoolArray) Value() (driver.Value, error) { 543 if a == nil { 544 return nil, nil 545 } 546 547 if n := len(a); n > 0 { 548 // There will be exactly two curly brackets, N bytes of values, 549 // and N-1 bytes of delimiters. 550 b := make([]byte, 1+2*n) 551 552 for i := 0; i < n; i++ { 553 b[2*i] = ',' 554 if a[i] { 555 b[1+2*i] = 't' 556 } else { 557 b[1+2*i] = 'f' 558 } 559 } 560 561 b[0] = '{' 562 b[2*n] = '}' 563 564 return string(b), nil 565 } 566 567 return "{}", nil 568 } 569 570 // Randomize for sqlboiler 571 func (a *BoolArray) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 572 *a = BoolArray{nextInt()%2 == 0, nextInt()%2 == 0, nextInt()%2 == 0} 573 } 574 575 // BytesArray represents a one-dimensional array of the PostgreSQL bytea type. 576 type BytesArray [][]byte 577 578 // Scan implements the sql.Scanner interface. 579 func (a *BytesArray) Scan(src interface{}) error { 580 switch src := src.(type) { 581 case []byte: 582 return a.scanBytes(src) 583 case string: 584 return a.scanBytes([]byte(src)) 585 case nil: 586 *a = nil 587 return nil 588 } 589 590 return fmt.Errorf("boil: cannot convert %T to BytesArray", src) 591 } 592 593 func (a *BytesArray) scanBytes(src []byte) error { 594 elems, err := scanLinearArray(src, []byte{','}, "BytesArray") 595 if err != nil { 596 return err 597 } 598 if *a != nil && len(elems) == 0 { 599 *a = (*a)[:0] 600 } else { 601 b := make(BytesArray, len(elems)) 602 for i, v := range elems { 603 b[i], err = parseBytea(v) 604 if err != nil { 605 return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) 606 } 607 } 608 *a = b 609 } 610 return nil 611 } 612 613 // Value implements the driver.Valuer interface. It uses the "hex" format which 614 // is only supported on PostgreSQL 9.0 or newer. 615 func (a BytesArray) Value() (driver.Value, error) { 616 if a == nil { 617 return nil, nil 618 } 619 620 if n := len(a); n > 0 { 621 // There will be at least two curly brackets, 2*N bytes of quotes, 622 // 3*N bytes of hex formatting, and N-1 bytes of delimiters. 623 size := 1 + 6*n 624 for _, x := range a { 625 size += hex.EncodedLen(len(x)) 626 } 627 628 b := make([]byte, size) 629 630 for i, s := 0, b; i < n; i++ { 631 o := copy(s, `,"\\x`) 632 o += hex.Encode(s[o:], a[i]) 633 s[o] = '"' 634 s = s[o+1:] 635 } 636 637 b[0] = '{' 638 b[size-1] = '}' 639 640 return string(b), nil 641 } 642 643 return "{}", nil 644 } 645 646 // Randomize for sqlboiler 647 func (a *BytesArray) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 648 *a = BytesArray{randomize.ByteSlice(nextInt, 4), randomize.ByteSlice(nextInt, 4), randomize.ByteSlice(nextInt, 4)} 649 } 650 651 // Float64Array represents a one-dimensional array of the PostgreSQL double 652 // precision type. 653 type Float64Array []float64 654 655 // Scan implements the sql.Scanner interface. 656 func (a *Float64Array) Scan(src interface{}) error { 657 switch src := src.(type) { 658 case []byte: 659 return a.scanBytes(src) 660 case string: 661 return a.scanBytes([]byte(src)) 662 case nil: 663 *a = nil 664 return nil 665 } 666 667 return fmt.Errorf("boil: cannot convert %T to Float64Array", src) 668 } 669 670 func (a *Float64Array) scanBytes(src []byte) error { 671 elems, err := scanLinearArray(src, []byte{','}, "Float64Array") 672 if err != nil { 673 return err 674 } 675 if *a != nil && len(elems) == 0 { 676 *a = (*a)[:0] 677 } else { 678 b := make(Float64Array, len(elems)) 679 for i, v := range elems { 680 if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { 681 return fmt.Errorf("boil: parsing array element index %d: %v", i, err) 682 } 683 } 684 *a = b 685 } 686 return nil 687 } 688 689 // Value implements the driver.Valuer interface. 690 func (a Float64Array) Value() (driver.Value, error) { 691 if a == nil { 692 return nil, nil 693 } 694 695 if n := len(a); n > 0 { 696 // There will be at least two curly brackets, N bytes of values, 697 // and N-1 bytes of delimiters. 698 b := make([]byte, 1, 1+2*n) 699 b[0] = '{' 700 701 b = strconv.AppendFloat(b, a[0], 'f', -1, 64) 702 for i := 1; i < n; i++ { 703 b = append(b, ',') 704 b = strconv.AppendFloat(b, a[i], 'f', -1, 64) 705 } 706 707 return string(append(b, '}')), nil 708 } 709 710 return "{}", nil 711 } 712 713 // Randomize for sqlboiler 714 func (a *Float64Array) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 715 *a = Float64Array{float64(nextInt()), float64(nextInt())} 716 } 717 718 // GenericArray implements the driver.Valuer and sql.Scanner interfaces for 719 // an array or slice of any dimension. 720 type GenericArray struct{ A interface{} } 721 722 func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { 723 var assign func([]byte, reflect.Value) error 724 var del = "," 725 726 // TODO calculate the assign function for other types 727 // TODO repeat this section on the element type of arrays or slices (multidimensional) 728 { 729 if reflect.PtrTo(rt).Implements(typeSQLScanner) { 730 // dest is always addressable because it is an element of a slice. 731 assign = func(src []byte, dest reflect.Value) (err error) { 732 ss := dest.Addr().Interface().(sql.Scanner) 733 if src == nil { 734 err = ss.Scan(nil) 735 } else { 736 err = ss.Scan(src) 737 } 738 return 739 } 740 goto FoundType 741 } 742 743 assign = func([]byte, reflect.Value) error { 744 return fmt.Errorf("boil: scanning to %s is not implemented; only sql.Scanner", rt) 745 } 746 } 747 748 FoundType: 749 750 if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { 751 del = ad.ArrayDelimiter() 752 } 753 754 return rt, assign, del 755 } 756 757 // Scan implements the sql.Scanner interface. 758 func (a GenericArray) Scan(src interface{}) error { 759 dpv := reflect.ValueOf(a.A) 760 switch { 761 case dpv.Kind() != reflect.Ptr: 762 return fmt.Errorf("boil: destination %T is not a pointer to array or slice", a.A) 763 case dpv.IsNil(): 764 return fmt.Errorf("boil: destination %T is nil", a.A) 765 } 766 767 dv := dpv.Elem() 768 switch dv.Kind() { 769 case reflect.Slice: 770 case reflect.Array: 771 default: 772 return fmt.Errorf("boil: destination %T is not a pointer to array or slice", a.A) 773 } 774 775 switch src := src.(type) { 776 case []byte: 777 return a.scanBytes(src, dv) 778 case string: 779 return a.scanBytes([]byte(src), dv) 780 case nil: 781 if dv.Kind() == reflect.Slice { 782 dv.Set(reflect.Zero(dv.Type())) 783 return nil 784 } 785 } 786 787 return fmt.Errorf("boil: cannot convert %T to %s", src, dv.Type()) 788 } 789 790 func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { 791 dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) 792 dims, elems, err := parseArray(src, []byte(del)) 793 if err != nil { 794 return err 795 } 796 797 // TODO allow multidimensional 798 799 if len(dims) > 1 { 800 return fmt.Errorf("boil: scanning from multidimensional ARRAY%s is not implemented", 801 strings.ReplaceAll(fmt.Sprint(dims), " ", "][")) 802 } 803 804 // Treat a zero-dimensional array like an array with a single dimension of zero. 805 if len(dims) == 0 { 806 dims = append(dims, 0) 807 } 808 809 for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { 810 switch rt.Kind() { 811 case reflect.Slice: 812 case reflect.Array: 813 if rt.Len() != dims[i] { 814 return fmt.Errorf("boil: cannot convert ARRAY%s to %s", 815 strings.ReplaceAll(fmt.Sprint(dims), " ", "]["), dv.Type()) 816 } 817 default: 818 // TODO handle multidimensional 819 } 820 } 821 822 values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) 823 for i, e := range elems { 824 if err := assign(e, values.Index(i)); err != nil { 825 return fmt.Errorf("boil: parsing array element index %d: %v", i, err) 826 } 827 } 828 829 // TODO handle multidimensional 830 831 switch dv.Kind() { 832 case reflect.Slice: 833 dv.Set(values.Slice(0, dims[0])) 834 case reflect.Array: 835 for i := 0; i < dims[0]; i++ { 836 dv.Index(i).Set(values.Index(i)) 837 } 838 } 839 840 return nil 841 } 842 843 // Value implements the driver.Valuer interface. 844 func (a GenericArray) Value() (driver.Value, error) { 845 if a.A == nil { 846 return nil, nil 847 } 848 849 rv := reflect.ValueOf(a.A) 850 851 switch rv.Kind() { 852 case reflect.Slice: 853 if rv.IsNil() { 854 return nil, nil 855 } 856 case reflect.Array: 857 default: 858 return nil, fmt.Errorf("boil: Unable to convert %T to array", a.A) 859 } 860 861 if n := rv.Len(); n > 0 { 862 // There will be at least two curly brackets, N bytes of values, 863 // and N-1 bytes of delimiters. 864 b := make([]byte, 0, 1+2*n) 865 866 b, _, err := appendArray(b, rv, n) 867 return string(b), err 868 } 869 870 return "{}", nil 871 } 872 873 // Int64Array represents a one-dimensional array of the PostgreSQL integer types. 874 type Int64Array []int64 875 876 // Scan implements the sql.Scanner interface. 877 func (a *Int64Array) Scan(src interface{}) error { 878 switch src := src.(type) { 879 case []byte: 880 return a.scanBytes(src) 881 case string: 882 return a.scanBytes([]byte(src)) 883 case nil: 884 *a = nil 885 return nil 886 } 887 888 return fmt.Errorf("boil: cannot convert %T to Int64Array", src) 889 } 890 891 func (a *Int64Array) scanBytes(src []byte) error { 892 elems, err := scanLinearArray(src, []byte{','}, "Int64Array") 893 if err != nil { 894 return err 895 } 896 if *a != nil && len(elems) == 0 { 897 *a = (*a)[:0] 898 } else { 899 b := make(Int64Array, len(elems)) 900 for i, v := range elems { 901 if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { 902 return fmt.Errorf("boil: parsing array element index %d: %v", i, err) 903 } 904 } 905 *a = b 906 } 907 return nil 908 } 909 910 // Value implements the driver.Valuer interface. 911 func (a Int64Array) Value() (driver.Value, error) { 912 if a == nil { 913 return nil, nil 914 } 915 916 if n := len(a); n > 0 { 917 // There will be at least two curly brackets, N bytes of values, 918 // and N-1 bytes of delimiters. 919 b := make([]byte, 1, 1+2*n) 920 b[0] = '{' 921 922 b = strconv.AppendInt(b, a[0], 10) 923 for i := 1; i < n; i++ { 924 b = append(b, ',') 925 b = strconv.AppendInt(b, a[i], 10) 926 } 927 928 return string(append(b, '}')), nil 929 } 930 931 return "{}", nil 932 } 933 934 // Randomize for sqlboiler 935 func (a *Int64Array) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 936 *a = Int64Array{int64(nextInt()), int64(nextInt())} 937 } 938 939 // StringArray represents a one-dimensional array of the PostgreSQL character types. 940 type StringArray []string 941 942 // Scan implements the sql.Scanner interface. 943 func (a *StringArray) Scan(src interface{}) error { 944 switch src := src.(type) { 945 case []byte: 946 return a.scanBytes(src) 947 case string: 948 return a.scanBytes([]byte(src)) 949 case nil: 950 *a = nil 951 return nil 952 } 953 954 return fmt.Errorf("boil: cannot convert %T to StringArray", src) 955 } 956 957 func (a *StringArray) scanBytes(src []byte) error { 958 elems, err := scanLinearArray(src, []byte{','}, "StringArray") 959 if err != nil { 960 return err 961 } 962 if *a != nil && len(elems) == 0 { 963 *a = (*a)[:0] 964 } else { 965 b := make(StringArray, len(elems)) 966 for i, v := range elems { 967 if b[i] = string(v); v == nil { 968 return fmt.Errorf("boil: parsing array element index %d: cannot convert nil to string", i) 969 } 970 } 971 *a = b 972 } 973 return nil 974 } 975 976 // Value implements the driver.Valuer interface. 977 func (a StringArray) Value() (driver.Value, error) { 978 if a == nil { 979 return nil, nil 980 } 981 982 if n := len(a); n > 0 { 983 // There will be at least two curly brackets, 2*N bytes of quotes, 984 // and N-1 bytes of delimiters. 985 b := make([]byte, 1, 1+3*n) 986 b[0] = '{' 987 988 b = appendArrayQuotedBytes(b, []byte(a[0])) 989 for i := 1; i < n; i++ { 990 b = append(b, ',') 991 b = appendArrayQuotedBytes(b, []byte(a[i])) 992 } 993 994 return string(append(b, '}')), nil 995 } 996 997 return "{}", nil 998 } 999 1000 // Randomize for sqlboiler 1001 func (a *StringArray) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 1002 strs := make([]string, 2) 1003 fieldType = strings.TrimPrefix(fieldType, "ARRAY") 1004 1005 for i := range strs { 1006 val, ok := randomize.FormattedString(nextInt, fieldType) 1007 if ok { 1008 strs[i] = val 1009 continue 1010 } 1011 1012 strs[i] = randomize.Str(nextInt, 1) 1013 } 1014 1015 *a = strs 1016 } 1017 1018 // DecimalArray represents a one-dimensional array of the decimal type. 1019 type DecimalArray []Decimal 1020 1021 // Scan implements the sql.Scanner interface. 1022 func (a *DecimalArray) Scan(src interface{}) error { 1023 switch src := src.(type) { 1024 case []byte: 1025 return a.scanBytes(src) 1026 case string: 1027 return a.scanBytes([]byte(src)) 1028 case nil: 1029 *a = nil 1030 return nil 1031 } 1032 1033 return fmt.Errorf("boil: cannot convert %T to DecimalArray", src) 1034 } 1035 1036 func (a *DecimalArray) scanBytes(src []byte) error { 1037 elems, err := scanLinearArray(src, []byte{','}, "DecimalArray") 1038 if err != nil { 1039 return err 1040 } 1041 if *a != nil && len(elems) == 0 { 1042 *a = (*a)[:0] 1043 } else { 1044 b := make(DecimalArray, len(elems)) 1045 for i, v := range elems { 1046 var success bool 1047 b[i].Big, success = new(decimal.Big).SetString(string(v)) 1048 if !success { 1049 return fmt.Errorf("boil: parsing decimal element index as decimal %d: %s", i, v) 1050 } 1051 } 1052 *a = b 1053 } 1054 return nil 1055 } 1056 1057 // Value implements the driver.Valuer interface. 1058 func (a DecimalArray) Value() (driver.Value, error) { 1059 if a == nil { 1060 return nil, nil 1061 } else if len(a) == 0 { 1062 return "{}", nil 1063 } 1064 1065 strs := make([]string, len(a)) 1066 for i, d := range a { 1067 strs[i] = d.String() 1068 } 1069 1070 return "{" + strings.Join(strs, ",") + "}", nil 1071 } 1072 1073 // Randomize for sqlboiler 1074 func (a *DecimalArray) Randomize(nextInt func() int64, fieldType string, shouldBeNull bool) { 1075 d1, d2 := NewDecimal(new(decimal.Big)), NewDecimal(new(decimal.Big)) 1076 d1.SetString(fmt.Sprintf("%d.%d", nextInt()%10, nextInt()%10)) 1077 d2.SetString(fmt.Sprintf("%d.%d", nextInt()%10, nextInt()%10)) 1078 *a = DecimalArray{d1, d2} 1079 } 1080 1081 // appendArray appends rv to the buffer, returning the extended buffer and 1082 // the delimiter used between elements. 1083 // 1084 // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. 1085 func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { 1086 var del string 1087 var err error 1088 1089 b = append(b, '{') 1090 1091 if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { 1092 return b, del, err 1093 } 1094 1095 for i := 1; i < n; i++ { 1096 b = append(b, del...) 1097 if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { 1098 return b, del, err 1099 } 1100 } 1101 1102 return append(b, '}'), del, nil 1103 } 1104 1105 // appendArrayElement appends rv to the buffer, returning the extended buffer 1106 // and the delimiter to use before the next element. 1107 // 1108 // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted 1109 // using driver.DefaultParameterConverter and the resulting []byte or string 1110 // is double-quoted. 1111 // 1112 // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO 1113 func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { 1114 if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { 1115 if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { 1116 if n := rv.Len(); n > 0 { 1117 return appendArray(b, rv, n) 1118 } 1119 1120 return b, "", nil 1121 } 1122 } 1123 1124 var del = "," 1125 var err error 1126 var iv interface{} = rv.Interface() 1127 1128 if ad, ok := iv.(ArrayDelimiter); ok { 1129 del = ad.ArrayDelimiter() 1130 } 1131 1132 if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { 1133 return b, del, err 1134 } 1135 1136 switch v := iv.(type) { 1137 case nil: 1138 return append(b, "NULL"...), del, nil 1139 case []byte: 1140 return appendArrayQuotedBytes(b, v), del, nil 1141 case string: 1142 return appendArrayQuotedBytes(b, []byte(v)), del, nil 1143 } 1144 1145 b, err = appendValue(b, iv) 1146 return b, del, err 1147 } 1148 1149 func appendArrayQuotedBytes(b, v []byte) []byte { 1150 b = append(b, '"') 1151 for { 1152 i := bytes.IndexAny(v, `"\`) 1153 if i < 0 { 1154 b = append(b, v...) 1155 break 1156 } 1157 if i > 0 { 1158 b = append(b, v[:i]...) 1159 } 1160 b = append(b, '\\', v[i]) 1161 v = v[i+1:] 1162 } 1163 return append(b, '"') 1164 } 1165 1166 func appendValue(b []byte, v driver.Value) ([]byte, error) { 1167 return append(b, encode(nil, v, 0)...), nil 1168 } 1169 1170 // parseArray extracts the dimensions and elements of an array represented in 1171 // text format. Only representations emitted by the backend are supported. 1172 // Notably, whitespace around brackets and delimiters is significant, and NULL 1173 // is case-sensitive. 1174 // 1175 // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO 1176 func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { 1177 var depth, i int 1178 1179 if len(src) < 1 || src[0] != '{' { 1180 return nil, nil, fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '{', 0) 1181 } 1182 1183 Open: 1184 for i < len(src) { 1185 switch src[i] { 1186 case '{': 1187 depth++ 1188 i++ 1189 case '}': 1190 elems = make([][]byte, 0) 1191 goto Close 1192 default: 1193 break Open 1194 } 1195 } 1196 dims = make([]int, i) 1197 1198 Element: 1199 for i < len(src) { 1200 switch src[i] { 1201 case '{': 1202 if depth == len(dims) { 1203 break Element 1204 } 1205 depth++ 1206 dims[depth-1] = 0 1207 i++ 1208 case '"': 1209 var elem = []byte{} 1210 var escape bool 1211 for i++; i < len(src); i++ { 1212 if escape { 1213 elem = append(elem, src[i]) 1214 escape = false 1215 } else { 1216 switch src[i] { 1217 default: 1218 elem = append(elem, src[i]) 1219 case '\\': 1220 escape = true 1221 case '"': 1222 elems = append(elems, elem) 1223 i++ 1224 break Element 1225 } 1226 } 1227 } 1228 default: 1229 for start := i; i < len(src); i++ { 1230 if bytes.HasPrefix(src[i:], del) || src[i] == '}' { 1231 elem := src[start:i] 1232 if len(elem) == 0 { 1233 return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) 1234 } 1235 if bytes.Equal(elem, []byte("NULL")) { 1236 elem = nil 1237 } 1238 elems = append(elems, elem) 1239 break Element 1240 } 1241 } 1242 } 1243 } 1244 1245 for i < len(src) { 1246 if bytes.HasPrefix(src[i:], del) && depth > 0 { 1247 dims[depth-1]++ 1248 i += len(del) 1249 goto Element 1250 } else if src[i] == '}' && depth > 0 { 1251 dims[depth-1]++ 1252 depth-- 1253 i++ 1254 } else { 1255 return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) 1256 } 1257 } 1258 1259 Close: 1260 for i < len(src) { 1261 if src[i] == '}' && depth > 0 { 1262 depth-- 1263 i++ 1264 } else { 1265 return nil, nil, fmt.Errorf("boil: unable to parse array; unexpected %q at offset %d", src[i], i) 1266 } 1267 } 1268 if depth > 0 { 1269 err = fmt.Errorf("boil: unable to parse array; expected %q at offset %d", '}', i) 1270 } 1271 if err == nil { 1272 for _, d := range dims { 1273 if (len(elems) % d) != 0 { 1274 err = fmt.Errorf("boil: multidimensional arrays must have elements with matching dimensions") 1275 } 1276 } 1277 } 1278 return 1279 } 1280 1281 func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { 1282 dims, elems, err := parseArray(src, del) 1283 if err != nil { 1284 return nil, err 1285 } 1286 if len(dims) > 1 { 1287 return nil, fmt.Errorf("boil: cannot convert ARRAY%s to %s", strings.ReplaceAll(fmt.Sprint(dims), " ", "]["), typ) 1288 } 1289 return elems, err 1290 }