github.com/supabase/cli@v1.168.1/internal/testing/pgtest/step.go (about)

     1  package pgtest
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/go-errors/errors"
     7  	"github.com/jackc/pgmock"
     8  	"github.com/jackc/pgproto3/v2"
     9  	"github.com/jackc/pgtype"
    10  )
    11  
    12  var ci = pgtype.NewConnInfo()
    13  
    14  type extendedQueryStep struct {
    15  	sql    string
    16  	params [][]byte
    17  	oids   []uint32
    18  	reply  pgmock.Script
    19  }
    20  
    21  func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
    22  	msg, err := getFrontendMessage(backend)
    23  	if err != nil {
    24  		return err
    25  	}
    26  
    27  	// Handle prepared statements, name can be dynamic: lrupsc_5_0
    28  	if m, ok := msg.(*pgproto3.Parse); ok {
    29  		want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs}
    30  		if !reflect.DeepEqual(m, want) {
    31  			return errors.Errorf("msg => %#v, e.want => %#v", m, want)
    32  		}
    33  		// Anonymous ps falls through
    34  		if m.Name != "" {
    35  			script := pgmock.Script{Steps: []pgmock.Step{
    36  				pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}),
    37  				pgmock.ExpectMessage(&pgproto3.Sync{}),
    38  				pgmock.SendMessage(&pgproto3.ParseComplete{}),
    39  				pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}),
    40  				// Postgres responds pgproto3.RowDescription but it's optional for pgx
    41  				pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
    42  			}}
    43  			if err := script.Run(backend); err != nil {
    44  				return err
    45  			}
    46  		}
    47  		// Expect bind command next
    48  		msg, err = backend.Receive()
    49  		if err != nil {
    50  			return err
    51  		}
    52  	}
    53  
    54  	if m, ok := msg.(*pgproto3.Bind); ok {
    55  		var codes []int16
    56  		for _, oid := range e.oids {
    57  			codes = append(codes, ci.ParamFormatCodeForOID(oid))
    58  		}
    59  		want := &pgproto3.Bind{
    60  			ParameterFormatCodes: codes,
    61  			Parameters:           e.params,
    62  			ResultFormatCodes:    []int16{},
    63  			DestinationPortal:    m.DestinationPortal,
    64  			PreparedStatement:    m.PreparedStatement,
    65  		}
    66  		if !reflect.DeepEqual(m, want) {
    67  			return errors.Errorf("msg => %#v, e.want => %#v", msg, want)
    68  		}
    69  		e.reply.Steps = append([]pgmock.Step{
    70  			pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}),
    71  			pgmock.ExpectMessage(&pgproto3.Execute{}),
    72  			pgmock.SendMessage(&pgproto3.ParseComplete{}),
    73  			pgmock.SendMessage(&pgproto3.BindComplete{}),
    74  		}, e.reply.Steps...)
    75  		return e.reply.Run(backend)
    76  	}
    77  
    78  	// Handle simple query
    79  	want := &pgproto3.Query{String: e.sql}
    80  	if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
    81  		e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
    82  		return e.reply.Run(backend)
    83  	}
    84  
    85  	return errors.Errorf("msg => %#v, e.want => %#v", msg, want)
    86  }
    87  
    88  // Expects a SQL query in any form: simple, prepared, or anonymous.
    89  func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step {
    90  	return &extendedQueryStep{sql: sql, params: params, oids: oids}
    91  }
    92  
    93  type terminateStep struct{}
    94  
    95  func (e *terminateStep) Step(backend *pgproto3.Backend) error {
    96  	msg, err := getFrontendMessage(backend)
    97  	if err != nil {
    98  		return err
    99  	}
   100  
   101  	// Handle simple query
   102  	if _, ok := msg.(*pgproto3.Terminate); ok {
   103  		return nil
   104  	}
   105  
   106  	return errors.Errorf("msg => %#v, e.want => %#v", msg, &pgproto3.Terminate{})
   107  }
   108  
   109  func ExpectTerminate() pgmock.Step {
   110  	return &terminateStep{}
   111  }
   112  
   113  func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) {
   114  	msg, err := backend.Receive()
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	// Sync signals end of batch statements
   120  	if _, ok := msg.(*pgproto3.Sync); ok {
   121  		reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})
   122  		if err := reply.Step(backend); err != nil {
   123  			return nil, err
   124  		}
   125  		msg, err = backend.Receive()
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  	}
   130  
   131  	return msg, nil
   132  }