github.com/turbot/steampipe@v1.7.0-rc.0.0.20240517123944-7cef272d4458/pkg/db/db_client/db_client_session.go (about)

     1  package db_client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/jackc/pgx/v5/pgxpool"
     9  	"github.com/spf13/viper"
    10  	"github.com/turbot/steampipe/pkg/constants"
    11  	"github.com/turbot/steampipe/pkg/db/db_common"
    12  )
    13  
    14  func (c *DbClient) AcquireManagementConnection(ctx context.Context) (*pgxpool.Conn, error) {
    15  	return c.managementPool.Acquire(ctx)
    16  }
    17  
    18  func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common.AcquireSessionResult) {
    19  	sessionResult = &db_common.AcquireSessionResult{}
    20  
    21  	defer func() {
    22  		if sessionResult != nil && sessionResult.Session != nil {
    23  			// fail safe - if there is no database connection, ensure we return an error
    24  			// NOTE: this should not be necessary but an occasional crash is occurring with a nil connection
    25  			if sessionResult.Session.Connection == nil && sessionResult.Error == nil {
    26  				sessionResult.Error = fmt.Errorf("nil database connection being returned from AcquireSession but no error was raised")
    27  			}
    28  		}
    29  	}()
    30  
    31  	// get a database connection and query its backend pid
    32  	// note - this will retry if the connection is bad
    33  	databaseConnection, err := c.userPool.Acquire(ctx)
    34  	if err != nil {
    35  		sessionResult.Error = err
    36  		return sessionResult
    37  	}
    38  	backendPid := databaseConnection.Conn().PgConn().PID()
    39  
    40  	c.sessionsMutex.Lock()
    41  	session, found := c.sessions[backendPid]
    42  	if !found {
    43  		session = db_common.NewDBSession(backendPid)
    44  		c.sessions[backendPid] = session
    45  	}
    46  	// we get a new *sql.Conn everytime. USE IT!
    47  	session.Connection = databaseConnection
    48  	sessionResult.Session = session
    49  	c.sessionsMutex.Unlock()
    50  
    51  	// make sure that we close the acquired session, in case of error
    52  	defer func() {
    53  		if sessionResult.Error != nil && databaseConnection != nil {
    54  			sessionResult.Session = nil
    55  			databaseConnection.Release()
    56  		}
    57  	}()
    58  
    59  	// if this is connected to a local service (localhost) and if the server cache
    60  	// is disabled, override the client setting to always disable
    61  	//
    62  	// this is a temporary workaround to make sure
    63  	// that we turn off caching for plugins compiled with SDK pre-V5
    64  	if c.isLocalService && !viper.GetBool(constants.ArgServiceCacheEnabled) {
    65  		if err := db_common.SetCacheEnabled(ctx, false, databaseConnection.Conn()); err != nil {
    66  			sessionResult.Error = err
    67  			return sessionResult
    68  		}
    69  	} else {
    70  		if viper.IsSet(constants.ArgClientCacheEnabled) {
    71  			if err := db_common.SetCacheEnabled(ctx, viper.GetBool(constants.ArgClientCacheEnabled), databaseConnection.Conn()); err != nil {
    72  				sessionResult.Error = err
    73  				return sessionResult
    74  			}
    75  		}
    76  	}
    77  
    78  	if viper.IsSet(constants.ArgCacheTtl) {
    79  		ttl := time.Duration(viper.GetInt(constants.ArgCacheTtl)) * time.Second
    80  		if err := db_common.SetCacheTtl(ctx, ttl, databaseConnection.Conn()); err != nil {
    81  			sessionResult.Error = err
    82  			return sessionResult
    83  		}
    84  	}
    85  
    86  	// update required session search path if needed
    87  	err = c.ensureSessionSearchPath(ctx, session)
    88  	if err != nil {
    89  		sessionResult.Error = err
    90  		return sessionResult
    91  	}
    92  
    93  	sessionResult.Error = ctx.Err()
    94  	return sessionResult
    95  }