github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/testing/pgtest/mock.go (about)

     1  package pgtest
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/jackc/pgmock"
    12  	"github.com/jackc/pgproto3/v2"
    13  	"github.com/jackc/pgtype"
    14  	"github.com/jackc/pgx/v4"
    15  	"google.golang.org/grpc/test/bufconn"
    16  )
    17  
    18  var ci = pgtype.NewConnInfo()
    19  
    20  type MockConn struct {
    21  	// Duplex server listener backed by in-memory buffer
    22  	server *bufconn.Listener
    23  
    24  	// Mock server requests and responses
    25  	script pgmock.Script
    26  
    27  	// Status parameters emitted by postgres on first connect
    28  	status map[string]string
    29  
    30  	// Channel for reporting all server error
    31  	errChan chan error
    32  }
    33  
    34  func (r *MockConn) getStartupMessage(config *pgx.ConnConfig) []pgmock.Step {
    35  	var steps []pgmock.Step
    36  	// Add auth message
    37  	steps = append(
    38  		steps,
    39  		pgmock.ExpectMessage(&pgproto3.StartupMessage{
    40  			ProtocolVersion: pgproto3.ProtocolVersionNumber,
    41  			Parameters:      map[string]string{"database": config.Database, "user": config.User},
    42  		}),
    43  		pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
    44  	)
    45  	// Add status message
    46  	r.status["session_authorization"] = config.User
    47  	for key, value := range r.status {
    48  		steps = append(steps, pgmock.SendMessage(&pgproto3.ParameterStatus{Name: key, Value: value}))
    49  	}
    50  	// Add ready message
    51  	steps = append(
    52  		steps,
    53  		pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
    54  		pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
    55  	)
    56  	return steps
    57  }
    58  
    59  // Configures pgx to use the mock dialer.
    60  //
    61  // The mock dialer provides a full duplex net.Conn backed by an in-memory buffer.
    62  // It is implemented by grcp/test/bufconn package.
    63  func (r *MockConn) Intercept(config *pgx.ConnConfig) {
    64  	// Override config for test
    65  	config.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
    66  		return r.server.DialContext(ctx)
    67  	}
    68  	config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
    69  		return []string{"127.0.0.1"}, nil
    70  	}
    71  	config.TLSConfig = nil
    72  	// Add startup message
    73  	r.script.Steps = append(r.getStartupMessage(config), r.script.Steps...)
    74  }
    75  
    76  // Adds a simple query or prepared statement to the mock connection.
    77  func (r *MockConn) Query(sql string, args ...interface{}) *MockConn {
    78  	var oids []uint32
    79  	var params [][]byte
    80  	for _, v := range args {
    81  		if dt, ok := ci.DataTypeForValue(v); ok {
    82  			if err := dt.Value.Set(v); err != nil {
    83  				continue
    84  			}
    85  			value, err := (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{})
    86  			if err != nil {
    87  				continue
    88  			}
    89  			params = append(params, value)
    90  			oids = append(oids, dt.OID)
    91  		}
    92  	}
    93  	r.script.Steps = append(r.script.Steps, ExpectQuery(sql, params, oids))
    94  	return r
    95  }
    96  
    97  func getDataTypeSize(v interface{}) int16 {
    98  	t := reflect.TypeOf(v)
    99  	k := t.Kind()
   100  	if k < reflect.Int || k > reflect.Complex128 {
   101  		return -1
   102  	}
   103  	return int16(t.Size())
   104  }
   105  
   106  func (r *MockConn) lastQuery() *extendedQueryStep {
   107  	return r.script.Steps[len(r.script.Steps)-1].(*extendedQueryStep)
   108  }
   109  
   110  // Adds a server reply using text protocol format.
   111  //
   112  // TODO: support binary protocol
   113  func (r *MockConn) Reply(tag string, rows ...[]interface{}) *MockConn {
   114  	q := r.lastQuery()
   115  	// Add field description
   116  	if len(rows) > 0 {
   117  		var desc pgproto3.RowDescription
   118  		for i, v := range rows[0] {
   119  			name := fmt.Sprintf("c_%02d", i)
   120  			if dt, ok := ci.DataTypeForValue(v); ok {
   121  				size := getDataTypeSize(v)
   122  				desc.Fields = append(desc.Fields, pgproto3.FieldDescription{
   123  					Name:                 []byte(name),
   124  					TableOID:             17131,
   125  					TableAttributeNumber: 1,
   126  					DataTypeOID:          dt.OID,
   127  					DataTypeSize:         size,
   128  					TypeModifier:         -1,
   129  					Format:               pgtype.TextFormatCode,
   130  				})
   131  			}
   132  		}
   133  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&desc))
   134  	} else {
   135  		// No data is optional, but we add for completeness
   136  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&pgproto3.NoData{}))
   137  	}
   138  	// Add row data
   139  	for _, data := range rows {
   140  		var dr pgproto3.DataRow
   141  		for _, v := range data {
   142  			if dt, ok := ci.DataTypeForValue(v); ok {
   143  				if err := dt.Value.Set(v); err != nil {
   144  					continue
   145  				}
   146  				if value, err := (dt.Value).(pgtype.TextEncoder).EncodeText(ci, []byte{}); err == nil {
   147  					dr.Values = append(dr.Values, value)
   148  				}
   149  			}
   150  		}
   151  		q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(&dr))
   152  	}
   153  	// Add completion message
   154  	var complete pgproto3.BackendMessage
   155  	if tag == "" {
   156  		complete = &pgproto3.EmptyQueryResponse{}
   157  	} else {
   158  		complete = &pgproto3.CommandComplete{CommandTag: []byte(tag)}
   159  	}
   160  	q.reply.Steps = append(q.reply.Steps, pgmock.SendMessage(complete))
   161  	return r
   162  }
   163  
   164  // Simulates an error reply from the server.
   165  //
   166  // TODO: simulate a notice reply
   167  func (r *MockConn) ReplyError(code, message string) *MockConn {
   168  	q := r.lastQuery()
   169  	q.reply.Steps = append(
   170  		q.reply.Steps,
   171  		pgmock.SendMessage(&pgproto3.ErrorResponse{
   172  			Severity:            "ERROR",
   173  			SeverityUnlocalized: "ERROR",
   174  			Code:                code,
   175  			Message:             message,
   176  		}),
   177  	)
   178  	return r
   179  }
   180  
   181  func (r *MockConn) Close(t *testing.T) {
   182  	if err := <-r.errChan; err != nil {
   183  		t.Fatalf("failed to close %v", err)
   184  	}
   185  	if err := r.server.Close(); err != nil {
   186  		t.Fatalf("failed to close %v", err)
   187  	}
   188  }
   189  
   190  func NewWithStatus(status map[string]string) *MockConn {
   191  	const bufSize = 1024 * 1024
   192  	mock := MockConn{
   193  		server:  bufconn.Listen(bufSize),
   194  		status:  status,
   195  		errChan: make(chan error, 1),
   196  	}
   197  	// Start server in background
   198  	const timeout = time.Millisecond * 450
   199  	go func() {
   200  		defer close(mock.errChan)
   201  		// Block until we've opened a TCP connection
   202  		conn, err := mock.server.Accept()
   203  		if err != nil {
   204  			mock.errChan <- err
   205  			return
   206  		}
   207  		defer conn.Close()
   208  		// Prevent server from hanging the test
   209  		err = conn.SetDeadline(time.Now().Add(timeout))
   210  		if err != nil {
   211  			mock.errChan <- err
   212  			return
   213  		}
   214  		// Always expect clients to terminate the request
   215  		mock.script.Steps = append(mock.script.Steps, ExpectTerminate())
   216  		err = mock.script.Run(pgproto3.NewBackend(pgproto3.NewChunkReader(conn), conn))
   217  		if err != nil {
   218  			mock.errChan <- err
   219  			return
   220  		}
   221  	}()
   222  
   223  	return &mock
   224  }
   225  
   226  func NewConn() *MockConn {
   227  	return NewWithStatus(map[string]string{
   228  		"application_name":              "",
   229  		"client_encoding":               "UTF8",
   230  		"DateStyle":                     "ISO, MDY",
   231  		"default_transaction_read_only": "off",
   232  		"in_hot_standby":                "off",
   233  		"integer_datetimes":             "on",
   234  		"IntervalStyle":                 "postgres",
   235  		"is_superuser":                  "on",
   236  		"server_encoding":               "UTF8",
   237  		"server_version":                "14.3 (Debian 14.3-1.pgdg110+1)",
   238  		"standard_conforming_strings":   "on",
   239  		"TimeZone":                      "UTC",
   240  	})
   241  }