github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/cmp-protocol/pgconnect/pgconnect.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // Package pgconnect provides a way to get byte encodings from a simple query.
    12  package pgconnect
    13  
    14  import (
    15  	"context"
    16  	"net"
    17  	"reflect"
    18  
    19  	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase"
    20  	"github.com/cockroachdb/cockroach/pkg/util/ctxgroup"
    21  	"github.com/cockroachdb/errors"
    22  	"github.com/jackc/pgx/pgproto3"
    23  )
    24  
    25  // Connect connects to the postgres-compatible server at addr with specified
    26  // user. input must specify a SELECT query (including the "SELECT") that
    27  // returns one row and one column. code is the format code. The resulting
    28  // row bytes are returned.
    29  func Connect(
    30  	ctx context.Context, input, addr, user string, code pgwirebase.FormatCode,
    31  ) ([]byte, error) {
    32  	ctx, cancel := context.WithCancel(ctx)
    33  	defer cancel()
    34  
    35  	var d net.Dialer
    36  	conn, err := d.DialContext(ctx, "tcp", addr)
    37  	if err != nil {
    38  		return nil, errors.Wrap(err, "dail")
    39  	}
    40  	defer conn.Close()
    41  
    42  	fe, err := pgproto3.NewFrontend(conn, conn)
    43  	if err != nil {
    44  		return nil, errors.Wrap(err, "new frontend")
    45  	}
    46  
    47  	send := make(chan pgproto3.FrontendMessage)
    48  	recv := make(chan pgproto3.BackendMessage)
    49  	var res []byte
    50  	// Use go routines to divide up work in order to improve debugging. These
    51  	// aren't strictly necessary, but they make it easy to print when messages
    52  	// are received.
    53  	g := ctxgroup.WithContext(ctx)
    54  	// The send chan sends messages to the server.
    55  	g.GoCtx(func(ctx context.Context) error {
    56  		defer close(send)
    57  		for {
    58  			select {
    59  			case <-ctx.Done():
    60  				return ctx.Err()
    61  			case msg := <-send:
    62  				err := fe.Send(msg)
    63  				if err != nil {
    64  					return errors.Wrap(err, "send")
    65  				}
    66  			}
    67  		}
    68  	})
    69  	// The recv go routine receives messages from the server and puts them on
    70  	// the recv chan. It makes a copy of them when it does since the next message
    71  	// received will otherwise use the same pointer.
    72  	g.GoCtx(func(ctx context.Context) error {
    73  		defer close(recv)
    74  		for {
    75  			msg, err := fe.Receive()
    76  			if err != nil {
    77  				return errors.Wrap(err, "receive")
    78  			}
    79  
    80  			// Make a deep copy since the receiver uses a pointer.
    81  			x := reflect.ValueOf(msg)
    82  			starX := x.Elem()
    83  			y := reflect.New(starX.Type())
    84  			starY := y.Elem()
    85  			starY.Set(starX)
    86  			dup := y.Interface().(pgproto3.BackendMessage)
    87  
    88  			select {
    89  			case <-ctx.Done():
    90  				return ctx.Err()
    91  			case recv <- dup:
    92  			}
    93  		}
    94  	})
    95  	// The main go routine executing the logic.
    96  	g.GoCtx(func(ctx context.Context) error {
    97  		send <- &pgproto3.StartupMessage{
    98  			ProtocolVersion: 196608, // Version 3.0
    99  			Parameters: map[string]string{
   100  				"user": user,
   101  			},
   102  		}
   103  		{
   104  			r := <-recv
   105  			if msg, ok := r.(*pgproto3.Authentication); !ok || msg.Type != 0 {
   106  				return errors.Errorf("unexpected: %#v\n", r)
   107  			}
   108  		}
   109  	WaitConnLoop:
   110  		for {
   111  			msg := <-recv
   112  			switch msg.(type) {
   113  			case *pgproto3.ReadyForQuery:
   114  				break WaitConnLoop
   115  			}
   116  		}
   117  		send <- &pgproto3.Parse{
   118  			Query: input,
   119  		}
   120  		send <- &pgproto3.Describe{
   121  			ObjectType: 'S',
   122  		}
   123  		send <- &pgproto3.Sync{}
   124  		r := <-recv
   125  		if _, ok := r.(*pgproto3.ParseComplete); !ok {
   126  			return errors.Errorf("unexpected: %#v", r)
   127  		}
   128  		send <- &pgproto3.Bind{
   129  			ResultFormatCodes: []int16{int16(code)},
   130  		}
   131  		send <- &pgproto3.Execute{}
   132  		send <- &pgproto3.Sync{}
   133  	WaitExecuteLoop:
   134  		for {
   135  			msg := <-recv
   136  			switch msg := msg.(type) {
   137  			case *pgproto3.DataRow:
   138  				if res != nil {
   139  					return errors.New("already saw a row")
   140  				}
   141  				if len(msg.Values) != 1 {
   142  					return errors.Errorf("unexpected: %#v\n", msg)
   143  				}
   144  				res = msg.Values[0]
   145  			case *pgproto3.CommandComplete,
   146  				*pgproto3.EmptyQueryResponse,
   147  				*pgproto3.ErrorResponse:
   148  				break WaitExecuteLoop
   149  			}
   150  		}
   151  		// Stop the other go routines.
   152  		cancel()
   153  		return nil
   154  	})
   155  	err = g.Wait()
   156  	// If res is set, we don't care about any errors.
   157  	if res != nil {
   158  		return res, nil
   159  	}
   160  	if err == nil {
   161  		return nil, errors.New("unexpected")
   162  	}
   163  	return nil, err
   164  }