github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_local/local_db_client.go (about)

     1  package db_local
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  
     8  	"github.com/jackc/pgx/v5/pgconn"
     9  	"github.com/spf13/viper"
    10  	"github.com/turbot/steampipe/pkg/constants"
    11  	"github.com/turbot/steampipe/pkg/db/db_client"
    12  	"github.com/turbot/steampipe/pkg/db/db_common"
    13  	"github.com/turbot/steampipe/pkg/error_helpers"
    14  	pb "github.com/turbot/steampipe/pkg/pluginmanager_service/grpc/proto"
    15  	"github.com/turbot/steampipe/pkg/utils"
    16  )
    17  
    18  // LocalDbClient wraps over DbClient
    19  type LocalDbClient struct {
    20  	db_client.DbClient
    21  	notificationListener *db_common.NotificationListener
    22  	invoker              constants.Invoker
    23  }
    24  
    25  // GetLocalClient starts service if needed and creates a new LocalDbClient
    26  func GetLocalClient(ctx context.Context, invoker constants.Invoker, onConnectionCallback db_client.DbConnectionCallback, opts ...db_client.ClientOption) (*LocalDbClient, error_helpers.ErrorAndWarnings) {
    27  	utils.LogTime("db.GetLocalClient start")
    28  	defer utils.LogTime("db.GetLocalClient end")
    29  
    30  	log.Printf("[INFO] GetLocalClient")
    31  	defer log.Printf("[INFO] GetLocalClient complete")
    32  
    33  	listenAddresses := StartListenType(ListenTypeLocal).ToListenAddresses()
    34  	port := viper.GetInt(constants.ArgDatabasePort)
    35  	log.Println(fmt.Sprintf("[TRACE] GetLocalClient - listenAddresses=%s, port=%d", listenAddresses, port))
    36  	// start db if necessary
    37  	if err := EnsureDBInstalled(ctx); err != nil {
    38  		return nil, error_helpers.NewErrorsAndWarning(err)
    39  	}
    40  
    41  	log.Printf("[INFO] StartServices")
    42  	startResult := StartServices(ctx, listenAddresses, port, invoker)
    43  	if startResult.Error != nil {
    44  		return nil, startResult.ErrorAndWarnings
    45  	}
    46  
    47  	log.Printf("[INFO] newLocalClient")
    48  	client, err := newLocalClient(ctx, invoker, onConnectionCallback, opts...)
    49  	if err != nil {
    50  		ShutdownService(ctx, invoker)
    51  		startResult.Error = err
    52  	}
    53  
    54  	// after creating the client, refresh connections
    55  	// NOTE: we cannot do this until after creating the client to ensure we do not miss notifications
    56  	if startResult.Status == ServiceStarted {
    57  		// ask the plugin manager to refresh connections
    58  		// this is executed asyncronously by the plugin manager
    59  		// we ignore this error, since RefreshConnections is async and all errors will flow through
    60  		// the notification system
    61  		// we do not expect any I/O errors on this since the PluginManager is running in the same box
    62  		_, _ = startResult.PluginManager.RefreshConnections(&pb.RefreshConnectionsRequest{})
    63  	}
    64  
    65  	return client, startResult.ErrorAndWarnings
    66  }
    67  
    68  // newLocalClient verifies that the local database instance is running and returns a LocalDbClient to interact with it
    69  // (This FAILS if local service is not running - use GetLocalClient to start service first)
    70  func newLocalClient(ctx context.Context, invoker constants.Invoker, onConnectionCallback db_client.DbConnectionCallback, opts ...db_client.ClientOption) (*LocalDbClient, error) {
    71  	utils.LogTime("db.newLocalClient start")
    72  	defer utils.LogTime("db.newLocalClient end")
    73  
    74  	connString, err := getLocalSteampipeConnectionString(nil)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	dbClient, err := db_client.NewDbClient(ctx, connString, onConnectionCallback, opts...)
    79  	if err != nil {
    80  		log.Printf("[TRACE] error getting local client %s", err.Error())
    81  		return nil, err
    82  	}
    83  
    84  	client := &LocalDbClient{DbClient: *dbClient, invoker: invoker}
    85  	log.Printf("[INFO] created local client %p", client)
    86  
    87  	if err := client.initNotificationListener(ctx); err != nil {
    88  		client.Close(ctx)
    89  		return nil, err
    90  	}
    91  
    92  	return client, nil
    93  }
    94  
    95  func (c *LocalDbClient) initNotificationListener(ctx context.Context) error {
    96  	// get a connection for the notification cache
    97  	conn, err := c.AcquireManagementConnection(ctx)
    98  	if err != nil {
    99  		c.Close(ctx)
   100  		return err
   101  	}
   102  	// hijack from the pool  as we will be keeping open for the lifetime of this run
   103  	// notification cache will manage the lifecycle of the connection
   104  	notificationConnection := conn.Hijack()
   105  	listener, err := db_common.NewNotificationListener(ctx, notificationConnection)
   106  	if err != nil {
   107  		return err
   108  	}
   109  	c.notificationListener = listener
   110  
   111  	return nil
   112  }
   113  
   114  // Close implements Client
   115  // close the connection to the database and shuts down the db service if we are the last connection
   116  func (c *LocalDbClient) Close(ctx context.Context) error {
   117  	if c.notificationListener != nil {
   118  		c.notificationListener.Stop(ctx)
   119  	}
   120  
   121  	if err := c.DbClient.Close(ctx); err != nil {
   122  		return err
   123  	}
   124  	log.Printf("[TRACE] local client close complete")
   125  
   126  	log.Printf("[TRACE] shutdown local service %v", c.invoker)
   127  	ShutdownService(ctx, c.invoker)
   128  	return nil
   129  }
   130  
   131  func (c *LocalDbClient) RegisterNotificationListener(f func(notification *pgconn.Notification)) {
   132  	c.notificationListener.RegisterListener(f)
   133  }