github.com/supabase/cli@v1.168.1/internal/testing/pgtest/mock.go (about)

     1  package pgtest
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/jackc/pgmock"
    12  	"github.com/jackc/pgproto3/v2"
    13  	"github.com/jackc/pgtype"
    14  	"github.com/jackc/pgx/v4"
    15  	"google.golang.org/grpc/test/bufconn"
    16  )
    17  
    18  type MockConn struct {
    19  	// Duplex server listener backed by in-memory buffer
    20  	server *bufconn.Listener
    21  
    22  	// Mock server requests and responses
    23  	script pgmock.Script
    24  
    25  	// Status parameters emitted by postgres on first connect
    26  	status map[string]string
    27  
    28  	// Channel for reporting all server error
    29  	errChan chan error
    30  }
    31  
    32  func (r *MockConn) getStartupMessage(config *pgx.ConnConfig) []pgmock.Step {
    33  	var steps []pgmock.Step
    34  	// Add auth message
    35  	params := map[string]string{"user": config.User}
    36  	if len(config.Database) > 0 {
    37  		params["database"] = config.Database
    38  	}
    39  	steps = append(
    40  		steps,
    41  		pgmock.ExpectMessage(&pgproto3.StartupMessage{
    42  			ProtocolVersion: pgproto3.ProtocolVersionNumber,
    43  			Parameters:      params,
    44  		}),
    45  		pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
    46  	)
    47  	// Add status message
    48  	r.status["session_authorization"] = config.User
    49  	for key, value := range r.status {
    50  		steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterStatus{Name: key, Value: value}))
    51  	}
    52  	// Add ready message
    53  	steps = append(
    54  		steps,
    55  		pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
    56  		pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
    57  	)
    58  	return steps
    59  }
    60  
    61  // Configures pgx to use the mock dialer.
    62  //
    63  // The mock dialer provides a full duplex net.Conn backed by an in-memory buffer.
    64  // It is implemented by grcp/test/bufconn package.
    65  func (r *MockConn) Intercept(config *pgx.ConnConfig) {
    66  	// Override config for test
    67  	config.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
    68  		return r.server.DialContext(ctx)
    69  	}
    70  	config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
    71  		return []string{"127.0.0.1"}, nil
    72  	}
    73  	config.TLSConfig = nil
    74  	// Add startup message
    75  	r.script.Steps = append(r.getStartupMessage(config), r.script.Steps...)
    76  }
    77  
    78  // Adds a simple query or prepared statement to the mock connection.
    79  func (r *MockConn) Query(sql string, args ...interface{}) *MockConn {
    80  	var oids []uint32
    81  	var params [][]byte
    82  	for _, v := range args {
    83  		if value, oid := r.encodeValueArg(v); oid > 0 {
    84  			params = append(params, value)
    85  			oids = append(oids, oid)
    86  		}
    87  	}
    88  	r.script.Steps = append(r.script.Steps, ExpectQuery(sql, params, oids))
    89  	return r
    90  }
    91  
    92  func (r *MockConn) encodeValueArg(v interface{}) (value []byte, oid uint32) {
    93  	if v == nil {
    94  		return nil, pgtype.TextArrayOID
    95  	}
    96  	dt, ok := ci.DataTypeForValue(v)
    97  	if !ok {
    98  		r.errChan <- fmt.Errorf("no suitable type for arg: %v", v)
    99  		return nil, 0
   100  	}
   101  	if err := dt.Value.Set(v); err != nil {
   102  		r.errChan <- fmt.Errorf("failed to set value: %w", err)
   103  		return nil, 0
   104  	}
   105  	var err error
   106  	switch dt.OID {
   107  	case pgtype.TextOID:
   108  		value, err = (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{})
   109  	default:
   110  		value, err = (dt.Value).(pgtype.BinaryEncoder).EncodeBinary(ci, []byte{})
   111  	}
   112  	if err != nil {
   113  		r.errChan <- fmt.Errorf("failed to encode arg: %w", err)
   114  		return nil, 0
   115  	}
   116  	return value, dt.OID
   117  }
   118  
   119  func getDataTypeSize(v interface{}) int16 {
   120  	t := reflect.TypeOf(v)
   121  	k := t.Kind()
   122  	if k < reflect.Int || k > reflect.Complex128 {
   123  		return -1
   124  	}
   125  	return int16(t.Size())
   126  }
   127  
   128  func (r *MockConn) lastQuery() *extendedQueryStep {
   129  	return r.script.Steps[len(r.script.Steps)-1].(*extendedQueryStep)
   130  }
   131  
   132  // Adds a server reply using binary or text protocol format.
   133  //
   134  // TODO: support prepared statements when using binary protocol
   135  func (r *MockConn) Reply(tag string, rows ...interface{}) *MockConn {
   136  	q := r.lastQuery()
   137  	// Add field description
   138  	if len(rows) > 0 {
   139  		var desc pgproto3.RowDescription
   140  		if arr, ok := rows[0].([]interface{}); ok {
   141  			for i, v := range arr {
   142  				name := fmt.Sprintf("c_%02d", i)
   143  				if fd := toFieldDescription(v); fd != nil {
   144  					fd.Name = []byte(name)
   145  					desc.Fields = append(desc.Fields, *fd)
   146  				} else {
   147  					r.errChan <- fmt.Errorf("failed to describe field: %s", name)
   148  				}
   149  			}
   150  		} else if t := reflect.TypeOf(rows[0]); t.Kind() == reflect.Struct {
   151  			s := reflect.ValueOf(rows[0])
   152  			for i := 0; i < s.NumField(); i++ {
   153  				name := t.Field(i).Name
   154  				v := s.Field(i).Interface()
   155  				if fd := toFieldDescription(v); fd != nil {
   156  					fd.Name = []byte(name)
   157  					desc.Fields = append(desc.Fields, *fd)
   158  				} else {
   159  					r.errChan <- fmt.Errorf("failed to describe field: %s", name)
   160  				}
   161  			}
   162  		} else {
   163  			r.errChan <- fmt.Errorf("reply type must be one of [array, struct], found: %s", t.Kind())
   164  		}
   165  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&desc))
   166  	} else {
   167  		// No data is optional, but we add for completeness
   168  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&pgproto3.NoData{}))
   169  	}
   170  	// Add row data
   171  	for _, data := range rows {
   172  		var dr pgproto3.DataRow
   173  		if arr, ok := data.([]interface{}); ok {
   174  			for _, v := range arr {
   175  				if value, oid := r.encodeValueArg(v); oid > 0 {
   176  					dr.Values = append(dr.Values, value)
   177  				}
   178  			}
   179  		} else if t := reflect.TypeOf(data); t.Kind() == reflect.Struct {
   180  			s := reflect.ValueOf(rows[0])
   181  			for i := 0; i < s.NumField(); i++ {
   182  				v := s.Field(i).Interface()
   183  				if value, oid := r.encodeValueArg(v); oid > 0 {
   184  					dr.Values = append(dr.Values, value)
   185  				}
   186  			}
   187  		} else {
   188  			r.errChan <- fmt.Errorf("invalid reply value: %v", data)
   189  		}
   190  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&dr))
   191  	}
   192  	// Add completion message
   193  	var complete pgproto3.BackendMessage
   194  	if tag == "" {
   195  		complete = &pgproto3.EmptyQueryResponse{}
   196  	} else {
   197  		complete = &pgproto3.CommandComplete{CommandTag: []byte(tag)}
   198  	}
   199  	q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(complete))
   200  	return r
   201  }
   202  
   203  func toFieldDescription(v interface{}) *pgproto3.FieldDescription {
   204  	if dt, ok := ci.DataTypeForValue(v); ok {
   205  		size := getDataTypeSize(v)
   206  		format := ci.ParamFormatCodeForOID(dt.OID)
   207  		return &pgproto3.FieldDescription{
   208  			TableOID:             17131,
   209  			TableAttributeNumber: 1,
   210  			DataTypeOID:          dt.OID,
   211  			DataTypeSize:         size,
   212  			TypeModifier:         -1,
   213  			Format:               format,
   214  		}
   215  	}
   216  	return nil
   217  }
   218  
   219  // Simulates an error reply from the server.
   220  //
   221  // TODO: simulate a notice reply
   222  func (r *MockConn) ReplyError(code, message string) *MockConn {
   223  	q := r.lastQuery()
   224  	q.reply.Steps = append(
   225  		q.reply.Steps,
   226  		pgmock.SendMessage(&pgproto3.ErrorResponse{
   227  			Severity:            "ERROR",
   228  			SeverityUnlocalized: "ERROR",
   229  			Code:                code,
   230  			Message:             message,
   231  		}),
   232  	)
   233  	return r
   234  }
   235  
   236  func (r *MockConn) Close(t *testing.T) {
   237  	for err := range r.errChan {
   238  		t.Errorf("pgmock error: %v", err)
   239  	}
   240  	if err := r.server.Close(); err != nil {
   241  		t.Fatalf("failed to close: %v", err)
   242  	}
   243  }
   244  
   245  func NewWithStatus(status map[string]string) *MockConn {
   246  	const bufSize = 1024 * 1024
   247  	mock := MockConn{
   248  		server:  bufconn.Listen(bufSize),
   249  		status:  status,
   250  		errChan: make(chan error, 10),
   251  	}
   252  	// Start server in background
   253  	const timeout = time.Millisecond * 450
   254  	go func() {
   255  		defer close(mock.errChan)
   256  		// Block until we've opened a TCP connection
   257  		conn, err := mock.server.Accept()
   258  		if err != nil {
   259  			mock.errChan <- err
   260  			return
   261  		}
   262  		defer conn.Close()
   263  		// Prevent server from hanging the test
   264  		err = conn.SetDeadline(time.Now().Add(timeout))
   265  		if err != nil {
   266  			mock.errChan <- err
   267  			return
   268  		}
   269  		// Always expect clients to terminate the request
   270  		mock.script.Steps = append(mock.script.Steps, ExpectTerminate())
   271  		err = mock.script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
   272  		if err != nil {
   273  			mock.errChan <- err
   274  			return
   275  		}
   276  	}()
   277  
   278  	return &mock
   279  }
   280  
   281  func NewConn() *MockConn {
   282  	return NewWithStatus(map[string]string{
   283  		"application_name":              "",
   284  		"client_encoding":               "UTF8",
   285  		"DateStyle":                     "ISO, MDY",
   286  		"default_transaction_read_only": "off",
   287  		"in_hot_standby":                "off",
   288  		"integer_datetimes":             "on",
   289  		"IntervalStyle":                 "postgres",
   290  		"is_superuser":                  "on",
   291  		"server_encoding":               "UTF8",
   292  		"server_version":                "14.3 (Debian 14.3-1.pgdg110+1)",
   293  		"standard_conforming_strings":   "on",
   294  		"TimeZone":                      "UTC",
   295  	})
   296  }