github.com/dkishere/pop/v6@v6.103.1/connection.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync/atomic"
     8  	"time"
     9  
    10  	"github.com/dkishere/pop/v6/internal/defaults"
    11  	"github.com/dkishere/pop/v6/internal/randx"
    12  )
    13  
    14  // Connections contains all available connections
    15  var Connections = map[string]*Connection{}
    16  
    17  // Connection represents all necessary details to talk with a datastore
    18  type Connection struct {
    19  	ID          string
    20  	Store       store
    21  	Dialect     dialect
    22  	Elapsed     int64
    23  	TX          *Tx
    24  	eager       bool
    25  	eagerFields []string
    26  }
    27  
    28  func (c *Connection) String() string {
    29  	return c.URL()
    30  }
    31  
    32  // URL returns the datasource connection string
    33  func (c *Connection) URL() string {
    34  	return c.Dialect.URL()
    35  }
    36  
    37  // Context returns the connection's context set by "Context()" or context.TODO()
    38  // if no context is set.
    39  func (c *Connection) Context() context.Context {
    40  	if c, ok := c.Store.(interface{ Context() context.Context }); ok {
    41  		return c.Context()
    42  	}
    43  
    44  	return context.TODO()
    45  }
    46  
    47  // MigrationURL returns the datasource connection string used for running the migrations
    48  func (c *Connection) MigrationURL() string {
    49  	return c.Dialect.MigrationURL()
    50  }
    51  
    52  // MigrationTableName returns the name of the table to track migrations
    53  func (c *Connection) MigrationTableName() string {
    54  	return c.Dialect.Details().MigrationTableName()
    55  }
    56  
    57  // NewConnection creates a new connection, and sets it's `Dialect`
    58  // appropriately based on the `ConnectionDetails` passed into it.
    59  func NewConnection(deets *ConnectionDetails) (*Connection, error) {
    60  	err := deets.Finalize()
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	c := &Connection{
    65  		ID: randx.String(30),
    66  	}
    67  
    68  	if nc, ok := newConnection[deets.Dialect]; ok {
    69  		c.Dialect, err = nc(deets)
    70  		if err != nil {
    71  			return c, fmt.Errorf("could not create new connection: %w", err)
    72  		}
    73  		return c, nil
    74  	}
    75  	return nil, fmt.Errorf("could not found connection creator for %v", deets.Dialect)
    76  }
    77  
    78  // Connect takes the name of a connection, default is "development", and will
    79  // return that connection from the available `Connections`. If a connection with
    80  // that name can not be found an error will be returned. If a connection is
    81  // found, and it has yet to open a connection with its underlying datastore,
    82  // a connection to that store will be opened.
    83  func Connect(e string) (*Connection, error) {
    84  	if len(Connections) == 0 {
    85  		err := LoadConfigFile()
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  	}
    90  	e = defaults.String(e, "development")
    91  	c := Connections[e]
    92  	if c == nil {
    93  		return c, fmt.Errorf("could not find connection named %s", e)
    94  	}
    95  
    96  	if err := c.Open(); err != nil {
    97  		return c, fmt.Errorf("couldn't open connection for %s: %w", e, err)
    98  	}
    99  	return c, nil
   100  }
   101  
   102  // Open creates a new datasource connection
   103  func (c *Connection) Open() error {
   104  	if c.Store != nil {
   105  		return nil
   106  	}
   107  	if c.Dialect == nil {
   108  		return errors.New("invalid connection instance")
   109  	}
   110  	details := c.Dialect.Details()
   111  
   112  	db, err := openPotentiallyInstrumentedConnection(c.Dialect, c.Dialect.URL())
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	db.SetMaxOpenConns(details.Pool)
   118  	if details.IdlePool != 0 {
   119  		db.SetMaxIdleConns(details.IdlePool)
   120  	}
   121  	if details.ConnMaxLifetime > 0 {
   122  		db.SetConnMaxLifetime(details.ConnMaxLifetime)
   123  	}
   124  	if details.ConnMaxIdleTime > 0 {
   125  		db.SetConnMaxIdleTime(details.ConnMaxIdleTime)
   126  	}
   127  	if details.Unsafe {
   128  		db = db.Unsafe()
   129  	}
   130  	c.Store = &dB{db}
   131  
   132  	if d, ok := c.Dialect.(afterOpenable); ok {
   133  		if err := d.AfterOpen(c); err != nil {
   134  			c.Store = nil
   135  			return fmt.Errorf("could not open database connection: %w", err)
   136  		}
   137  	}
   138  	return nil
   139  }
   140  
   141  // Close destroys an active datasource connection
   142  func (c *Connection) Close() error {
   143  	if err := c.Store.Close(); err != nil {
   144  		return fmt.Errorf("couldn't close connection: %w", err)
   145  	}
   146  	return nil
   147  }
   148  
   149  // Transaction will start a new transaction on the connection. If the inner function
   150  // returns an error then the transaction will be rolled back, otherwise the transaction
   151  // will automatically commit at the end.
   152  func (c *Connection) Transaction(fn func(tx *Connection) error) error {
   153  	return c.Dialect.Lock(func() error {
   154  		var dberr error
   155  		cn, err := c.NewTransaction()
   156  		if err != nil {
   157  			return err
   158  		}
   159  		err = fn(cn)
   160  		if err != nil {
   161  			dberr = cn.TX.Rollback()
   162  		} else {
   163  			dberr = cn.TX.Commit()
   164  		}
   165  
   166  		if dberr != nil {
   167  			return fmt.Errorf("error committing or rolling back transaction: %w", dberr)
   168  		}
   169  
   170  		return err
   171  	})
   172  
   173  }
   174  
   175  // Rollback will open a new transaction and automatically rollback that transaction
   176  // when the inner function returns, regardless. This can be useful for tests, etc...
   177  func (c *Connection) Rollback(fn func(tx *Connection)) error {
   178  	cn, err := c.NewTransaction()
   179  	if err != nil {
   180  		return err
   181  	}
   182  	fn(cn)
   183  	return cn.TX.Rollback()
   184  }
   185  
   186  // NewTransaction starts a new transaction on the connection
   187  func (c *Connection) NewTransaction() (*Connection, error) {
   188  	var cn *Connection
   189  	if c.TX == nil {
   190  		tx, err := c.Store.Transaction()
   191  		if err != nil {
   192  			return cn, fmt.Errorf("couldn't start a new transaction: %w", err)
   193  		}
   194  		var store store = tx
   195  
   196  		// Rewrap the store if it was a context store
   197  		if cs, ok := c.Store.(contextStore); ok {
   198  			store = contextStore{store: store, ctx: cs.ctx}
   199  		}
   200  		cn = &Connection{
   201  			ID:      randx.String(30),
   202  			Store:   store,
   203  			Dialect: c.Dialect,
   204  			TX:      tx,
   205  		}
   206  	} else {
   207  		cn = c
   208  	}
   209  	return cn, nil
   210  }
   211  
   212  // WithContext returns a copy of the connection, wrapped with a context.
   213  func (c *Connection) WithContext(ctx context.Context) *Connection {
   214  	cn := c.copy()
   215  	cn.Store = contextStore{
   216  		store: cn.Store,
   217  		ctx:   ctx,
   218  	}
   219  	return cn
   220  }
   221  
   222  func (c *Connection) copy() *Connection {
   223  	return &Connection{
   224  		ID:      randx.String(30),
   225  		Store:   c.Store,
   226  		Dialect: c.Dialect,
   227  		TX:      c.TX,
   228  	}
   229  }
   230  
   231  // Q creates a new "empty" query for the current connection.
   232  func (c *Connection) Q() *Query {
   233  	return Q(c)
   234  }
   235  
   236  // disableEager disables eager mode for current connection.
   237  func (c *Connection) disableEager() {
   238  	c.eager = false
   239  	c.eagerFields = []string{}
   240  }
   241  
   242  // TruncateAll truncates all data from the datasource
   243  func (c *Connection) TruncateAll() error {
   244  	return c.Dialect.TruncateAll(c)
   245  }
   246  
   247  func (c *Connection) timeFunc(name string, fn func() error) error {
   248  	start := time.Now()
   249  	err := fn()
   250  	atomic.AddInt64(&c.Elapsed, int64(time.Since(start)))
   251  	if err != nil {
   252  		return err
   253  	}
   254  	return nil
   255  }