github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/testing/pgtest/step.go (about)

     1  package pgtest
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/jackc/pgmock"
     8  	"github.com/jackc/pgproto3/v2"
     9  	"github.com/jackc/pgtype"
    10  )
    11  
    12  type extendedQueryStep struct {
    13  	sql    string
    14  	params [][]byte
    15  	oids   []uint32
    16  	reply  pgmock.Script
    17  }
    18  
    19  func (e *extendedQueryStep) Step(backend *pgproto3.Backend) error {
    20  	msg, err := getFrontendMessage(backend)
    21  	if err != nil {
    22  		return err
    23  	}
    24  
    25  	// Handle prepared statements, name can be dynamic: lrupsc_5_0
    26  	if m, ok := msg.(*pgproto3.Parse); ok {
    27  		want := &pgproto3.Parse{Name: m.Name, Query: e.sql, ParameterOIDs: m.ParameterOIDs}
    28  		if !reflect.DeepEqual(m, want) {
    29  			return fmt.Errorf("msg => %#v, e.want => %#v", m, want)
    30  		}
    31  		// Anonymous ps falls through
    32  		if m.Name != "" {
    33  			script := pgmock.Script{Steps: []pgmock.Step{
    34  				pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: m.Name}),
    35  				pgmock.ExpectMessage(&pgproto3.Sync{}),
    36  				pgmock.SendMessage(&pgproto3.ParseComplete{}),
    37  				pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: e.oids}),
    38  				// Postgres responds pgproto3.RowDescription but it's optional for pgx
    39  				pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
    40  			}}
    41  			if err := script.Run(backend); err != nil {
    42  				return err
    43  			}
    44  		}
    45  		// Expect bind command next
    46  		msg, err = backend.Receive()
    47  		if err != nil {
    48  			return err
    49  		}
    50  	}
    51  
    52  	if m, ok := msg.(*pgproto3.Bind); ok {
    53  		var codes []int16
    54  		for range e.oids {
    55  			codes = append(codes, pgtype.TextFormatCode)
    56  		}
    57  		want := &pgproto3.Bind{
    58  			ParameterFormatCodes: codes,
    59  			Parameters:           e.params,
    60  			ResultFormatCodes:    []int16{},
    61  			DestinationPortal:    m.DestinationPortal,
    62  			PreparedStatement:    m.PreparedStatement,
    63  		}
    64  		if !reflect.DeepEqual(m, want) {
    65  			return fmt.Errorf("msg => %#v, e.want => %#v", msg, want)
    66  		}
    67  		e.reply.Steps = append([]pgmock.Step{
    68  			pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'P'}),
    69  			pgmock.ExpectMessage(&pgproto3.Execute{}),
    70  			pgmock.SendMessage(&pgproto3.ParseComplete{}),
    71  			pgmock.SendMessage(&pgproto3.BindComplete{}),
    72  		}, e.reply.Steps...)
    73  		return e.reply.Run(backend)
    74  	}
    75  
    76  	// Handle simple query
    77  	want := &pgproto3.Query{String: e.sql}
    78  	if m, ok := msg.(*pgproto3.Query); ok && reflect.DeepEqual(m, want) {
    79  		e.reply.Steps = append(e.reply.Steps, pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))
    80  		return e.reply.Run(backend)
    81  	}
    82  
    83  	return fmt.Errorf("msg => %#v, e.want => %#v", msg, want)
    84  }
    85  
    86  // Expects a SQL query in any form: simple, prepared, or anonymous.
    87  func ExpectQuery(sql string, params [][]byte, oids []uint32) pgmock.Step {
    88  	return &extendedQueryStep{sql: sql, params: params, oids: oids}
    89  }
    90  
    91  type terminateStep struct{}
    92  
    93  func (e *terminateStep) Step(backend *pgproto3.Backend) error {
    94  	msg, err := getFrontendMessage(backend)
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	// Handle simple query
   100  	if _, ok := msg.(*pgproto3.Terminate); ok {
   101  		return nil
   102  	}
   103  
   104  	return fmt.Errorf("msg => %#v, e.want => %#v", msg, &pgproto3.Terminate{})
   105  }
   106  
   107  func ExpectTerminate() pgmock.Step {
   108  	return &terminateStep{}
   109  }
   110  
   111  func getFrontendMessage(backend *pgproto3.Backend) (pgproto3.FrontendMessage, error) {
   112  	msg, err := backend.Receive()
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	// Sync signals end of batch statements
   118  	if _, ok := msg.(*pgproto3.Sync); ok {
   119  		reply := pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'})
   120  		if err := reply.Step(backend); err != nil {
   121  			return nil, err
   122  		}
   123  		msg, err = backend.Receive()
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  	}
   128  
   129  	return msg, nil
   130  }