github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/testing/pgtest/step.go (about) 1 package pgtest 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/jackc/pgmock" 8 "github.com/jackc/pgproto3/v2" 9 "github.com/jackc/pgtype" 10 ) 11 12 type extendedQueryStep struct { 13 sql string 14 params [][]byte 15 oids []uint32 16 reply pgmock.Script 17 } 18 19 func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error { 20 msg, err := getFrontendMessage(backend) 21 if err != nil { 22 return err 23 } 24 25 // Handle prepared statements, name can be dynamic: lrupsc_5_0 26 if m, ok := msg.(*pgproto3.Parse); ok { 27 want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs} 28 if !reflect.DeepEqual(m, want) { 29 return fmt.Errorf("msg => %#v, e.want => %#v", m, want) 30 } 31 // Anonymous ps falls through 32 if m.Name != "" { 33 script := pgmock.Script{Steps: []pgmock.Step{ 34 pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}), 35 pgmock.ExpectMessage(&pgproto3.Sync{}), 36 pgmock.SendMessage(&pgproto3.ParseComplete{}), 37 pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}), 38 // Postgres responds pgproto3.RowDescription but it's optional for pgx 39 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), 40 }} 41 if err := script.Run(backend); err != nil { 42 return err 43 } 44 } 45 // Expect bind command next 46 msg, err = backend.Receive() 47 if err != nil { 48 return err 49 } 50 } 51 52 if m, ok := msg.(*pgproto3.Bind); ok { 53 var codes []int16 54 for range e.oids { 55 codes = append(codes, pgtype.TextFormatCode) 56 } 57 want := &pgproto3.Bind{ 58 ParameterFormatCodes: codes, 59 Parameters: e.params, 60 ResultFormatCodes: []int16{}, 61 DestinationPortal: m.DestinationPortal, 62 PreparedStatement: m.PreparedStatement, 63 } 64 if !reflect.DeepEqual(m, want) { 65 return fmt.Errorf("msg => %#v, e.want => %#v", msg, want) 66 } 67 e.reply.Steps = append([]pgmock.Step{ 68 pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}), 69 pgmock.ExpectMessage(&pgproto3.Execute{}), 70 pgmock.SendMessage(&pgproto3.ParseComplete{}), 71 pgmock.SendMessage(&pgproto3.BindComplete{}), 72 }, e.reply.Steps...) 73 return e.reply.Run(backend) 74 } 75 76 // Handle simple query 77 want := &pgproto3.Query{String: e.sql} 78 if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) { 79 e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})) 80 return e.reply.Run(backend) 81 } 82 83 return fmt.Errorf("msg => %#v, e.want => %#v", msg, want) 84 } 85 86 // Expects a SQL query in any form: simple, prepared, or anonymous. 87 func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step { 88 return &extendedQueryStep{sql: sql, params: params, oids: oids} 89 } 90 91 type terminateStep struct{} 92 93 func (e *terminateStep) Step(backend *pgproto3.Backend) error { 94 msg, err := getFrontendMessage(backend) 95 if err != nil { 96 return err 97 } 98 99 // Handle simple query 100 if _, ok := msg.(*pgproto3.Terminate); ok { 101 return nil 102 } 103 104 return fmt.Errorf("msg => %#v, e.want => %#v", msg, &pgproto3.Terminate{}) 105 } 106 107 func ExpectTerminate() pgmock.Step { 108 return &terminateStep{} 109 } 110 111 func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) { 112 msg, err := backend.Receive() 113 if err != nil { 114 return nil, err 115 } 116 117 // Sync signals end of batch statements 118 if _, ok := msg.(*pgproto3.Sync); ok { 119 reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}) 120 if err := reply.Step(backend); err != nil { 121 return nil, err 122 } 123 msg, err = backend.Receive() 124 if err != nil { 125 return nil, err 126 } 127 } 128 129 return msg, nil 130 }