github.com/thiagoyeds/go-cloud@v0.26.0/postgres/gcppostgres/gcppostgres.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package gcppostgres provides connections to managed PostgreSQL Cloud SQL instances.
    16  // See https://cloud.google.com/sql/docs/postgres/ for more information.
    17  //
    18  // URLs
    19  //
    20  // For postgres.Open, gcppostgres registers for the scheme "gcppostgres".
    21  // The default URL opener will create a connection using the default
    22  // credentials from the environment, as described in
    23  // https://cloud.google.com/docs/authentication/production.
    24  // To customize the URL opener, or for more details on the URL format,
    25  // see URLOpener.
    26  //
    27  // See https://gocloud.dev/concepts/urls/ for background information.
    28  package gcppostgres // import "gocloud.dev/postgres/gcppostgres"
    29  
    30  import (
    31  	"context"
    32  	"database/sql"
    33  	"database/sql/driver"
    34  	"errors"
    35  	"fmt"
    36  	"net"
    37  	"net/url"
    38  	"strings"
    39  	"sync"
    40  	"time"
    41  
    42  	"contrib.go.opencensus.io/integrations/ocsql"
    43  	"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
    44  	"github.com/lib/pq"
    45  	"gocloud.dev/gcp"
    46  	"gocloud.dev/gcp/cloudsql"
    47  	"gocloud.dev/postgres"
    48  )
    49  
    50  // Scheme is the URL scheme gcppostgres registers its URLOpener under on
    51  // postgres.DefaultMux.
    52  const Scheme = "gcppostgres"
    53  
    54  func init() {
    55  	postgres.DefaultURLMux().RegisterPostgres(Scheme, new(lazyCredsOpener))
    56  }
    57  
    58  // lazyCredsOpener obtains Application Default Credentials on the first call
    59  // to OpenPostgresURL.
    60  type lazyCredsOpener struct {
    61  	init   sync.Once
    62  	opener *URLOpener
    63  	err    error
    64  }
    65  
    66  func (o *lazyCredsOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
    67  	o.init.Do(func() {
    68  		creds, err := gcp.DefaultCredentials(ctx)
    69  		if err != nil {
    70  			o.err = err
    71  			return
    72  		}
    73  		client, err := gcp.NewHTTPClient(gcp.DefaultTransport(), creds.TokenSource)
    74  		if err != nil {
    75  			o.err = err
    76  			return
    77  		}
    78  		certSource := cloudsql.NewCertSource(client)
    79  		o.opener = &URLOpener{CertSource: certSource}
    80  	})
    81  	if o.err != nil {
    82  		return nil, fmt.Errorf("gcppostgres open %v: %v", u, o.err)
    83  	}
    84  	return o.opener.OpenPostgresURL(ctx, u)
    85  }
    86  
    87  // URLOpener opens GCP PostgreSQL URLs
    88  // like "gcppostgres://user:password@myproject/us-central1/instanceId/mydb".
    89  type URLOpener struct {
    90  	// CertSource specifies how the opener will obtain authentication information.
    91  	// CertSource must not be nil.
    92  	CertSource proxy.CertSource
    93  
    94  	// TraceOpts contains options for OpenCensus.
    95  	TraceOpts []ocsql.TraceOption
    96  }
    97  
    98  // OpenPostgresURL opens a new GCP database connection wrapped with OpenCensus instrumentation.
    99  func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
   100  	if uo.CertSource == nil {
   101  		return nil, fmt.Errorf("gcppostgres: URLOpener CertSource is nil")
   102  	}
   103  	instance, dbName, err := instanceFromURL(u)
   104  	if err != nil {
   105  		return nil, fmt.Errorf("gcppostgres: open %v: %v", u, err)
   106  	}
   107  
   108  	query := u.Query()
   109  	for k := range query {
   110  		// Only permit parameters that do not conflict with other behavior.
   111  		if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" {
   112  			return nil, fmt.Errorf("gcppostgres: open: extra parameter %s not allowed", k)
   113  		}
   114  	}
   115  	query.Set("sslmode", "disable")
   116  
   117  	u2 := new(url.URL)
   118  	*u2 = *u
   119  	u2.Scheme = "postgres"
   120  	u2.Host = "cloudsql"
   121  	u2.Path = "/" + dbName
   122  	u2.RawQuery = query.Encode()
   123  	db := sql.OpenDB(connector{
   124  		client: &proxy.Client{
   125  			Port:  3307,
   126  			Certs: uo.CertSource,
   127  		},
   128  		instance:  instance,
   129  		pqConn:    u2.String(),
   130  		traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...),
   131  	})
   132  	return db, nil
   133  }
   134  
   135  func instanceFromURL(u *url.URL) (instance, db string, _ error) {
   136  	path := u.Host + u.Path // everything after scheme but before query or fragment
   137  	parts := strings.SplitN(path, "/", 4)
   138  	if len(parts) < 4 {
   139  		return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
   140  	}
   141  	for _, part := range parts {
   142  		if part == "" {
   143  			return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
   144  		}
   145  	}
   146  	return parts[0] + ":" + parts[1] + ":" + parts[2], parts[3], nil
   147  }
   148  
   149  type pqDriver struct {
   150  	client    *proxy.Client
   151  	instance  string
   152  	traceOpts []ocsql.TraceOption
   153  }
   154  
   155  func (d pqDriver) Open(name string) (driver.Conn, error) {
   156  	c, _ := d.OpenConnector(name)
   157  	return c.Connect(context.Background())
   158  }
   159  
   160  func (d pqDriver) OpenConnector(name string) (driver.Connector, error) {
   161  	return connector{d.client, d.instance, name, d.traceOpts}, nil
   162  }
   163  
   164  type connector struct {
   165  	client    *proxy.Client
   166  	instance  string
   167  	pqConn    string
   168  	traceOpts []ocsql.TraceOption
   169  }
   170  
   171  func (c connector) Connect(context.Context) (driver.Conn, error) {
   172  	conn, err := pq.DialOpen(dialer{c.client, c.instance}, c.pqConn)
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	return ocsql.WrapConn(conn, c.traceOpts...), nil
   177  }
   178  
   179  func (c connector) Driver() driver.Driver {
   180  	return pqDriver{c.client, c.instance, c.traceOpts}
   181  }
   182  
   183  type dialer struct {
   184  	client   *proxy.Client
   185  	instance string
   186  }
   187  
   188  func (d dialer) Dial(network, address string) (net.Conn, error) {
   189  	return d.client.Dial(d.instance)
   190  }
   191  
   192  func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
   193  	return nil, errors.New("gcppostgres: DialTimeout not supported")
   194  }