github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmp-protocol/pgconnect/pgconnect.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 // Package pgconnect provides a way to get byte encodings from a simple query. 12 package pgconnect 13 14 import ( 15 "context" 16 "net" 17 "reflect" 18 19 "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" 20 "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" 21 "github.com/cockroachdb/errors" 22 "github.com/jackc/pgx/pgproto3" 23 ) 24 25 // Connect connects to the postgres-compatible server at addr with specified 26 // user. input must specify a SELECT query (including the "SELECT") that 27 // returns one row and one column. code is the format code. The resulting 28 // row bytes are returned. 29 func Connect( 30 ctx context.Context, input, addr, user string, code pgwirebase.FormatCode, 31 ) ([]byte, error) { 32 ctx, cancel := context.WithCancel(ctx) 33 defer cancel() 34 35 var d net.Dialer 36 conn, err := d.DialContext(ctx, "tcp", addr) 37 if err != nil { 38 return nil, errors.Wrap(err, "dail") 39 } 40 defer conn.Close() 41 42 fe, err := pgproto3.NewFrontend(conn, conn) 43 if err != nil { 44 return nil, errors.Wrap(err, "new frontend") 45 } 46 47 send := make(chan pgproto3.FrontendMessage) 48 recv := make(chan pgproto3.BackendMessage) 49 var res []byte 50 // Use go routines to divide up work in order to improve debugging. These 51 // aren't strictly necessary, but they make it easy to print when messages 52 // are received. 53 g := ctxgroup.WithContext(ctx) 54 // The send chan sends messages to the server. 55 g.GoCtx(func(ctx context.Context) error { 56 defer close(send) 57 for { 58 select { 59 case <-ctx.Done(): 60 return ctx.Err() 61 case msg := <-send: 62 err := fe.Send(msg) 63 if err != nil { 64 return errors.Wrap(err, "send") 65 } 66 } 67 } 68 }) 69 // The recv go routine receives messages from the server and puts them on 70 // the recv chan. It makes a copy of them when it does since the next message 71 // received will otherwise use the same pointer. 72 g.GoCtx(func(ctx context.Context) error { 73 defer close(recv) 74 for { 75 msg, err := fe.Receive() 76 if err != nil { 77 return errors.Wrap(err, "receive") 78 } 79 80 // Make a deep copy since the receiver uses a pointer. 81 x := reflect.ValueOf(msg) 82 starX := x.Elem() 83 y := reflect.New(starX.Type()) 84 starY := y.Elem() 85 starY.Set(starX) 86 dup := y.Interface().(pgproto3.BackendMessage) 87 88 select { 89 case <-ctx.Done(): 90 return ctx.Err() 91 case recv <- dup: 92 } 93 } 94 }) 95 // The main go routine executing the logic. 96 g.GoCtx(func(ctx context.Context) error { 97 send <- &pgproto3.StartupMessage{ 98 ProtocolVersion: 196608, // Version 3.0 99 Parameters: map[string]string{ 100 "user": user, 101 }, 102 } 103 { 104 r := <-recv 105 if msg, ok := r.(*pgproto3.Authentication); !ok || msg.Type != 0 { 106 return errors.Errorf("unexpected: %#v\n", r) 107 } 108 } 109 WaitConnLoop: 110 for { 111 msg := <-recv 112 switch msg.(type) { 113 case *pgproto3.ReadyForQuery: 114 break WaitConnLoop 115 } 116 } 117 send <- &pgproto3.Parse{ 118 Query: input, 119 } 120 send <- &pgproto3.Describe{ 121 ObjectType: 'S', 122 } 123 send <- &pgproto3.Sync{} 124 r := <-recv 125 if _, ok := r.(*pgproto3.ParseComplete); !ok { 126 return errors.Errorf("unexpected: %#v", r) 127 } 128 send <- &pgproto3.Bind{ 129 ResultFormatCodes: []int16{int16(code)}, 130 } 131 send <- &pgproto3.Execute{} 132 send <- &pgproto3.Sync{} 133 WaitExecuteLoop: 134 for { 135 msg := <-recv 136 switch msg := msg.(type) { 137 case *pgproto3.DataRow: 138 if res != nil { 139 return errors.New("already saw a row") 140 } 141 if len(msg.Values) != 1 { 142 return errors.Errorf("unexpected: %#v\n", msg) 143 } 144 res = msg.Values[0] 145 case *pgproto3.CommandComplete, 146 *pgproto3.EmptyQueryResponse, 147 *pgproto3.ErrorResponse: 148 break WaitExecuteLoop 149 } 150 } 151 // Stop the other go routines. 152 cancel() 153 return nil 154 }) 155 err = g.Wait() 156 // If res is set, we don't care about any errors. 157 if res != nil { 158 return res, nil 159 } 160 if err == nil { 161 return nil, errors.New("unexpected") 162 } 163 return nil, err 164 }