github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/testing/pgtest/mock.go (about) 1 package pgtest 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "reflect" 8 "testing" 9 "time" 10 11 "github.com/jackc/pgmock" 12 "github.com/jackc/pgproto3/v2" 13 "github.com/jackc/pgtype" 14 "github.com/jackc/pgx/v4" 15 "google.golang.org/grpc/test/bufconn" 16 ) 17 18 var ci = pgtype.NewConnInfo() 19 20 type MockConn struct { 21 // Duplex server listener backed by in-memory buffer 22 server *bufconn.Listener 23 24 // Mock server requests and responses 25 script pgmock.Script 26 27 // Status parameters emitted by postgres on first connect 28 status map[string]string 29 30 // Channel for reporting all server error 31 errChan chan error 32 } 33 34 func (r *MockConn) getStartupMessage(config *pgx.ConnConfig) []pgmock.Step { 35 var steps []pgmock.Step 36 // Add auth message 37 steps = append( 38 steps, 39 pgmock.ExpectMessage(&pgproto3.StartupMessage{ 40 ProtocolVersion: pgproto3.ProtocolVersionNumber, 41 Parameters: map[string]string{"database": config.Database, "user": config.User}, 42 }), 43 pgmock.SendMessage(&pgproto3.AuthenticationOk{}), 44 ) 45 // Add status message 46 r.status["session_authorization"] = config.User 47 for key, value := range r.status { 48 steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterStatus{Name: key, Value: value})) 49 } 50 // Add ready message 51 steps = append( 52 steps, 53 pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), 54 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), 55 ) 56 return steps 57 } 58 59 // Configures pgx to use the mock dialer. 60 // 61 // The mock dialer provides a full duplex net.Conn backed by an in-memory buffer. 62 // It is implemented by grcp/test/bufconn package. 63 func (r *MockConn) Intercept(config *pgx.ConnConfig) { 64 // Override config for test 65 config.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { 66 return r.server.DialContext(ctx) 67 } 68 config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { 69 return []string{"127.0.0.1"}, nil 70 } 71 config.TLSConfig = nil 72 // Add startup message 73 r.script.Steps = append(r.getStartupMessage(config), r.script.Steps...) 74 } 75 76 // Adds a simple query or prepared statement to the mock connection. 77 func (r *MockConn) Query(sql string, args ...interface{}) *MockConn { 78 var oids []uint32 79 var params [][]byte 80 for _, v := range args { 81 if dt, ok := ci.DataTypeForValue(v); ok { 82 if err := dt.Value.Set(v); err != nil { 83 continue 84 } 85 value, err := (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{}) 86 if err != nil { 87 continue 88 } 89 params = append(params, value) 90 oids = append(oids, dt.OID) 91 } 92 } 93 r.script.Steps = append(r.script.Steps, ExpectQuery(sql, params, oids)) 94 return r 95 } 96 97 func getDataTypeSize(v interface{}) int16 { 98 t := reflect.TypeOf(v) 99 k := t.Kind() 100 if k < reflect.Int || k > reflect.Complex128 { 101 return -1 102 } 103 return int16(t.Size()) 104 } 105 106 func (r *MockConn) lastQuery() *extendedQueryStep { 107 return r.script.Steps[len(r.script.Steps)-1].(*extendedQueryStep) 108 } 109 110 // Adds a server reply using text protocol format. 111 // 112 // TODO: support binary protocol 113 func (r *MockConn) Reply(tag string, rows ...[]interface{}) *MockConn { 114 q := r.lastQuery() 115 // Add field description 116 if len(rows) > 0 { 117 var desc pgproto3.RowDescription 118 for i, v := range rows[0] { 119 name := fmt.Sprintf("c_%02d", i) 120 if dt, ok := ci.DataTypeForValue(v); ok { 121 size := getDataTypeSize(v) 122 desc.Fields = append(desc.Fields, pgproto3.FieldDescription{ 123 Name: []byte(name), 124 TableOID: 17131, 125 TableAttributeNumber: 1, 126 DataTypeOID: dt.OID, 127 DataTypeSize: size, 128 TypeModifier: -1, 129 Format: pgtype.TextFormatCode, 130 }) 131 } 132 } 133 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&desc)) 134 } else { 135 // No data is optional, but we add for completeness 136 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&pgproto3.NoData{})) 137 } 138 // Add row data 139 for _, data := range rows { 140 var dr pgproto3.DataRow 141 for _, v := range data { 142 if dt, ok := ci.DataTypeForValue(v); ok { 143 if err := dt.Value.Set(v); err != nil { 144 continue 145 } 146 if value, err := (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{}); err == nil { 147 dr.Values = append(dr.Values, value) 148 } 149 } 150 } 151 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&dr)) 152 } 153 // Add completion message 154 var complete pgproto3.BackendMessage 155 if tag == "" { 156 complete = &pgproto3.EmptyQueryResponse{} 157 } else { 158 complete = &pgproto3.CommandComplete{CommandTag: []byte(tag)} 159 } 160 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(complete)) 161 return r 162 } 163 164 // Simulates an error reply from the server. 165 // 166 // TODO: simulate a notice reply 167 func (r *MockConn) ReplyError(code, message string) *MockConn { 168 q := r.lastQuery() 169 q.reply.Steps = append( 170 q.reply.Steps, 171 pgmock.SendMessage(&pgproto3.ErrorResponse{ 172 Severity: "ERROR", 173 SeverityUnlocalized: "ERROR", 174 Code: code, 175 Message: message, 176 }), 177 ) 178 return r 179 } 180 181 func (r *MockConn) Close(t *testing.T) { 182 if err := <-r.errChan; err != nil { 183 t.Fatalf("failed to close %v", err) 184 } 185 if err := r.server.Close(); err != nil { 186 t.Fatalf("failed to close %v", err) 187 } 188 } 189 190 func NewWithStatus(status map[string]string) *MockConn { 191 const bufSize = 1024 * 1024 192 mock := MockConn{ 193 server: bufconn.Listen(bufSize), 194 status: status, 195 errChan: make(chan error, 1), 196 } 197 // Start server in background 198 const timeout = time.Millisecond * 450 199 go func() { 200 defer close(mock.errChan) 201 // Block until we've opened a TCP connection 202 conn, err := mock.server.Accept() 203 if err != nil { 204 mock.errChan <- err 205 return 206 } 207 defer conn.Close() 208 // Prevent server from hanging the test 209 err = conn.SetDeadline(time.Now().Add(timeout)) 210 if err != nil { 211 mock.errChan <- err 212 return 213 } 214 // Always expect clients to terminate the request 215 mock.script.Steps = append(mock.script.Steps, ExpectTerminate()) 216 err = mock.script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) 217 if err != nil { 218 mock.errChan <- err 219 return 220 } 221 }() 222 223 return &mock 224 } 225 226 func NewConn() *MockConn { 227 return NewWithStatus(map[string]string{ 228 "application_name": "", 229 "client_encoding": "UTF8", 230 "DateStyle": "ISO, MDY", 231 "default_transaction_read_only": "off", 232 "in_hot_standby": "off", 233 "integer_datetimes": "on", 234 "IntervalStyle": "postgres", 235 "is_superuser": "on", 236 "server_encoding": "UTF8", 237 "server_version": "14.3 (Debian 14.3-1.pgdg110+1)", 238 "standard_conforming_strings": "on", 239 "TimeZone": "UTC", 240 }) 241 }