github.com/kubeshop/testkube@v1.17.23/pkg/repository/storage/mongo.go (about)

     1  package storage
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"time"
    13  
    14  	"go.mongodb.org/mongo-driver/mongo"
    15  	"go.mongodb.org/mongo-driver/mongo/options"
    16  
    17  	"github.com/kubeshop/testkube/pkg/log"
    18  )
    19  
    20  // MongoSSLConfig contains the configurations necessary for an SSL connection
    21  type MongoSSLConfig struct {
    22  	// SSLClientCertificateKeyFile specifies a path to the client certificate and private key, which must be concatenated into one file.
    23  	SSLClientCertificateKeyFile string
    24  	// SSLClientCertificateKeyFilePassword specifies the password to decrypt the client private key file
    25  	SSLClientCertificateKeyFilePassword string
    26  	// SSLCertificateAuthoritiyFile specifies the path to a single or bundle of certificate authorities
    27  	SSLCertificateAuthoritiyFile string
    28  }
    29  
    30  const (
    31  	TypeMongoDB    = "mongo"
    32  	TypeDocDB      = "docdb"
    33  	DocDBcaFileURI = "https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem"
    34  )
    35  
    36  // GetMongoDatabase returns a valid database connection to the configured MongoDB database
    37  func GetMongoDatabase(dsn, name, dbType string, allowTLS bool, certConfig *MongoSSLConfig) (db *mongo.Database, err error) {
    38  	if dbType != "" && dbType != TypeMongoDB && dbType != TypeDocDB {
    39  		return nil, fmt.Errorf("unsupported database type %s", dbType)
    40  	}
    41  	var mongoOptions *tls.Config
    42  
    43  	if (dbType == TypeMongoDB || dbType == "") && certConfig != nil {
    44  		mongoOptions, err = options.BuildTLSConfig(map[string]interface{}{
    45  			"sslClientCertificateKeyFile":     certConfig.SSLClientCertificateKeyFile,
    46  			"sslClientCertificateKeyPassword": certConfig.SSLClientCertificateKeyFilePassword,
    47  			"sslCertificateAuthorityFile":     certConfig.SSLCertificateAuthoritiyFile,
    48  		})
    49  		if err != nil {
    50  			return nil, fmt.Errorf("could not build SSL config: %w", err)
    51  		}
    52  	}
    53  	if dbType == TypeDocDB && allowTLS {
    54  		mongoOptions, err = getDocDBTLSConfig()
    55  		if err != nil {
    56  			return nil, fmt.Errorf("could not get DocDB: %w", err)
    57  		}
    58  	}
    59  
    60  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    61  	defer cancel()
    62  	client, err := mongo.Connect(ctx, options.Client().SetTLSConfig(mongoOptions).ApplyURI(dsn))
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return client.Database(name), nil
    68  }
    69  
    70  func getDocDBTLSConfig() (*tls.Config, error) {
    71  	caFilePath, err := GetDocDBcaFile()
    72  	if err != nil {
    73  		return nil, fmt.Errorf("could not get CA file: %w", err)
    74  	}
    75  	defer func() {
    76  		err := deleteDocDBTLSConfigFile(caFilePath)
    77  		if err != nil {
    78  			log.DefaultLogger.Warnf("could not remove AWS DocumentDB CA file %s: %v", caFilePath, err)
    79  		}
    80  	}()
    81  
    82  	tlsConfig := new(tls.Config)
    83  	certs, err := os.ReadFile(caFilePath)
    84  	if err != nil {
    85  		return nil, fmt.Errorf("could not read CA file: %s", err)
    86  	}
    87  
    88  	tlsConfig.RootCAs = x509.NewCertPool()
    89  	ok := tlsConfig.RootCAs.AppendCertsFromPEM(certs)
    90  
    91  	if !ok {
    92  		return nil, errors.New("failed parsing pem file")
    93  	}
    94  
    95  	return tlsConfig, nil
    96  }
    97  
    98  // GetDocDBcaFile will fetch the file located at DocDBcaFileURI into a local file
    99  // Due to size limitations we cannot use Kubernetes secrets like we use for MongoDB TLS configs
   100  func GetDocDBcaFile() (string, error) {
   101  	// Get the data
   102  	resp, err := http.Get(DocDBcaFileURI)
   103  	if err != nil {
   104  		return "", fmt.Errorf("could not fetch file from %s: %w", DocDBcaFileURI, err)
   105  	}
   106  	defer resp.Body.Close()
   107  
   108  	out, err := os.CreateTemp("", "rds-combined-ca-bundle.pem")
   109  	if err != nil {
   110  		return "", fmt.Errorf("could not create file %s: %w", out.Name(), err)
   111  	}
   112  	defer out.Close()
   113  
   114  	_, err = io.Copy(out, resp.Body)
   115  	if err != nil {
   116  		return "", fmt.Errorf("could not write file %s: %w", out.Name(), err)
   117  	}
   118  	return out.Name(), nil
   119  }
   120  
   121  // deleteDocDBTLSConfigFile deletes the downloaded CA file
   122  func deleteDocDBTLSConfigFile(docDBcaPath string) error {
   123  	err := os.Remove(docDBcaPath)
   124  	if err != nil {
   125  		return fmt.Errorf("could not delete file %s: %w", docDBcaPath, err)
   126  	}
   127  	return nil
   128  }