github.com/thiagoyeds/go-cloud@v0.26.0/postgres/awspostgres/awspostgres.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awspostgres provides connections to AWS RDS PostgreSQL instances.
    16  //
    17  // URLs
    18  //
    19  // For postgres.Open, awspostgres registers for the scheme "awspostgres".
    20  // The default URL opener will create a connection using the default
    21  // credentials from the environment, as described in
    22  // https://docs.aws.amazon.com/sdk-for-go/api/aws/session/.
    23  // To customize the URL opener, or for more details on the URL format,
    24  // see URLOpener.
    25  //
    26  // See https://gocloud.dev/concepts/urls/ for background information.
    27  package awspostgres // import "gocloud.dev/postgres/awspostgres"
    28  
    29  import (
    30  	"context"
    31  	"crypto/tls"
    32  	"database/sql"
    33  	"database/sql/driver"
    34  	"fmt"
    35  	"io"
    36  	"net"
    37  	"net/url"
    38  	"time"
    39  
    40  	"contrib.go.opencensus.io/integrations/ocsql"
    41  	"github.com/lib/pq"
    42  	"gocloud.dev/aws/rds"
    43  	"gocloud.dev/postgres"
    44  )
    45  
    46  // URLOpener opens RDS PostgreSQL URLs
    47  // like "awspostgres://user:password@myinstance.borkxyzzy.us-west-1.rds.amazonaws.com:5432/mydb".
    48  type URLOpener struct {
    49  	// CertSource specifies how the opener will obtain the RDS Certificate
    50  	// Authority. If nil, it will use the default *rds.CertFetcher.
    51  	CertSource rds.CertPoolProvider
    52  	// TraceOpts contains options for OpenCensus.
    53  	TraceOpts []ocsql.TraceOption
    54  }
    55  
    56  // Scheme is the URL scheme awspostgres registers its URLOpener under on
    57  // postgres.DefaultMux.
    58  const Scheme = "awspostgres"
    59  
    60  func init() {
    61  	postgres.DefaultURLMux().RegisterPostgres(Scheme, &URLOpener{})
    62  }
    63  
    64  // OpenPostgresURL opens a new RDS database connection wrapped with OpenCensus instrumentation.
    65  func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
    66  	source := uo.CertSource
    67  	if source == nil {
    68  		source = new(rds.CertFetcher)
    69  	}
    70  
    71  	query := u.Query()
    72  	for k := range query {
    73  		// Forbid SSL-related parameters.
    74  		if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" {
    75  			return nil, fmt.Errorf("awspostgres: open: parameter %q not allowed; sslmode must be disabled because the underlying dialer is already providing TLS", k)
    76  		}
    77  	}
    78  	// sslmode must be disabled because the underlying dialer is already providing TLS.
    79  	query.Set("sslmode", "disable")
    80  
    81  	u2 := new(url.URL)
    82  	*u2 = *u
    83  	u2.Scheme = "postgres"
    84  	u2.RawQuery = query.Encode()
    85  	db := sql.OpenDB(connector{
    86  		provider:  source,
    87  		pqConn:    u2.String(),
    88  		traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...),
    89  	})
    90  	return db, nil
    91  }
    92  
    93  type pqDriver struct {
    94  	provider  rds.CertPoolProvider
    95  	traceOpts []ocsql.TraceOption
    96  }
    97  
    98  func (d pqDriver) Open(name string) (driver.Conn, error) {
    99  	c, _ := d.OpenConnector(name)
   100  	return c.Connect(context.Background())
   101  }
   102  
   103  func (d pqDriver) OpenConnector(name string) (driver.Connector, error) {
   104  	return connector{d.provider, name + " sslmode=disable", d.traceOpts}, nil
   105  }
   106  
   107  type connector struct {
   108  	provider  rds.CertPoolProvider
   109  	pqConn    string
   110  	traceOpts []ocsql.TraceOption
   111  }
   112  
   113  func (c connector) Connect(context.Context) (driver.Conn, error) {
   114  	conn, err := pq.DialOpen(dialer{c.provider}, c.pqConn)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return ocsql.WrapConn(conn, c.traceOpts...), nil
   119  }
   120  
   121  func (c connector) Driver() driver.Driver {
   122  	return pqDriver{c.provider, c.traceOpts}
   123  }
   124  
   125  type dialer struct {
   126  	provider rds.CertPoolProvider
   127  }
   128  
   129  func (d dialer) dial(ctx context.Context, network, address string) (net.Conn, error) {
   130  	host, _, err := net.SplitHostPort(address)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("awspostgres: parse address: %v", err)
   133  	}
   134  	certPool, err := d.provider.RDSCertPool(ctx)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	conn, err := new(net.Dialer).DialContext(ctx, network, address)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	// Write the PostgreSQL SSLRequest message described in
   144  	// https://www.postgresql.org/docs/11/protocol-message-formats.html
   145  	// to upgrade to a TLS connection.
   146  	_, err = conn.Write([]byte{
   147  		// Message length (Int32), including message length.
   148  		0x00, 0x00, 0x00, 0x08,
   149  		// Magic number: 80877103.
   150  		0x04, 0xd2, 0x16, 0x2f,
   151  	})
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	// Server must respond back with 'S'.
   156  	var readBuf [1]byte
   157  	if _, err := io.ReadFull(conn, readBuf[:]); err != nil {
   158  		return nil, err
   159  	}
   160  	if readBuf[0] != 'S' {
   161  		return nil, pq.ErrSSLNotSupported
   162  	}
   163  
   164  	// Begin TLS communication.
   165  	crypt := tls.Client(conn, &tls.Config{
   166  		ServerName:    host,
   167  		RootCAs:       certPool,
   168  		Renegotiation: tls.RenegotiateFreelyAsClient,
   169  	})
   170  	if err := crypt.Handshake(); err != nil {
   171  		return nil, err
   172  	}
   173  	return crypt, nil
   174  }
   175  
   176  func (d dialer) Dial(network, address string) (net.Conn, error) {
   177  	return d.dial(context.Background(), network, address)
   178  }
   179  
   180  func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
   181  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   182  	conn, err := d.dial(ctx, network, address)
   183  	cancel()
   184  	return conn, err
   185  }