github.com/supabase/cli@v1.168.1/internal/testing/pgtest/step.go (about) 1 package pgtest 2 3 import ( 4 "reflect" 5 6 "github.com/go-errors/errors" 7 "github.com/jackc/pgmock" 8 "github.com/jackc/pgproto3/v2" 9 "github.com/jackc/pgtype" 10 ) 11 12 var ci = pgtype.NewConnInfo() 13 14 type extendedQueryStep struct { 15 sql string 16 params [][]byte 17 oids []uint32 18 reply pgmock.Script 19 } 20 21 func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error { 22 msg, err := getFrontendMessage(backend) 23 if err != nil { 24 return err 25 } 26 27 // Handle prepared statements, name can be dynamic: lrupsc_5_0 28 if m, ok := msg.(*pgproto3.Parse); ok { 29 want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs} 30 if !reflect.DeepEqual(m, want) { 31 return errors.Errorf("msg => %#v, e.want => %#v", m, want) 32 } 33 // Anonymous ps falls through 34 if m.Name != "" { 35 script := pgmock.Script{Steps: []pgmock.Step{ 36 pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}), 37 pgmock.ExpectMessage(&pgproto3.Sync{}), 38 pgmock.SendMessage(&pgproto3.ParseComplete{}), 39 pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}), 40 // Postgres responds pgproto3.RowDescription but it's optional for pgx 41 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), 42 }} 43 if err := script.Run(backend); err != nil { 44 return err 45 } 46 } 47 // Expect bind command next 48 msg, err = backend.Receive() 49 if err != nil { 50 return err 51 } 52 } 53 54 if m, ok := msg.(*pgproto3.Bind); ok { 55 var codes []int16 56 for _, oid := range e.oids { 57 codes = append(codes, ci.ParamFormatCodeForOID(oid)) 58 } 59 want := &pgproto3.Bind{ 60 ParameterFormatCodes: codes, 61 Parameters: e.params, 62 ResultFormatCodes: []int16{}, 63 DestinationPortal: m.DestinationPortal, 64 PreparedStatement: m.PreparedStatement, 65 } 66 if !reflect.DeepEqual(m, want) { 67 return errors.Errorf("msg => %#v, e.want => %#v", msg, want) 68 } 69 e.reply.Steps = append([]pgmock.Step{ 70 pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}), 71 pgmock.ExpectMessage(&pgproto3.Execute{}), 72 pgmock.SendMessage(&pgproto3.ParseComplete{}), 73 pgmock.SendMessage(&pgproto3.BindComplete{}), 74 }, e.reply.Steps...) 75 return e.reply.Run(backend) 76 } 77 78 // Handle simple query 79 want := &pgproto3.Query{String: e.sql} 80 if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) { 81 e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) 82 return e.reply.Run(backend) 83 } 84 85 return errors.Errorf("msg => %#v, e.want => %#v", msg, want) 86 } 87 88 // Expects a SQL query in any form: simple, prepared, or anonymous. 89 func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step { 90 return &extendedQueryStep{sql: sql, params: params, oids: oids} 91 } 92 93 type terminateStep struct{} 94 95 func (e *terminateStep) Step(backend *pgproto3.Backend) error { 96 msg, err := getFrontendMessage(backend) 97 if err != nil { 98 return err 99 } 100 101 // Handle simple query 102 if _, ok := msg.(*pgproto3.Terminate); ok { 103 return nil 104 } 105 106 return errors.Errorf("msg => %#v, e.want => %#v", msg, &pgproto3.Terminate{}) 107 } 108 109 func ExpectTerminate() pgmock.Step { 110 return &terminateStep{} 111 } 112 113 func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) { 114 msg, err := backend.Receive() 115 if err != nil { 116 return nil, err 117 } 118 119 // Sync signals end of batch statements 120 if _, ok := msg.(*pgproto3.Sync); ok { 121 reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}) 122 if err := reply.Step(backend); err != nil { 123 return nil, err 124 } 125 msg, err = backend.Receive() 126 if err != nil { 127 return nil, err 128 } 129 } 130 131 return msg, nil 132 }