github.com/thiagoyeds/go-cloud@v0.26.0/postgres/gcppostgres/gcppostgres.go (about) 1 // Copyright 2018 The Go Cloud Development Kit Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package gcppostgres provides connections to managed PostgreSQL Cloud SQL instances. 16 // See https://cloud.google.com/sql/docs/postgres/ for more information. 17 // 18 // URLs 19 // 20 // For postgres.Open, gcppostgres registers for the scheme "gcppostgres". 21 // The default URL opener will create a connection using the default 22 // credentials from the environment, as described in 23 // https://cloud.google.com/docs/authentication/production. 24 // To customize the URL opener, or for more details on the URL format, 25 // see URLOpener. 26 // 27 // See https://gocloud.dev/concepts/urls/ for background information. 28 package gcppostgres // import "gocloud.dev/postgres/gcppostgres" 29 30 import ( 31 "context" 32 "database/sql" 33 "database/sql/driver" 34 "errors" 35 "fmt" 36 "net" 37 "net/url" 38 "strings" 39 "sync" 40 "time" 41 42 "contrib.go.opencensus.io/integrations/ocsql" 43 "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy" 44 "github.com/lib/pq" 45 "gocloud.dev/gcp" 46 "gocloud.dev/gcp/cloudsql" 47 "gocloud.dev/postgres" 48 ) 49 50 // Scheme is the URL scheme gcppostgres registers its URLOpener under on 51 // postgres.DefaultMux. 52 const Scheme = "gcppostgres" 53 54 func init() { 55 postgres.DefaultURLMux().RegisterPostgres(Scheme, new(lazyCredsOpener)) 56 } 57 58 // lazyCredsOpener obtains Application Default Credentials on the first call 59 // to OpenPostgresURL. 60 type lazyCredsOpener struct { 61 init sync.Once 62 opener *URLOpener 63 err error 64 } 65 66 func (o *lazyCredsOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) { 67 o.init.Do(func() { 68 creds, err := gcp.DefaultCredentials(ctx) 69 if err != nil { 70 o.err = err 71 return 72 } 73 client, err := gcp.NewHTTPClient(gcp.DefaultTransport(), creds.TokenSource) 74 if err != nil { 75 o.err = err 76 return 77 } 78 certSource := cloudsql.NewCertSource(client) 79 o.opener = &URLOpener{CertSource: certSource} 80 }) 81 if o.err != nil { 82 return nil, fmt.Errorf("gcppostgres open %v: %v", u, o.err) 83 } 84 return o.opener.OpenPostgresURL(ctx, u) 85 } 86 87 // URLOpener opens GCP PostgreSQL URLs 88 // like "gcppostgres://user:password@myproject/us-central1/instanceId/mydb". 89 type URLOpener struct { 90 // CertSource specifies how the opener will obtain authentication information. 91 // CertSource must not be nil. 92 CertSource proxy.CertSource 93 94 // TraceOpts contains options for OpenCensus. 95 TraceOpts []ocsql.TraceOption 96 } 97 98 // OpenPostgresURL opens a new GCP database connection wrapped with OpenCensus instrumentation. 99 func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) { 100 if uo.CertSource == nil { 101 return nil, fmt.Errorf("gcppostgres: URLOpener CertSource is nil") 102 } 103 instance, dbName, err := instanceFromURL(u) 104 if err != nil { 105 return nil, fmt.Errorf("gcppostgres: open %v: %v", u, err) 106 } 107 108 query := u.Query() 109 for k := range query { 110 // Only permit parameters that do not conflict with other behavior. 111 if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" { 112 return nil, fmt.Errorf("gcppostgres: open: extra parameter %s not allowed", k) 113 } 114 } 115 query.Set("sslmode", "disable") 116 117 u2 := new(url.URL) 118 *u2 = *u 119 u2.Scheme = "postgres" 120 u2.Host = "cloudsql" 121 u2.Path = "/" + dbName 122 u2.RawQuery = query.Encode() 123 db := sql.OpenDB(connector{ 124 client: &proxy.Client{ 125 Port: 3307, 126 Certs: uo.CertSource, 127 }, 128 instance: instance, 129 pqConn: u2.String(), 130 traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...), 131 }) 132 return db, nil 133 } 134 135 func instanceFromURL(u *url.URL) (instance, db string, _ error) { 136 path := u.Host + u.Path // everything after scheme but before query or fragment 137 parts := strings.SplitN(path, "/", 4) 138 if len(parts) < 4 { 139 return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path) 140 } 141 for _, part := range parts { 142 if part == "" { 143 return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path) 144 } 145 } 146 return parts[0] + ":" + parts[1] + ":" + parts[2], parts[3], nil 147 } 148 149 type pqDriver struct { 150 client *proxy.Client 151 instance string 152 traceOpts []ocsql.TraceOption 153 } 154 155 func (d pqDriver) Open(name string) (driver.Conn, error) { 156 c, _ := d.OpenConnector(name) 157 return c.Connect(context.Background()) 158 } 159 160 func (d pqDriver) OpenConnector(name string) (driver.Connector, error) { 161 return connector{d.client, d.instance, name, d.traceOpts}, nil 162 } 163 164 type connector struct { 165 client *proxy.Client 166 instance string 167 pqConn string 168 traceOpts []ocsql.TraceOption 169 } 170 171 func (c connector) Connect(context.Context) (driver.Conn, error) { 172 conn, err := pq.DialOpen(dialer{c.client, c.instance}, c.pqConn) 173 if err != nil { 174 return nil, err 175 } 176 return ocsql.WrapConn(conn, c.traceOpts...), nil 177 } 178 179 func (c connector) Driver() driver.Driver { 180 return pqDriver{c.client, c.instance, c.traceOpts} 181 } 182 183 type dialer struct { 184 client *proxy.Client 185 instance string 186 } 187 188 func (d dialer) Dial(network, address string) (net.Conn, error) { 189 return d.client.Dial(d.instance) 190 } 191 192 func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 193 return nil, errors.New("gcppostgres: DialTimeout not supported") 194 }