github.com/jackc/pgx/v5@v5.5.5/pgproto3/backend_test.go (about) 1 package pgproto3_test 2 3 import ( 4 "io" 5 "testing" 6 7 "github.com/jackc/pgx/v5/internal/pgio" 8 "github.com/jackc/pgx/v5/pgproto3" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 ) 12 13 func TestBackendReceiveInterrupted(t *testing.T) { 14 t.Parallel() 15 16 server := &interruptReader{} 17 server.push([]byte{'Q', 0, 0, 0, 6}) 18 19 backend := pgproto3.NewBackend(server, nil) 20 21 msg, err := backend.Receive() 22 if err == nil { 23 t.Fatal("expected err") 24 } 25 if msg != nil { 26 t.Fatalf("did not expect msg, but %v", msg) 27 } 28 29 server.push([]byte{'I', 0}) 30 31 msg, err = backend.Receive() 32 if err != nil { 33 t.Fatal(err) 34 } 35 if msg, ok := msg.(*pgproto3.Query); !ok || msg.String != "I" { 36 t.Fatalf("unexpected msg: %v", msg) 37 } 38 } 39 40 func TestBackendReceiveUnexpectedEOF(t *testing.T) { 41 t.Parallel() 42 43 server := &interruptReader{} 44 server.push([]byte{'Q', 0, 0, 0, 6}) 45 46 backend := pgproto3.NewBackend(server, nil) 47 48 // Receive regular msg 49 msg, err := backend.Receive() 50 assert.Nil(t, msg) 51 assert.Equal(t, io.ErrUnexpectedEOF, err) 52 53 // Receive StartupMessage msg 54 dst := []byte{} 55 dst = pgio.AppendUint32(dst, 1000) // tell the backend we expect 1000 bytes to be read 56 dst = pgio.AppendUint32(dst, 1) // only send 1 byte 57 server.push(dst) 58 59 msg, err = backend.ReceiveStartupMessage() 60 assert.Nil(t, msg) 61 assert.Equal(t, io.ErrUnexpectedEOF, err) 62 } 63 64 func TestStartupMessage(t *testing.T) { 65 t.Parallel() 66 67 t.Run("valid StartupMessage", func(t *testing.T) { 68 want := &pgproto3.StartupMessage{ 69 ProtocolVersion: pgproto3.ProtocolVersionNumber, 70 Parameters: map[string]string{ 71 "username": "tester", 72 }, 73 } 74 dst, err := want.Encode([]byte{}) 75 require.NoError(t, err) 76 77 server := &interruptReader{} 78 server.push(dst) 79 80 backend := pgproto3.NewBackend(server, nil) 81 82 msg, err := backend.ReceiveStartupMessage() 83 require.NoError(t, err) 84 require.Equal(t, want, msg) 85 }) 86 87 t.Run("invalid packet length", func(t *testing.T) { 88 wantErr := "invalid length of startup packet" 89 tests := []struct { 90 name string 91 packetLen uint32 92 }{ 93 { 94 name: "large packet length", 95 // Since the StartupMessage contains the "Length of message contents 96 // in bytes, including self", the max startup packet length is actually 97 // 10000+4. Therefore, let's go past the limit with 10005 98 packetLen: 10005, 99 }, 100 { 101 name: "short packet length", 102 packetLen: 3, 103 }, 104 } 105 for _, tt := range tests { 106 t.Run(tt.name, func(t *testing.T) { 107 server := &interruptReader{} 108 dst := []byte{} 109 dst = pgio.AppendUint32(dst, tt.packetLen) 110 dst = pgio.AppendUint32(dst, pgproto3.ProtocolVersionNumber) 111 server.push(dst) 112 113 backend := pgproto3.NewBackend(server, nil) 114 115 msg, err := backend.ReceiveStartupMessage() 116 require.Error(t, err) 117 require.Nil(t, msg) 118 require.Contains(t, err.Error(), wantErr) 119 }) 120 } 121 }) 122 } 123 124 func TestBackendReceiveExceededMaxBodyLen(t *testing.T) { 125 t.Parallel() 126 127 server := &interruptReader{} 128 server.push([]byte{'Q', 0, 0, 10, 10}) 129 130 backend := pgproto3.NewBackend(server, nil) 131 132 // Set max body len to 5 133 backend.SetMaxBodyLen(5) 134 135 // Receive regular msg 136 msg, err := backend.Receive() 137 assert.Nil(t, msg) 138 var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr 139 assert.ErrorAs(t, err, &invalidBodyLenErr) 140 }