github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/xid8.go (about)

     1  package postgres
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"strconv"
     8  
     9  	"github.com/jackc/pgio"
    10  	"github.com/jackc/pgx/v5/pgtype"
    11  )
    12  
    13  // Adapted from https://github.com/jackc/pgx/blob/ca022267dbbfe7a8ba7070557352a5cd08f6cb37/pgtype/uint32.go
    14  type Uint64Scanner interface {
    15  	ScanUint64(v xid8) error
    16  }
    17  
    18  type Uint64Valuer interface {
    19  	Uint64Value() (xid8, error)
    20  }
    21  
    22  // xid8 is the core type that is used to XID.
    23  type xid8 struct {
    24  	Uint64 uint64
    25  	Valid  bool
    26  }
    27  
    28  func newXid8(u uint64) xid8 {
    29  	return xid8{
    30  		Uint64: u,
    31  		Valid:  true,
    32  	}
    33  }
    34  
    35  func (n *xid8) ScanUint64(v xid8) error {
    36  	*n = v
    37  	return nil
    38  }
    39  
    40  func (n xid8) Uint64Value() (xid8, error) {
    41  	return n, nil
    42  }
    43  
    44  type Uint64Codec struct{}
    45  
    46  func (Uint64Codec) FormatSupported(format int16) bool {
    47  	return format == pgtype.TextFormatCode || format == pgtype.BinaryFormatCode
    48  }
    49  
    50  func (Uint64Codec) PreferredFormat() int16 {
    51  	return pgtype.BinaryFormatCode
    52  }
    53  
    54  func (Uint64Codec) PlanEncode(_ *pgtype.Map, _ uint32, format int16, value any) pgtype.EncodePlan {
    55  	switch format {
    56  	case pgtype.BinaryFormatCode:
    57  		switch value.(type) {
    58  		case uint64:
    59  			return encodePlanUint64CodecBinaryUint64{}
    60  		case Uint64Valuer:
    61  			return encodePlanUint64CodecBinaryUint64Valuer{}
    62  		}
    63  	case pgtype.TextFormatCode:
    64  		switch value.(type) {
    65  		case uint64:
    66  			return encodePlanUint64CodecTextUint64{}
    67  		}
    68  	}
    69  
    70  	return nil
    71  }
    72  
    73  type encodePlanUint64CodecBinaryUint64 struct{}
    74  
    75  func (encodePlanUint64CodecBinaryUint64) Encode(value any, buf []byte) (newBuf []byte, err error) {
    76  	v := value.(uint64)
    77  	return pgio.AppendUint64(buf, v), nil
    78  }
    79  
    80  type encodePlanUint64CodecBinaryUint64Valuer struct{}
    81  
    82  func (encodePlanUint64CodecBinaryUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
    83  	v, err := value.(Uint64Valuer).Uint64Value()
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	if !v.Valid {
    89  		return nil, nil
    90  	}
    91  
    92  	return pgio.AppendUint64(buf, v.Uint64), nil
    93  }
    94  
    95  type encodePlanUint64CodecTextUint64 struct{}
    96  
    97  func (encodePlanUint64CodecTextUint64) Encode(value any, buf []byte) (newBuf []byte, err error) {
    98  	v := value.(uint64)
    99  	return append(buf, strconv.FormatUint(v, 10)...), nil
   100  }
   101  
   102  func (Uint64Codec) PlanScan(_ *pgtype.Map, _ uint32, format int16, target any) pgtype.ScanPlan {
   103  	switch format {
   104  	case pgtype.BinaryFormatCode:
   105  		switch target.(type) {
   106  		case *uint64:
   107  			return scanPlanBinaryUint64ToUint64{}
   108  		case Uint64Scanner:
   109  			return scanPlanBinaryUint64ToUint64Scanner{}
   110  		}
   111  	case pgtype.TextFormatCode:
   112  		switch target.(type) {
   113  		case *uint64:
   114  			return scanPlanTextAnyToUint64{}
   115  		case Uint64Scanner:
   116  			return scanPlanTextAnyToUint64Scanner{}
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  func (c Uint64Codec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
   124  	if src == nil {
   125  		return nil, nil
   126  	}
   127  
   128  	var n uint64
   129  	err := codecScan(c, m, oid, format, src, &n)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	return int64(n), nil
   134  }
   135  
   136  func (c Uint64Codec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
   137  	if src == nil {
   138  		return nil, nil
   139  	}
   140  
   141  	var n uint64
   142  	err := codecScan(c, m, oid, format, src, &n)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	return n, nil
   147  }
   148  
   149  type scanPlanBinaryUint64ToUint64 struct{}
   150  
   151  func (scanPlanBinaryUint64ToUint64) Scan(src []byte, dst any) error {
   152  	if src == nil {
   153  		return fmt.Errorf("cannot scan NULL into %T", dst)
   154  	}
   155  
   156  	if len(src) != 8 {
   157  		return fmt.Errorf("invalid length for uint64: %v", len(src))
   158  	}
   159  
   160  	p := (dst).(*uint64)
   161  	*p = binary.BigEndian.Uint64(src)
   162  
   163  	return nil
   164  }
   165  
   166  type scanPlanBinaryUint64ToUint64Scanner struct{}
   167  
   168  func (scanPlanBinaryUint64ToUint64Scanner) Scan(src []byte, dst any) error {
   169  	s, ok := (dst).(Uint64Scanner)
   170  	if !ok {
   171  		return pgtype.ErrScanTargetTypeChanged
   172  	}
   173  
   174  	if src == nil {
   175  		return s.ScanUint64(xid8{})
   176  	}
   177  
   178  	if len(src) != 8 {
   179  		return fmt.Errorf("invalid length for uint64: %v", len(src))
   180  	}
   181  
   182  	n := binary.BigEndian.Uint64(src)
   183  
   184  	return s.ScanUint64(xid8{Uint64: n, Valid: true})
   185  }
   186  
   187  type scanPlanTextAnyToUint64Scanner struct{}
   188  
   189  func (scanPlanTextAnyToUint64Scanner) Scan(src []byte, dst any) error {
   190  	s, ok := (dst).(Uint64Scanner)
   191  	if !ok {
   192  		return pgtype.ErrScanTargetTypeChanged
   193  	}
   194  
   195  	if src == nil {
   196  		return s.ScanUint64(xid8{})
   197  	}
   198  
   199  	n, err := strconv.ParseUint(string(src), 10, 64)
   200  	if err != nil {
   201  		return err
   202  	}
   203  
   204  	return s.ScanUint64(xid8{Uint64: n, Valid: true})
   205  }
   206  
   207  func codecScan(codec pgtype.Codec, m *pgtype.Map, oid uint32, format int16, src []byte, dst any) error {
   208  	scanPlan := codec.PlanScan(m, oid, format, dst)
   209  	if scanPlan == nil {
   210  		return fmt.Errorf("PlanScan did not find a plan")
   211  	}
   212  	return scanPlan.Scan(src, dst)
   213  }
   214  
   215  type scanPlanTextAnyToUint64 struct{}
   216  
   217  func (scanPlanTextAnyToUint64) Scan(src []byte, dst any) error {
   218  	if src == nil {
   219  		return fmt.Errorf("cannot scan NULL into %T", dst)
   220  	}
   221  
   222  	p, ok := (dst).(*uint64)
   223  	if !ok {
   224  		return pgtype.ErrScanTargetTypeChanged
   225  	}
   226  
   227  	n, err := strconv.ParseUint(string(src), 10, 32)
   228  	if err != nil {
   229  		return err
   230  	}
   231  
   232  	*p = n
   233  	return nil
   234  }