github.com/jackc/pgx/v5@v5.5.5/pgtype/numeric_test.go (about)

     1  package pgtype_test
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"math"
     7  	"math/big"
     8  	"math/rand"
     9  	"reflect"
    10  	"strconv"
    11  	"testing"
    12  
    13  	pgx "github.com/jackc/pgx/v5"
    14  	"github.com/jackc/pgx/v5/pgtype"
    15  	"github.com/jackc/pgx/v5/pgxtest"
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  )
    19  
    20  func mustParseBigInt(t *testing.T, src string) *big.Int {
    21  	i := &big.Int{}
    22  	if _, ok := i.SetString(src, 10); !ok {
    23  		t.Fatalf("could not parse big.Int: %s", src)
    24  	}
    25  	return i
    26  }
    27  
    28  func isExpectedEqNumeric(a any) func(any) bool {
    29  	return func(v any) bool {
    30  		aa := a.(pgtype.Numeric)
    31  		vv := v.(pgtype.Numeric)
    32  
    33  		if aa.Valid != vv.Valid {
    34  			return false
    35  		}
    36  
    37  		// If NULL doesn't matter what the rest of the values are.
    38  		if !aa.Valid {
    39  			return true
    40  		}
    41  
    42  		if !(aa.NaN == vv.NaN && aa.InfinityModifier == vv.InfinityModifier) {
    43  			return false
    44  		}
    45  
    46  		// If NaN or InfinityModifier are set then Int and Exp don't matter.
    47  		if aa.NaN || aa.InfinityModifier != pgtype.Finite {
    48  			return true
    49  		}
    50  
    51  		aaInt := (&big.Int{}).Set(aa.Int)
    52  		vvInt := (&big.Int{}).Set(vv.Int)
    53  
    54  		if aa.Exp < vv.Exp {
    55  			mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(vv.Exp-aa.Exp)), nil)
    56  			vvInt.Mul(vvInt, mul)
    57  		} else if aa.Exp > vv.Exp {
    58  			mul := (&big.Int{}).Exp(big.NewInt(10), big.NewInt(int64(aa.Exp-vv.Exp)), nil)
    59  			aaInt.Mul(aaInt, mul)
    60  		}
    61  
    62  		return aaInt.Cmp(vvInt) == 0
    63  	}
    64  }
    65  
    66  func mustParseNumeric(t *testing.T, src string) pgtype.Numeric {
    67  	var n pgtype.Numeric
    68  	plan := pgtype.NumericCodec{}.PlanScan(nil, pgtype.NumericOID, pgtype.TextFormatCode, &n)
    69  	require.NotNil(t, plan)
    70  	err := plan.Scan([]byte(src), &n)
    71  	require.NoError(t, err)
    72  	return n
    73  }
    74  
    75  func TestNumericCodec(t *testing.T) {
    76  	skipCockroachDB(t, "server formats numeric text format differently")
    77  
    78  	max := new(big.Int).Exp(big.NewInt(10), big.NewInt(147454), nil)
    79  	max.Add(max, big.NewInt(1))
    80  	longestNumeric := pgtype.Numeric{Int: max, Exp: -16383, Valid: true}
    81  
    82  	pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{
    83  		{mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))},
    84  		{mustParseNumeric(t, "3.14159"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "3.14159"))},
    85  		{mustParseNumeric(t, "100010001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001"))},
    86  		{mustParseNumeric(t, "100010001.0001"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "100010001.0001"))},
    87  		{mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "4237234789234789289347892374324872138321894178943189043890124832108934.43219085471578891547854892438945012347981"))},
    88  		{mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.8925092023480223478923478978978937897879595901237890234789243679037419057877231734823098432903527585734549035904590854890345905434578345789347890402348952348905890489054234237489234987723894789234"))},
    89  		{mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0.000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000123"))},
    90  		{pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "243723409723490243842378942378901237502734019231380123"), Exp: 23790, Valid: true})},
    91  		{pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "2437"), Exp: 23790, Valid: true})},
    92  		{pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 80, Valid: true})},
    93  		{pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 81, Valid: true})},
    94  		{pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 82, Valid: true})},
    95  		{pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 83, Valid: true})},
    96  		{pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "43723409723490243842378942378901237502734019231380123"), Exp: 84, Valid: true})},
    97  		{pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "913423409823409243892349028349023482934092340892390101"), Exp: -14021, Valid: true})},
    98  		{pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -90, Valid: true})},
    99  		{pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -91, Valid: true})},
   100  		{pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -92, Valid: true})},
   101  		{pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{Int: mustParseBigInt(t, "13423409823409243892349028349023482934092340892390101"), Exp: -93, Valid: true})},
   102  		{pgtype.Numeric{NaN: true, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{NaN: true, Valid: true})},
   103  		{longestNumeric, new(pgtype.Numeric), isExpectedEqNumeric(longestNumeric)},
   104  		{mustParseNumeric(t, "1"), new(int64), isExpectedEq(int64(1))},
   105  		{math.NaN(), new(float64), func(a any) bool { return math.IsNaN(a.(float64)) }},
   106  		{float32(math.NaN()), new(float32), func(a any) bool { return math.IsNaN(float64(a.(float32))) }},
   107  		{int64(-1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))},
   108  		{int64(0), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))},
   109  		{int64(1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))},
   110  		{int64(math.MinInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64, 10)))},
   111  		{int64(math.MinInt64 + 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MinInt64+1, 10)))},
   112  		{int64(math.MaxInt64), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64, 10)))},
   113  		{int64(math.MaxInt64 - 1), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, strconv.FormatInt(math.MaxInt64-1, 10)))},
   114  		{uint64(100), new(uint64), isExpectedEq(uint64(100))},
   115  		{uint64(math.MaxUint64), new(uint64), isExpectedEq(uint64(math.MaxUint64))},
   116  		{uint(math.MaxUint), new(uint), isExpectedEq(uint(math.MaxUint))},
   117  		{uint(100), new(uint), isExpectedEq(uint(100))},
   118  		{"1.23", new(string), isExpectedEq("1.23")},
   119  		{pgtype.Numeric{}, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})},
   120  		{nil, new(pgtype.Numeric), isExpectedEq(pgtype.Numeric{})},
   121  		{mustParseNumeric(t, "1"), new(string), isExpectedEq("1")},
   122  		{pgtype.Numeric{NaN: true, Valid: true}, new(string), isExpectedEq("NaN")},
   123  	})
   124  
   125  	pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "int8", []pgxtest.ValueRoundTripTest{
   126  		{mustParseNumeric(t, "-1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "-1"))},
   127  		{mustParseNumeric(t, "0"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "0"))},
   128  		{mustParseNumeric(t, "1"), new(pgtype.Numeric), isExpectedEqNumeric(mustParseNumeric(t, "1"))},
   129  	})
   130  }
   131  
   132  func TestNumericCodecInfinity(t *testing.T) {
   133  	skipCockroachDB(t, "server formats numeric text format differently")
   134  	skipPostgreSQLVersionLessThan(t, 14)
   135  
   136  	pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", []pgxtest.ValueRoundTripTest{
   137  		{math.Inf(1), new(float64), isExpectedEq(math.Inf(1))},
   138  		{float32(math.Inf(1)), new(float32), isExpectedEq(float32(math.Inf(1)))},
   139  		{math.Inf(-1), new(float64), isExpectedEq(math.Inf(-1))},
   140  		{float32(math.Inf(-1)), new(float32), isExpectedEq(float32(math.Inf(-1)))},
   141  		{pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true})},
   142  		{pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(pgtype.Numeric), isExpectedEqNumeric(pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true})},
   143  		{pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, new(string), isExpectedEq("Infinity")},
   144  		{pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, new(string), isExpectedEq("-Infinity")},
   145  	})
   146  }
   147  
   148  func TestNumericFloat64Valuer(t *testing.T) {
   149  	for i, tt := range []struct {
   150  		n pgtype.Numeric
   151  		f pgtype.Float8
   152  	}{
   153  		{mustParseNumeric(t, "1"), pgtype.Float8{Float64: 1, Valid: true}},
   154  		{mustParseNumeric(t, "0.0000000000000000001"), pgtype.Float8{Float64: 0.0000000000000000001, Valid: true}},
   155  		{mustParseNumeric(t, "-99999999999"), pgtype.Float8{Float64: -99999999999, Valid: true}},
   156  		{pgtype.Numeric{InfinityModifier: pgtype.Infinity, Valid: true}, pgtype.Float8{Float64: math.Inf(1), Valid: true}},
   157  		{pgtype.Numeric{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, pgtype.Float8{Float64: math.Inf(-1), Valid: true}},
   158  		{pgtype.Numeric{Valid: true}, pgtype.Float8{Valid: true}},
   159  		{pgtype.Numeric{}, pgtype.Float8{}},
   160  	} {
   161  		f, err := tt.n.Float64Value()
   162  		assert.NoErrorf(t, err, "%d", i)
   163  		assert.Equalf(t, tt.f, f, "%d", i)
   164  	}
   165  
   166  	f, err := pgtype.Numeric{NaN: true, Valid: true}.Float64Value()
   167  	assert.NoError(t, err)
   168  	assert.True(t, math.IsNaN(f.Float64))
   169  	assert.True(t, f.Valid)
   170  }
   171  
   172  func TestNumericCodecFuzz(t *testing.T) {
   173  	skipCockroachDB(t, "server formats numeric text format differently")
   174  
   175  	r := rand.New(rand.NewSource(0))
   176  	max := &big.Int{}
   177  	max.SetString("9999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", 10)
   178  
   179  	tests := make([]pgxtest.ValueRoundTripTest, 0, 2000)
   180  	for i := 0; i < 10; i++ {
   181  		for j := -50; j < 50; j++ {
   182  			num := (&big.Int{}).Rand(r, max)
   183  
   184  			n := pgtype.Numeric{Int: num, Exp: int32(j), Valid: true}
   185  			tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)})
   186  
   187  			negNum := &big.Int{}
   188  			negNum.Neg(num)
   189  			n = pgtype.Numeric{Int: negNum, Exp: int32(j), Valid: true}
   190  			tests = append(tests, pgxtest.ValueRoundTripTest{n, new(pgtype.Numeric), isExpectedEqNumeric(n)})
   191  		}
   192  	}
   193  
   194  	pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "numeric", tests)
   195  }
   196  
   197  func TestNumericMarshalJSON(t *testing.T) {
   198  	skipCockroachDB(t, "server formats numeric text format differently")
   199  
   200  	defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
   201  
   202  		for i, tt := range []struct {
   203  			decString string
   204  		}{
   205  			{"NaN"},
   206  			{"0"},
   207  			{"1"},
   208  			{"-1"},
   209  			{"1000000000000000000"},
   210  			{"1234.56789"},
   211  			{"1.56789"},
   212  			{"0.00000000000056789"},
   213  			{"0.00123000"},
   214  			{"123e-3"},
   215  			{"243723409723490243842378942378901237502734019231380123e23790"},
   216  			{"3409823409243892349028349023482934092340892390101e-14021"},
   217  			{"-1.1"},
   218  			{"-1.0231"},
   219  			{"-10.0231"},
   220  			{"-0.1"},   // failed with "invalid character '.' in numeric literal"
   221  			{"-0.01"},  // failed with "invalid character '-' after decimal point in numeric literal"
   222  			{"-0.001"}, // failed with "invalid character '-' after top-level value"
   223  		} {
   224  			var num pgtype.Numeric
   225  			var pgJSON string
   226  			err := conn.QueryRow(ctx, `select $1::numeric, to_json($1::numeric)`, tt.decString).Scan(&num, &pgJSON)
   227  			require.NoErrorf(t, err, "%d", i)
   228  
   229  			goJSON, err := json.Marshal(num)
   230  			require.NoErrorf(t, err, "%d", i)
   231  
   232  			require.Equal(t, pgJSON, string(goJSON))
   233  		}
   234  	})
   235  }
   236  
   237  func TestNumericUnmarshalJSON(t *testing.T) {
   238  	tests := []struct {
   239  		name    string
   240  		want    *pgtype.Numeric
   241  		src     []byte
   242  		wantErr bool
   243  	}{
   244  		{
   245  			name:    "null",
   246  			want:    &pgtype.Numeric{},
   247  			src:     []byte(`null`),
   248  			wantErr: false,
   249  		},
   250  		{
   251  			name:    "NaN",
   252  			want:    &pgtype.Numeric{Valid: true, NaN: true},
   253  			src:     []byte(`"NaN"`),
   254  			wantErr: false,
   255  		},
   256  		{
   257  			name:    "0",
   258  			want:    &pgtype.Numeric{Valid: true, Int: big.NewInt(0)},
   259  			src:     []byte("0"),
   260  			wantErr: false,
   261  		},
   262  		{
   263  			name:    "1",
   264  			want:    &pgtype.Numeric{Valid: true, Int: big.NewInt(1)},
   265  			src:     []byte("1"),
   266  			wantErr: false,
   267  		},
   268  		{
   269  			name:    "-1",
   270  			want:    &pgtype.Numeric{Valid: true, Int: big.NewInt(-1)},
   271  			src:     []byte("-1"),
   272  			wantErr: false,
   273  		},
   274  		{
   275  			name:    "bigInt",
   276  			want:    &pgtype.Numeric{Valid: true, Int: big.NewInt(1), Exp: 30},
   277  			src:     []byte("1000000000000000000000000000000"),
   278  			wantErr: false,
   279  		},
   280  		{
   281  			name:    "float: 1234.56789",
   282  			want:    &pgtype.Numeric{Valid: true, Int: big.NewInt(123456789), Exp: -5},
   283  			src:     []byte("1234.56789"),
   284  			wantErr: false,
   285  		},
   286  		{
   287  			name:    "invalid value",
   288  			want:    &pgtype.Numeric{},
   289  			src:     []byte("0xffff"),
   290  			wantErr: true,
   291  		},
   292  	}
   293  	for _, tt := range tests {
   294  		t.Run(tt.name, func(t *testing.T) {
   295  			got := &pgtype.Numeric{}
   296  			if err := got.UnmarshalJSON(tt.src); (err != nil) != tt.wantErr {
   297  				t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
   298  			}
   299  			if !reflect.DeepEqual(got, tt.want) {
   300  				t.Errorf("UnmarshalJSON() got = %v, want %v", got, tt.want)
   301  			}
   302  		})
   303  	}
   304  }