github.com/supabase/cli@v1.168.1/internal/testing/pgtest/mock.go (about) 1 package pgtest 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "reflect" 8 "testing" 9 "time" 10 11 "github.com/jackc/pgmock" 12 "github.com/jackc/pgproto3/v2" 13 "github.com/jackc/pgtype" 14 "github.com/jackc/pgx/v4" 15 "google.golang.org/grpc/test/bufconn" 16 ) 17 18 type MockConn struct { 19 // Duplex server listener backed by in-memory buffer 20 server *bufconn.Listener 21 22 // Mock server requests and responses 23 script pgmock.Script 24 25 // Status parameters emitted by postgres on first connect 26 status map[string]string 27 28 // Channel for reporting all server error 29 errChan chan error 30 } 31 32 func (r *MockConn) getStartupMessage(config *pgx.ConnConfig) []pgmock.Step { 33 var steps []pgmock.Step 34 // Add auth message 35 params := map[string]string{"user": config.User} 36 if len(config.Database) > 0 { 37 params["database"] = config.Database 38 } 39 steps = append( 40 steps, 41 pgmock.ExpectMessage(&pgproto3.StartupMessage{ 42 ProtocolVersion: pgproto3.ProtocolVersionNumber, 43 Parameters: params, 44 }), 45 pgmock.SendMessage(&pgproto3.AuthenticationOk{}), 46 ) 47 // Add status message 48 r.status["session_authorization"] = config.User 49 for key, value := range r.status { 50 steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterStatus{Name: key, Value: value})) 51 } 52 // Add ready message 53 steps = append( 54 steps, 55 pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), 56 pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), 57 ) 58 return steps 59 } 60 61 // Configures pgx to use the mock dialer. 62 // 63 // The mock dialer provides a full duplex net.Conn backed by an in-memory buffer. 64 // It is implemented by grcp/test/bufconn package. 65 func (r *MockConn) Intercept(config *pgx.ConnConfig) { 66 // Override config for test 67 config.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { 68 return r.server.DialContext(ctx) 69 } 70 config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { 71 return []string{"127.0.0.1"}, nil 72 } 73 config.TLSConfig = nil 74 // Add startup message 75 r.script.Steps = append(r.getStartupMessage(config), r.script.Steps...) 76 } 77 78 // Adds a simple query or prepared statement to the mock connection. 79 func (r *MockConn) Query(sql string, args ...interface{}) *MockConn { 80 var oids []uint32 81 var params [][]byte 82 for _, v := range args { 83 if value, oid := r.encodeValueArg(v); oid > 0 { 84 params = append(params, value) 85 oids = append(oids, oid) 86 } 87 } 88 r.script.Steps = append(r.script.Steps, ExpectQuery(sql, params, oids)) 89 return r 90 } 91 92 func (r *MockConn) encodeValueArg(v interface{}) (value []byte, oid uint32) { 93 if v == nil { 94 return nil, pgtype.TextArrayOID 95 } 96 dt, ok := ci.DataTypeForValue(v) 97 if !ok { 98 r.errChan <- fmt.Errorf("no suitable type for arg: %v", v) 99 return nil, 0 100 } 101 if err := dt.Value.Set(v); err != nil { 102 r.errChan <- fmt.Errorf("failed to set value: %w", err) 103 return nil, 0 104 } 105 var err error 106 switch dt.OID { 107 case pgtype.TextOID: 108 value, err = (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{}) 109 default: 110 value, err = (dt.Value).(pgtype.BinaryEncoder).EncodeBinary(ci, []byte{}) 111 } 112 if err != nil { 113 r.errChan <- fmt.Errorf("failed to encode arg: %w", err) 114 return nil, 0 115 } 116 return value, dt.OID 117 } 118 119 func getDataTypeSize(v interface{}) int16 { 120 t := reflect.TypeOf(v) 121 k := t.Kind() 122 if k < reflect.Int || k > reflect.Complex128 { 123 return -1 124 } 125 return int16(t.Size()) 126 } 127 128 func (r *MockConn) lastQuery() *extendedQueryStep { 129 return r.script.Steps[len(r.script.Steps)-1].(*extendedQueryStep) 130 } 131 132 // Adds a server reply using binary or text protocol format. 133 // 134 // TODO: support prepared statements when using binary protocol 135 func (r *MockConn) Reply(tag string, rows ...interface{}) *MockConn { 136 q := r.lastQuery() 137 // Add field description 138 if len(rows) > 0 { 139 var desc pgproto3.RowDescription 140 if arr, ok := rows[0].([]interface{}); ok { 141 for i, v := range arr { 142 name := fmt.Sprintf("c_%02d", i) 143 if fd := toFieldDescription(v); fd != nil { 144 fd.Name = []byte(name) 145 desc.Fields = append(desc.Fields, *fd) 146 } else { 147 r.errChan <- fmt.Errorf("failed to describe field: %s", name) 148 } 149 } 150 } else if t := reflect.TypeOf(rows[0]); t.Kind() == reflect.Struct { 151 s := reflect.ValueOf(rows[0]) 152 for i := 0; i < s.NumField(); i++ { 153 name := t.Field(i).Name 154 v := s.Field(i).Interface() 155 if fd := toFieldDescription(v); fd != nil { 156 fd.Name = []byte(name) 157 desc.Fields = append(desc.Fields, *fd) 158 } else { 159 r.errChan <- fmt.Errorf("failed to describe field: %s", name) 160 } 161 } 162 } else { 163 r.errChan <- fmt.Errorf("reply type must be one of [array, struct], found: %s", t.Kind()) 164 } 165 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&desc)) 166 } else { 167 // No data is optional, but we add for completeness 168 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&pgproto3.NoData{})) 169 } 170 // Add row data 171 for _, data := range rows { 172 var dr pgproto3.DataRow 173 if arr, ok := data.([]interface{}); ok { 174 for _, v := range arr { 175 if value, oid := r.encodeValueArg(v); oid > 0 { 176 dr.Values = append(dr.Values, value) 177 } 178 } 179 } else if t := reflect.TypeOf(data); t.Kind() == reflect.Struct { 180 s := reflect.ValueOf(rows[0]) 181 for i := 0; i < s.NumField(); i++ { 182 v := s.Field(i).Interface() 183 if value, oid := r.encodeValueArg(v); oid > 0 { 184 dr.Values = append(dr.Values, value) 185 } 186 } 187 } else { 188 r.errChan <- fmt.Errorf("invalid reply value: %v", data) 189 } 190 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&dr)) 191 } 192 // Add completion message 193 var complete pgproto3.BackendMessage 194 if tag == "" { 195 complete = &pgproto3.EmptyQueryResponse{} 196 } else { 197 complete = &pgproto3.CommandComplete{CommandTag: []byte(tag)} 198 } 199 q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(complete)) 200 return r 201 } 202 203 func toFieldDescription(v interface{}) *pgproto3.FieldDescription { 204 if dt, ok := ci.DataTypeForValue(v); ok { 205 size := getDataTypeSize(v) 206 format := ci.ParamFormatCodeForOID(dt.OID) 207 return &pgproto3.FieldDescription{ 208 TableOID: 17131, 209 TableAttributeNumber: 1, 210 DataTypeOID: dt.OID, 211 DataTypeSize: size, 212 TypeModifier: -1, 213 Format: format, 214 } 215 } 216 return nil 217 } 218 219 // Simulates an error reply from the server. 220 // 221 // TODO: simulate a notice reply 222 func (r *MockConn) ReplyError(code, message string) *MockConn { 223 q := r.lastQuery() 224 q.reply.Steps = append( 225 q.reply.Steps, 226 pgmock.SendMessage(&pgproto3.ErrorResponse{ 227 Severity: "ERROR", 228 SeverityUnlocalized: "ERROR", 229 Code: code, 230 Message: message, 231 }), 232 ) 233 return r 234 } 235 236 func (r *MockConn) Close(t *testing.T) { 237 for err := range r.errChan { 238 t.Errorf("pgmock error: %v", err) 239 } 240 if err := r.server.Close(); err != nil { 241 t.Fatalf("failed to close: %v", err) 242 } 243 } 244 245 func NewWithStatus(status map[string]string) *MockConn { 246 const bufSize = 1024 * 1024 247 mock := MockConn{ 248 server: bufconn.Listen(bufSize), 249 status: status, 250 errChan: make(chan error, 10), 251 } 252 // Start server in background 253 const timeout = time.Millisecond * 450 254 go func() { 255 defer close(mock.errChan) 256 // Block until we've opened a TCP connection 257 conn, err := mock.server.Accept() 258 if err != nil { 259 mock.errChan <- err 260 return 261 } 262 defer conn.Close() 263 // Prevent server from hanging the test 264 err = conn.SetDeadline(time.Now().Add(timeout)) 265 if err != nil { 266 mock.errChan <- err 267 return 268 } 269 // Always expect clients to terminate the request 270 mock.script.Steps = append(mock.script.Steps, ExpectTerminate()) 271 err = mock.script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn)) 272 if err != nil { 273 mock.errChan <- err 274 return 275 } 276 }() 277 278 return &mock 279 } 280 281 func NewConn() *MockConn { 282 return NewWithStatus(map[string]string{ 283 "application_name": "", 284 "client_encoding": "UTF8", 285 "DateStyle": "ISO, MDY", 286 "default_transaction_read_only": "off", 287 "in_hot_standby": "off", 288 "integer_datetimes": "on", 289 "IntervalStyle": "postgres", 290 "is_superuser": "on", 291 "server_encoding": "UTF8", 292 "server_version": "14.3 (Debian 14.3-1.pgdg110+1)", 293 "standard_conforming_strings": "on", 294 "TimeZone": "UTC", 295 }) 296 }