github.com/thiagoyeds/go-cloud@v0.26.0/postgres/awspostgres/awspostgres.go (about) 1 // Copyright 2018 The Go Cloud Development Kit Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package awspostgres provides connections to AWS RDS PostgreSQL instances. 16 // 17 // URLs 18 // 19 // For postgres.Open, awspostgres registers for the scheme "awspostgres". 20 // The default URL opener will create a connection using the default 21 // credentials from the environment, as described in 22 // https://docs.aws.amazon.com/sdk-for-go/api/aws/session/. 23 // To customize the URL opener, or for more details on the URL format, 24 // see URLOpener. 25 // 26 // See https://gocloud.dev/concepts/urls/ for background information. 27 package awspostgres // import "gocloud.dev/postgres/awspostgres" 28 29 import ( 30 "context" 31 "crypto/tls" 32 "database/sql" 33 "database/sql/driver" 34 "fmt" 35 "io" 36 "net" 37 "net/url" 38 "time" 39 40 "contrib.go.opencensus.io/integrations/ocsql" 41 "github.com/lib/pq" 42 "gocloud.dev/aws/rds" 43 "gocloud.dev/postgres" 44 ) 45 46 // URLOpener opens RDS PostgreSQL URLs 47 // like "awspostgres://user:password@myinstance.borkxyzzy.us-west-1.rds.amazonaws.com:5432/mydb". 48 type URLOpener struct { 49 // CertSource specifies how the opener will obtain the RDS Certificate 50 // Authority. If nil, it will use the default *rds.CertFetcher. 51 CertSource rds.CertPoolProvider 52 // TraceOpts contains options for OpenCensus. 53 TraceOpts []ocsql.TraceOption 54 } 55 56 // Scheme is the URL scheme awspostgres registers its URLOpener under on 57 // postgres.DefaultMux. 58 const Scheme = "awspostgres" 59 60 func init() { 61 postgres.DefaultURLMux().RegisterPostgres(Scheme, &URLOpener{}) 62 } 63 64 // OpenPostgresURL opens a new RDS database connection wrapped with OpenCensus instrumentation. 65 func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) { 66 source := uo.CertSource 67 if source == nil { 68 source = new(rds.CertFetcher) 69 } 70 71 query := u.Query() 72 for k := range query { 73 // Forbid SSL-related parameters. 74 if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" { 75 return nil, fmt.Errorf("awspostgres: open: parameter %q not allowed; sslmode must be disabled because the underlying dialer is already providing TLS", k) 76 } 77 } 78 // sslmode must be disabled because the underlying dialer is already providing TLS. 79 query.Set("sslmode", "disable") 80 81 u2 := new(url.URL) 82 *u2 = *u 83 u2.Scheme = "postgres" 84 u2.RawQuery = query.Encode() 85 db := sql.OpenDB(connector{ 86 provider: source, 87 pqConn: u2.String(), 88 traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...), 89 }) 90 return db, nil 91 } 92 93 type pqDriver struct { 94 provider rds.CertPoolProvider 95 traceOpts []ocsql.TraceOption 96 } 97 98 func (d pqDriver) Open(name string) (driver.Conn, error) { 99 c, _ := d.OpenConnector(name) 100 return c.Connect(context.Background()) 101 } 102 103 func (d pqDriver) OpenConnector(name string) (driver.Connector, error) { 104 return connector{d.provider, name + " sslmode=disable", d.traceOpts}, nil 105 } 106 107 type connector struct { 108 provider rds.CertPoolProvider 109 pqConn string 110 traceOpts []ocsql.TraceOption 111 } 112 113 func (c connector) Connect(context.Context) (driver.Conn, error) { 114 conn, err := pq.DialOpen(dialer{c.provider}, c.pqConn) 115 if err != nil { 116 return nil, err 117 } 118 return ocsql.WrapConn(conn, c.traceOpts...), nil 119 } 120 121 func (c connector) Driver() driver.Driver { 122 return pqDriver{c.provider, c.traceOpts} 123 } 124 125 type dialer struct { 126 provider rds.CertPoolProvider 127 } 128 129 func (d dialer) dial(ctx context.Context, network, address string) (net.Conn, error) { 130 host, _, err := net.SplitHostPort(address) 131 if err != nil { 132 return nil, fmt.Errorf("awspostgres: parse address: %v", err) 133 } 134 certPool, err := d.provider.RDSCertPool(ctx) 135 if err != nil { 136 return nil, err 137 } 138 conn, err := new(net.Dialer).DialContext(ctx, network, address) 139 if err != nil { 140 return nil, err 141 } 142 143 // Write the PostgreSQL SSLRequest message described in 144 // https://www.postgresql.org/docs/11/protocol-message-formats.html 145 // to upgrade to a TLS connection. 146 _, err = conn.Write([]byte{ 147 // Message length (Int32), including message length. 148 0x00, 0x00, 0x00, 0x08, 149 // Magic number: 80877103. 150 0x04, 0xd2, 0x16, 0x2f, 151 }) 152 if err != nil { 153 return nil, err 154 } 155 // Server must respond back with 'S'. 156 var readBuf [1]byte 157 if _, err := io.ReadFull(conn, readBuf[:]); err != nil { 158 return nil, err 159 } 160 if readBuf[0] != 'S' { 161 return nil, pq.ErrSSLNotSupported 162 } 163 164 // Begin TLS communication. 165 crypt := tls.Client(conn, &tls.Config{ 166 ServerName: host, 167 RootCAs: certPool, 168 Renegotiation: tls.RenegotiateFreelyAsClient, 169 }) 170 if err := crypt.Handshake(); err != nil { 171 return nil, err 172 } 173 return crypt, nil 174 } 175 176 func (d dialer) Dial(network, address string) (net.Conn, error) { 177 return d.dial(context.Background(), network, address) 178 } 179 180 func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 181 ctx, cancel := context.WithTimeout(context.Background(), timeout) 182 conn, err := d.dial(ctx, network, address) 183 cancel() 184 return conn, err 185 }