github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/mysql/awsmysql/awsmysql.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awsmysql provides connections to AWS RDS MySQL instances.
    16  //
    17  // # URLs
    18  //
    19  // For mysql.Open, awsmysql registers for the scheme "awsmysql".
    20  // The default URL opener will create a connection using the default
    21  // credentials from the environment, as described in
    22  // https://docs.aws.amazon.com/sdk-for-go/api/aws/session/.
    23  // To customize the URL opener, or for more details on the URL format,
    24  // see URLOpener.
    25  //
    26  // See https://gocloud.dev/concepts/urls/ for background information.
    27  package awsmysql // import "gocloud.dev/mysql/awsmysql"
    28  
    29  import (
    30  	"context"
    31  	"crypto/tls"
    32  	"database/sql"
    33  	"database/sql/driver"
    34  	"fmt"
    35  	"net/url"
    36  	"sync/atomic"
    37  
    38  	"contrib.go.opencensus.io/integrations/ocsql"
    39  	"github.com/go-sql-driver/mysql"
    40  	"github.com/google/wire"
    41  	"gocloud.dev/aws/rds"
    42  	gcmysql "gocloud.dev/mysql"
    43  )
    44  
    45  // Set is a Wire provider set that provides a *sql.DB given
    46  // *Params and an HTTP client.
    47  var Set = wire.NewSet(
    48  	wire.Struct(new(URLOpener), "CertSource"),
    49  	rds.CertFetcherSet,
    50  )
    51  
    52  // URLOpener opens RDS MySQL URLs
    53  // like "awsmysql://user:password@myinstance.borkxyzzy.us-west-1.rds.amazonaws.com:3306/mydb".
    54  type URLOpener struct {
    55  	// CertSource specifies how the opener will obtain the RDS Certificate
    56  	// Authority. If nil, it will use the default *rds.CertFetcher.
    57  	CertSource rds.CertPoolProvider
    58  	// TraceOpts contains options for OpenCensus.
    59  	TraceOpts []ocsql.TraceOption
    60  }
    61  
    62  // Scheme is the URL scheme awsmysql registers its URLOpener under on
    63  // mysql.DefaultMux.
    64  const Scheme = "awsmysql"
    65  
    66  func init() {
    67  	gcmysql.DefaultURLMux().RegisterMySQL(Scheme, &URLOpener{})
    68  }
    69  
    70  // OpenMySQLURL opens a new RDS database connection wrapped with OpenCensus instrumentation.
    71  func (uo *URLOpener) OpenMySQLURL(_ context.Context, u *url.URL) (*sql.DB, error) {
    72  	source := uo.CertSource
    73  	if source == nil {
    74  		source = new(rds.CertFetcher)
    75  	}
    76  	if u.Host == "" {
    77  		return nil, fmt.Errorf("open RDS: empty endpoint")
    78  	}
    79  
    80  	cfg, err := gcmysql.ConfigFromURL(u)
    81  	if err != nil {
    82  		return nil, err
    83  	}
    84  	c := &connector{
    85  		dsn: cfg.FormatDSN(),
    86  		// Make a copy of TraceOpts to avoid caller modifying.
    87  		traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...),
    88  		provider:  source,
    89  
    90  		sem:   make(chan struct{}, 1),
    91  		ready: make(chan struct{}),
    92  	}
    93  	c.sem <- struct{}{}
    94  	return sql.OpenDB(c), nil
    95  }
    96  
    97  type connector struct {
    98  	traceOpts []ocsql.TraceOption
    99  
   100  	sem      chan struct{} // receive to acquire, send to release
   101  	provider CertPoolProvider
   102  
   103  	ready chan struct{} // closed after resolving dsn
   104  	dsn   string
   105  }
   106  
   107  func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
   108  	select {
   109  	case <-c.sem:
   110  		certPool, err := c.provider.RDSCertPool(ctx)
   111  		if err != nil {
   112  			c.sem <- struct{}{} // release
   113  			return nil, fmt.Errorf("connect RDS: %v", err)
   114  		}
   115  		// TODO(light): Avoid global registry once https://github.com/go-sql-driver/mysql/issues/771 is fixed.
   116  		tlsConfigName := fmt.Sprintf(
   117  			"gocloud.dev/mysql/awsmysql/%d",
   118  			atomic.AddUint32(&tlsConfigCounter, 1),
   119  		)
   120  		err = mysql.RegisterTLSConfig(tlsConfigName, &tls.Config{
   121  			RootCAs: certPool,
   122  		})
   123  		if err != nil {
   124  			c.sem <- struct{}{} // release
   125  			return nil, fmt.Errorf("connect RDS: register TLS: %v", err)
   126  		}
   127  		cfg, _ := mysql.ParseDSN(c.dsn)
   128  		cfg.TLSConfig = tlsConfigName
   129  		c.dsn = cfg.FormatDSN()
   130  		close(c.ready)
   131  		// Don't release sem: make it block forever, so this case won't be run again.
   132  	case <-c.ready:
   133  		// Already succeeded.
   134  	case <-ctx.Done():
   135  		return nil, fmt.Errorf("connect RDS: waiting for certificates: %v", ctx.Err())
   136  	}
   137  	return c.Driver().Open(c.dsn)
   138  }
   139  
   140  func (c *connector) Driver() driver.Driver {
   141  	return ocsql.Wrap(mysql.MySQLDriver{}, c.traceOpts...)
   142  }
   143  
   144  var tlsConfigCounter uint32
   145  
   146  // A CertPoolProvider obtains a certificate pool that contains the RDS CA certificate.
   147  type CertPoolProvider = rds.CertPoolProvider
   148  
   149  // CertFetcher pulls the RDS CA certificates from Amazon's servers. The zero
   150  // value will fetch certificates using the default HTTP client.
   151  type CertFetcher = rds.CertFetcher