github.com/jackc/pgx/v5@v5.5.5/internal/pgmock/pgmock.go (about) 1 // Package pgmock provides the ability to mock a PostgreSQL server. 2 package pgmock 3 4 import ( 5 "fmt" 6 "io" 7 "reflect" 8 9 "github.com/jackc/pgx/v5/pgproto3" 10 ) 11 12 type Step interface { 13 Step(*pgproto3.Backend) error 14 } 15 16 type Script struct { 17 Steps []Step 18 } 19 20 func (s *Script) Run(backend *pgproto3.Backend) error { 21 for _, step := range s.Steps { 22 err := step.Step(backend) 23 if err != nil { 24 return err 25 } 26 } 27 28 return nil 29 } 30 31 func (s *Script) Step(backend *pgproto3.Backend) error { 32 return s.Run(backend) 33 } 34 35 type expectMessageStep struct { 36 want pgproto3.FrontendMessage 37 any bool 38 } 39 40 func (e *expectMessageStep) Step(backend *pgproto3.Backend) error { 41 msg, err := backend.Receive() 42 if err != nil { 43 return err 44 } 45 46 if e.any && reflect.TypeOf(msg) == reflect.TypeOf(e.want) { 47 return nil 48 } 49 50 if !reflect.DeepEqual(msg, e.want) { 51 return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) 52 } 53 54 return nil 55 } 56 57 type expectStartupMessageStep struct { 58 want *pgproto3.StartupMessage 59 any bool 60 } 61 62 func (e *expectStartupMessageStep) Step(backend *pgproto3.Backend) error { 63 msg, err := backend.ReceiveStartupMessage() 64 if err != nil { 65 return err 66 } 67 68 if e.any { 69 return nil 70 } 71 72 if !reflect.DeepEqual(msg, e.want) { 73 return fmt.Errorf("msg => %#v, e.want => %#v", msg, e.want) 74 } 75 76 return nil 77 } 78 79 func ExpectMessage(want pgproto3.FrontendMessage) Step { 80 return expectMessage(want, false) 81 } 82 83 func ExpectAnyMessage(want pgproto3.FrontendMessage) Step { 84 return expectMessage(want, true) 85 } 86 87 func expectMessage(want pgproto3.FrontendMessage, any bool) Step { 88 if want, ok := want.(*pgproto3.StartupMessage); ok { 89 return &expectStartupMessageStep{want: want, any: any} 90 } 91 92 return &expectMessageStep{want: want, any: any} 93 } 94 95 type sendMessageStep struct { 96 msg pgproto3.BackendMessage 97 } 98 99 func (e *sendMessageStep) Step(backend *pgproto3.Backend) error { 100 backend.Send(e.msg) 101 return backend.Flush() 102 } 103 104 func SendMessage(msg pgproto3.BackendMessage) Step { 105 return &sendMessageStep{msg: msg} 106 } 107 108 type waitForCloseMessageStep struct{} 109 110 func (e *waitForCloseMessageStep) Step(backend *pgproto3.Backend) error { 111 for { 112 msg, err := backend.Receive() 113 if err == io.EOF { 114 return nil 115 } else if err != nil { 116 return err 117 } 118 119 if _, ok := msg.(*pgproto3.Terminate); ok { 120 return nil 121 } 122 } 123 } 124 125 func WaitForClose() Step { 126 return &waitForCloseMessageStep{} 127 } 128 129 func AcceptUnauthenticatedConnRequestSteps() []Step { 130 return []Step{ 131 ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}), 132 SendMessage(&pgproto3.AuthenticationOk{}), 133 SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}), 134 SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), 135 } 136 }