github.com/supabase/cli@v1.168.1/internal/debug/postgres.go (about)

     1  package debug
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"os"
    11  
    12  	"github.com/jackc/pgproto3/v2"
    13  	"github.com/jackc/pgx/v4"
    14  	"google.golang.org/grpc/test/bufconn"
    15  )
    16  
    17  type Proxy struct {
    18  	dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
    19  	errChan     chan error
    20  }
    21  
    22  func NewProxy() Proxy {
    23  	dialer := net.Dialer{}
    24  	return Proxy{
    25  		dialContext: dialer.DialContext,
    26  		errChan:     make(chan error, 1),
    27  	}
    28  }
    29  
    30  func SetupPGX(config *pgx.ConnConfig) {
    31  	proxy := Proxy{
    32  		dialContext: config.DialFunc,
    33  		errChan:     make(chan error, 1),
    34  	}
    35  	config.DialFunc = proxy.DialFunc
    36  	config.TLSConfig = nil
    37  }
    38  
    39  func (p *Proxy) DialFunc(ctx context.Context, network, addr string) (net.Conn, error) {
    40  	serverConn, err := p.dialContext(ctx, network, addr)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	const bufSize = 1024 * 1024
    46  	ln := bufconn.Listen(bufSize)
    47  	go func() {
    48  		defer serverConn.Close()
    49  		clientConn, err := ln.Accept()
    50  		if err != nil {
    51  			// Unreachable code as bufconn never throws, but just in case
    52  			panic(err)
    53  		}
    54  		defer clientConn.Close()
    55  
    56  		backend := NewBackend(clientConn)
    57  		frontend := NewFrontend(serverConn)
    58  		go backend.forward(frontend, p.errChan)
    59  		go frontend.forward(backend, p.errChan)
    60  
    61  		for {
    62  			// Since pgx closes connection first, every EOF is seen as unexpected
    63  			if err := <-p.errChan; err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
    64  				panic(err)
    65  			}
    66  		}
    67  	}()
    68  
    69  	return ln.DialContext(ctx)
    70  }
    71  
    72  type Backend struct {
    73  	*pgproto3.Backend
    74  	logger *log.Logger
    75  }
    76  
    77  func NewBackend(clientConn net.Conn) Backend {
    78  	return Backend{
    79  		pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn),
    80  		log.New(os.Stderr, "PG Recv: ", log.LstdFlags|log.Lmsgprefix),
    81  	}
    82  }
    83  
    84  func (b *Backend) forward(frontend Frontend, errChan chan error) {
    85  	startupMessage, err := b.ReceiveStartupMessage()
    86  	if err != nil {
    87  		errChan <- err
    88  		return
    89  	}
    90  
    91  	buf, err := json.Marshal(startupMessage)
    92  	if err != nil {
    93  		errChan <- err
    94  		return
    95  	}
    96  	frontend.logger.Println(string(buf))
    97  
    98  	if err = frontend.Send(startupMessage); err != nil {
    99  		errChan <- err
   100  		return
   101  	}
   102  
   103  	for {
   104  		msg, err := b.Receive()
   105  		if err != nil {
   106  			errChan <- err
   107  			return
   108  		}
   109  
   110  		buf, err := json.Marshal(msg)
   111  		if err != nil {
   112  			errChan <- err
   113  			return
   114  		}
   115  		frontend.logger.Println(string(buf))
   116  
   117  		if err = frontend.Send(msg); err != nil {
   118  			errChan <- err
   119  			return
   120  		}
   121  	}
   122  }
   123  
   124  type Frontend struct {
   125  	*pgproto3.Frontend
   126  	logger *log.Logger
   127  }
   128  
   129  func NewFrontend(serverConn net.Conn) Frontend {
   130  	return Frontend{
   131  		pgproto3.NewFrontend(pgproto3.NewChunkReader(serverConn), serverConn),
   132  		log.New(os.Stderr, "PG Send: ", log.LstdFlags|log.Lmsgprefix),
   133  	}
   134  }
   135  
   136  func (f *Frontend) forward(backend Backend, errChan chan error) {
   137  	for {
   138  		msg, err := f.Receive()
   139  		if err != nil {
   140  			errChan <- err
   141  			return
   142  		}
   143  
   144  		buf, err := json.Marshal(msg)
   145  		if err != nil {
   146  			errChan <- err
   147  			return
   148  		}
   149  		backend.logger.Println(string(buf))
   150  
   151  		if _, ok := msg.(pgproto3.AuthenticationResponseMessage); ok {
   152  			// Set the authentication type so the next backend.Receive() will
   153  			// properly decode the appropriate 'p' message.
   154  			if err := backend.SetAuthType(f.GetAuthType()); err != nil {
   155  				errChan <- err
   156  				return
   157  			}
   158  		}
   159  
   160  		if err := backend.Send(msg); err != nil {
   161  			errChan <- err
   162  			return
   163  		}
   164  	}
   165  }