github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/pgwire/encoding_test.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package pgwire
    12  
    13  import (
    14  	"bytes"
    15  	"context"
    16  	"encoding/binary"
    17  	"encoding/json"
    18  	"fmt"
    19  	"os"
    20  	"path/filepath"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/cockroachdb/apd"
    25  	"github.com/cockroachdb/cockroach/pkg/server/telemetry"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/parser"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    29  	"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
    30  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    31  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    32  	"github.com/cockroachdb/cockroach/pkg/util/metric"
    33  	"github.com/lib/pq/oid"
    34  )
    35  
    36  type encodingTest struct {
    37  	SQL          string
    38  	Datum        tree.Datum
    39  	Oid          oid.Oid
    40  	Text         string
    41  	TextAsBinary []byte
    42  	Binary       []byte
    43  }
    44  
    45  func readEncodingTests(t testing.TB) []*encodingTest {
    46  	var tests []*encodingTest
    47  	f, err := os.Open(filepath.Join("testdata", "encodings.json"))
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	if err := json.NewDecoder(f).Decode(&tests); err != nil {
    52  		t.Fatal(err)
    53  	}
    54  	f.Close()
    55  
    56  	ctx := context.Background()
    57  	sema := tree.MakeSemaContext()
    58  	evalCtx := tree.MakeTestingEvalContext(nil)
    59  
    60  	for _, tc := range tests {
    61  		// Convert the SQL expression to a Datum.
    62  		stmt, err := parser.ParseOne(fmt.Sprintf("SELECT %s", tc.SQL))
    63  		if err != nil {
    64  			t.Fatal(err)
    65  		}
    66  		selectStmt, ok := stmt.AST.(*tree.Select)
    67  		if !ok {
    68  			t.Fatal("not select")
    69  		}
    70  		selectClause, ok := selectStmt.Select.(*tree.SelectClause)
    71  		if !ok {
    72  			t.Fatal("not select clause")
    73  		}
    74  		if len(selectClause.Exprs) != 1 {
    75  			t.Fatal("expected 1 expr")
    76  		}
    77  		expr := selectClause.Exprs[0].Expr
    78  		te, err := expr.TypeCheck(ctx, &sema, types.Any)
    79  		if err != nil {
    80  			t.Fatal(err)
    81  		}
    82  		d, err := te.Eval(&evalCtx)
    83  		if err != nil {
    84  			t.Fatal(err)
    85  		}
    86  		tc.Datum = d
    87  	}
    88  
    89  	return tests
    90  }
    91  
    92  // TestEncodings uses testdata/encodings.json to test expected pgwire encodings
    93  // and ensure they are identical to what Postgres produces. Regenerate that
    94  // file by:
    95  //   Starting a postgres server on :5432 then running:
    96  //   cd pkg/cmd/generate-binary; go run main.go > ../../sql/pgwire/testdata/encodings.json
    97  func TestEncodings(t *testing.T) {
    98  	defer leaktest.AfterTest(t)()
    99  
   100  	tests := readEncodingTests(t)
   101  	buf := newWriteBuffer(metric.NewCounter(metric.Metadata{}))
   102  
   103  	verifyLen := func(t *testing.T) []byte {
   104  		t.Helper()
   105  		b := buf.wrapped.Bytes()
   106  		if len(b) < 4 {
   107  			t.Fatal("short buffer")
   108  		}
   109  		n := binary.BigEndian.Uint32(b)
   110  		// The first 4 bytes are the length prefix.
   111  		data := b[4:]
   112  		if len(data) != int(n) {
   113  			t.Logf("%v", b)
   114  			t.Errorf("expected %d bytes, got %d", n, len(data))
   115  		}
   116  		return data
   117  	}
   118  
   119  	var conv sessiondata.DataConversionConfig
   120  	ctx := context.Background()
   121  	evalCtx := tree.MakeTestingEvalContext(nil)
   122  	t.Run("encode", func(t *testing.T) {
   123  		t.Run(pgwirebase.FormatText.String(), func(t *testing.T) {
   124  			for _, tc := range tests {
   125  				d := tc.Datum
   126  
   127  				buf.reset()
   128  				buf.textFormatter.Buffer.Reset()
   129  				buf.writeTextDatum(ctx, d, conv)
   130  				if buf.err != nil {
   131  					t.Fatal(buf.err)
   132  				}
   133  				got := verifyLen(t)
   134  				if !bytes.Equal(got, tc.TextAsBinary) {
   135  					t.Errorf("unexpected text encoding:\n\t%q found,\n\t%q expected", got, tc.Text)
   136  				}
   137  			}
   138  		})
   139  		t.Run(pgwirebase.FormatBinary.String(), func(t *testing.T) {
   140  			for _, tc := range tests {
   141  				d := tc.Datum
   142  				buf.reset()
   143  				buf.writeBinaryDatum(ctx, d, time.UTC, tc.Oid)
   144  				if buf.err != nil {
   145  					t.Fatal(buf.err)
   146  				}
   147  				got := verifyLen(t)
   148  				if !bytes.Equal(got, tc.Binary) {
   149  					t.Errorf("unexpected binary encoding:\n\t%v found,\n\t%v expected", got, tc.Binary)
   150  				}
   151  			}
   152  		})
   153  	})
   154  	t.Run("decode", func(t *testing.T) {
   155  		for _, tc := range tests {
   156  			switch tc.Datum.(type) {
   157  			case *tree.DFloat:
   158  				// Skip floats because postgres rounds them different than Go.
   159  				continue
   160  			case *tree.DTuple:
   161  				// Unsupported.
   162  				continue
   163  			}
   164  			for code, value := range map[pgwirebase.FormatCode][]byte{
   165  				pgwirebase.FormatText:   tc.TextAsBinary,
   166  				pgwirebase.FormatBinary: tc.Binary,
   167  			} {
   168  				d, err := pgwirebase.DecodeOidDatum(nil, tc.Oid, code, value)
   169  				if err != nil {
   170  					t.Fatal(err)
   171  				}
   172  				// Text decoding returns a string for some kinds of arrays. If that's
   173  				// the case, manually do the conversion to array.
   174  				darr, isdarr := tc.Datum.(*tree.DArray)
   175  				if isdarr && d.ResolvedType().Family() == types.StringFamily {
   176  					d, err = tree.ParseDArrayFromString(&evalCtx, string(value), darr.ParamTyp)
   177  					if err != nil {
   178  						t.Fatal(err)
   179  					}
   180  				}
   181  				if d.Compare(&evalCtx, tc.Datum) != 0 {
   182  					t.Fatalf("%v != %v", d, tc.Datum)
   183  				}
   184  			}
   185  		}
   186  	})
   187  }
   188  
   189  // TestExoticNumericEncodings goes through specific, legal pgwire encodings
   190  // that Postgres itself would usually choose to not produce, which therefore
   191  // would not be covered by TestEncodings. Of course, being valid encodings
   192  // they'd still be accepted and correctly parsed by Postgres.
   193  func TestExoticNumericEncodings(t *testing.T) {
   194  	defer leaktest.AfterTest(t)()
   195  
   196  	testCases := []struct {
   197  		Value    *apd.Decimal
   198  		Encoding []byte
   199  	}{
   200  		{apd.New(0, 0), []byte{0, 0, 0, 0, 0, 0, 0, 0}},
   201  		{apd.New(0, 0), []byte{0, 1, 0, 0, 0, 0, 0, 0, 0, 0}},
   202  		{apd.New(10000, 0), []byte{0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0}},
   203  		{apd.New(10001, 0), []byte{0, 2, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1}},
   204  		{apd.New(1000000, 0), []byte{0, 2, 0, 1, 0, 0, 0, 0, 0, 100, 0, 0}},
   205  		{apd.New(1000001, 0), []byte{0, 2, 0, 1, 0, 0, 0, 0, 0, 100, 0, 1}},
   206  		{apd.New(100000000, 0), []byte{0, 1, 0, 2, 0, 0, 0, 0, 0, 1}},
   207  		{apd.New(100000000, 0), []byte{0, 2, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0}},
   208  		{apd.New(100000000, 0), []byte{0, 3, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0}},
   209  		{apd.New(100000001, 0), []byte{0, 3, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1}},
   210  		// Elixir/Postgrex combinations.
   211  		{apd.New(1234, 0), []byte{0, 2, 0, 0, 0, 0, 0, 0, 0x4, 0xd2, 0, 0}},
   212  		{apd.New(12340, -1), []byte{0, 2, 0, 0, 0, 0, 0, 1, 0x4, 0xd2, 0, 0}},
   213  		{apd.New(1234123400, -2), []byte{0, 3, 0, 1, 0, 0, 0, 2, 0x4, 0xd2, 0x4, 0xd2, 0, 0}},
   214  		{apd.New(12340000, 0), []byte{0, 3, 0, 1, 0, 0, 0, 0, 0x4, 0xd2, 0, 0, 0, 0}},
   215  		{apd.New(123400000, -1), []byte{0, 3, 0, 1, 0, 0, 0, 1, 0x4, 0xd2, 0, 0, 0, 0}},
   216  		{apd.New(12341234000000, -2), []byte{0, 4, 0, 2, 0, 0, 0, 2, 0x4, 0xd2, 0x4, 0xd2, 0, 0, 0, 0}},
   217  		// Postgrex inspired -- even more trailing zeroes!
   218  		{apd.New(0, 0), []byte{0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
   219  		{apd.New(1234123400, -2), []byte{0, 4, 0, 1, 0, 0, 0, 2, 0x4, 0xd2, 0x4, 0xd2, 0, 0, 0, 0}},
   220  	}
   221  
   222  	evalCtx := tree.MakeTestingEvalContext(nil)
   223  	for i, c := range testCases {
   224  		t.Run(fmt.Sprintf("%d_%s", i, c.Value), func(t *testing.T) {
   225  			d, err := pgwirebase.DecodeOidDatum(nil, oid.T_numeric, pgwirebase.FormatBinary, c.Encoding)
   226  			if err != nil {
   227  				t.Fatal(err)
   228  			}
   229  
   230  			expected := &tree.DDecimal{Decimal: *c.Value}
   231  			if d.Compare(&evalCtx, expected) != 0 {
   232  				t.Fatalf("%v != %v", d, expected)
   233  			}
   234  		})
   235  	}
   236  }
   237  
   238  func BenchmarkEncodings(b *testing.B) {
   239  	tests := readEncodingTests(b)
   240  	buf := newWriteBuffer(metric.NewCounter(metric.Metadata{}))
   241  	var conv sessiondata.DataConversionConfig
   242  	ctx := context.Background()
   243  
   244  	for _, tc := range tests {
   245  		b.Run(tc.SQL, func(b *testing.B) {
   246  			d := tc.Datum
   247  
   248  			b.Run("text", func(b *testing.B) {
   249  				for i := 0; i < b.N; i++ {
   250  					buf.reset()
   251  					buf.textFormatter.Buffer.Reset()
   252  					buf.writeTextDatum(ctx, d, conv)
   253  				}
   254  			})
   255  			b.Run("binary", func(b *testing.B) {
   256  				for i := 0; i < b.N; i++ {
   257  					buf.reset()
   258  					buf.writeBinaryDatum(ctx, d, time.UTC, tc.Oid)
   259  				}
   260  			})
   261  		})
   262  	}
   263  }
   264  
   265  func TestEncodingErrorCounts(t *testing.T) {
   266  	defer leaktest.AfterTest(t)()
   267  
   268  	buf := newWriteBuffer(metric.NewCounter(metric.Metadata{}))
   269  	d, _ := tree.ParseDDecimal("Inf")
   270  	buf.writeBinaryDatum(context.Background(), d, nil, d.ResolvedType().Oid())
   271  	if count := telemetry.GetRawFeatureCounts()["pgwire.#32489.binary_decimal_infinity"]; count != 1 {
   272  		t.Fatalf("expected 1 encoding error, got %d", count)
   273  	}
   274  }