github.com/friesencr/pop/v6@v6.1.6/connection.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"math/rand"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/friesencr/pop/v6/internal/defaults"
    14  	"github.com/friesencr/pop/v6/internal/randx"
    15  	"github.com/friesencr/pop/v6/logging"
    16  )
    17  
    18  // Connections contains all available connections
    19  var Connections = map[string]*Connection{}
    20  
    21  // Connection represents all necessary details to talk with a datastore
    22  type Connection struct {
    23  	ID          string
    24  	Store       store
    25  	Dialect     dialect
    26  	Elapsed     int64
    27  	TX          *Tx
    28  	eager       bool
    29  	eagerFields []string
    30  }
    31  
    32  func (c *Connection) String() string {
    33  	return c.URL()
    34  }
    35  
    36  // URL returns the datasource connection string
    37  func (c *Connection) URL() string {
    38  	return c.Dialect.URL()
    39  }
    40  
    41  // Context returns the connection's context set by "Context()" or context.TODO()
    42  // if no context is set.
    43  func (c *Connection) Context() context.Context {
    44  	if c, ok := c.Store.(interface{ Context() context.Context }); ok {
    45  		return c.Context()
    46  	}
    47  
    48  	return context.TODO()
    49  }
    50  
    51  // MigrationURL returns the datasource connection string used for running the migrations
    52  func (c *Connection) MigrationURL() string {
    53  	return c.Dialect.MigrationURL()
    54  }
    55  
    56  // MigrationTableName returns the name of the table to track migrations
    57  func (c *Connection) MigrationTableName() string {
    58  	return c.Dialect.Details().MigrationTableName()
    59  }
    60  
    61  // NewConnection creates a new connection, and sets it's `Dialect`
    62  // appropriately based on the `ConnectionDetails` passed into it.
    63  func NewConnection(deets *ConnectionDetails) (*Connection, error) {
    64  	err := deets.Finalize()
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	c := &Connection{}
    69  	c.setID()
    70  
    71  	if nc, ok := newConnection[deets.Dialect]; ok {
    72  		c.Dialect, err = nc(deets)
    73  		if err != nil {
    74  			return c, fmt.Errorf("could not create new connection: %w", err)
    75  		}
    76  		return c, nil
    77  	}
    78  	return nil, fmt.Errorf("could not found connection creator for %v", deets.Dialect)
    79  }
    80  
    81  // Connect takes the name of a connection, default is "development", and will
    82  // return that connection from the available `Connections`. If a connection with
    83  // that name can not be found an error will be returned. If a connection is
    84  // found, and it has yet to open a connection with its underlying datastore,
    85  // a connection to that store will be opened.
    86  func Connect(e string) (*Connection, error) {
    87  	if len(Connections) == 0 {
    88  		err := LoadConfigFile()
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  	}
    93  	e = defaults.String(e, "development")
    94  	c := Connections[e]
    95  	if c == nil {
    96  		return c, fmt.Errorf("could not find connection named %s", e)
    97  	}
    98  
    99  	if err := c.Open(); err != nil {
   100  		return c, fmt.Errorf("couldn't open connection for %s: %w", e, err)
   101  	}
   102  	return c, nil
   103  }
   104  
   105  // Open creates a new datasource connection
   106  func (c *Connection) Open() error {
   107  	if c.Store != nil {
   108  		return nil
   109  	}
   110  	if c.Dialect == nil {
   111  		return errors.New("invalid connection instance")
   112  	}
   113  	details := c.Dialect.Details()
   114  
   115  	db, err := openPotentiallyInstrumentedConnection(c.Dialect, c.Dialect.URL())
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	db.SetMaxOpenConns(details.Pool)
   121  	if details.IdlePool != 0 {
   122  		db.SetMaxIdleConns(details.IdlePool)
   123  	}
   124  	if details.ConnMaxLifetime > 0 {
   125  		db.SetConnMaxLifetime(details.ConnMaxLifetime)
   126  	}
   127  	if details.ConnMaxIdleTime > 0 {
   128  		db.SetConnMaxIdleTime(details.ConnMaxIdleTime)
   129  	}
   130  	if details.Unsafe {
   131  		db = db.Unsafe()
   132  	}
   133  	c.Store = &dB{db}
   134  
   135  	if d, ok := c.Dialect.(afterOpenable); ok {
   136  		if err := d.AfterOpen(c); err != nil {
   137  			c.Store = nil
   138  			return fmt.Errorf("could not open database connection: %w", err)
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  // Close destroys an active datasource connection
   145  func (c *Connection) Close() error {
   146  	if err := c.Store.Close(); err != nil {
   147  		return fmt.Errorf("couldn't close connection: %w", err)
   148  	}
   149  	c.Store = nil
   150  	return nil
   151  }
   152  
   153  // Transaction will start a new transaction on the connection. If the inner function
   154  // returns an error then the transaction will be rolled back, otherwise the transaction
   155  // will automatically commit at the end.
   156  func (c *Connection) Transaction(fn func(tx *Connection) error) error {
   157  	return c.Dialect.Lock(func() (err error) {
   158  		var dberr error
   159  
   160  		cn, err := c.NewTransaction()
   161  		if err != nil {
   162  			return err
   163  		}
   164  		txlog(logging.SQL, cn, "BEGIN Transaction ---")
   165  
   166  		defer func() {
   167  			if ex := recover(); ex != nil {
   168  				txlog(logging.SQL, cn, "ROLLBACK Transaction (inner function panic) ---")
   169  				dberr = cn.TX.Rollback()
   170  				if dberr != nil {
   171  					txlog(logging.Error, cn, "database error while inner panic rollback: %w", dberr)
   172  				}
   173  				panic(ex)
   174  			}
   175  		}()
   176  
   177  		err = fn(cn)
   178  		if err != nil {
   179  			txlog(logging.SQL, cn, "ROLLBACK Transaction ---")
   180  			dberr = cn.TX.Rollback()
   181  		} else {
   182  			txlog(logging.SQL, cn, "END Transaction ---")
   183  			dberr = cn.TX.Commit()
   184  		}
   185  
   186  		if dberr != nil {
   187  			return fmt.Errorf("database error on committing or rolling back transaction: %w", dberr)
   188  		}
   189  
   190  		return err
   191  	})
   192  
   193  }
   194  
   195  // Rollback will open a new transaction and automatically rollback that transaction
   196  // when the inner function returns, regardless. This can be useful for tests, etc...
   197  func (c *Connection) Rollback(fn func(tx *Connection)) error {
   198  	// TODO: the name of the method could be changed to express it better.
   199  	cn, err := c.NewTransaction()
   200  	if err != nil {
   201  		return err
   202  	}
   203  	txlog(logging.SQL, cn, "BEGIN Transaction for Rollback ---")
   204  	fn(cn)
   205  	txlog(logging.SQL, cn, "ROLLBACK Transaction as planned ---")
   206  	return cn.TX.Rollback()
   207  }
   208  
   209  // NewTransaction starts a new transaction on the connection
   210  func (c *Connection) NewTransaction() (*Connection, error) {
   211  	return c.NewTransactionContextOptions(c.Context(), nil)
   212  }
   213  
   214  // NewTransactionContext starts a new transaction on the connection using the provided context
   215  func (c *Connection) NewTransactionContext(ctx context.Context) (*Connection, error) {
   216  	return c.NewTransactionContextOptions(ctx, nil)
   217  }
   218  
   219  // NewTransactionContextOptions starts a new transaction on the connection using the provided context and transaction options
   220  func (c *Connection) NewTransactionContextOptions(ctx context.Context, options *sql.TxOptions) (*Connection, error) {
   221  	var cn *Connection
   222  	if c.TX == nil {
   223  		tx, err := c.Store.TransactionContextOptions(ctx, options)
   224  		if err != nil {
   225  			return cn, fmt.Errorf("couldn't start a new transaction: %w", err)
   226  		}
   227  
   228  		cn = &Connection{
   229  			Store:   contextStore{store: tx, ctx: ctx},
   230  			Dialect: c.Dialect,
   231  			TX:      tx,
   232  		}
   233  		cn.setID()
   234  	} else {
   235  		cn = c
   236  	}
   237  	return cn, nil
   238  }
   239  
   240  // WithContext returns a copy of the connection, wrapped with a context.
   241  func (c *Connection) WithContext(ctx context.Context) *Connection {
   242  	cn := c.copy()
   243  	cn.Store = contextStore{
   244  		store: cn.Store,
   245  		ctx:   ctx,
   246  	}
   247  	return cn
   248  }
   249  
   250  func (c *Connection) copy() *Connection {
   251  	// TODO: checkme. it copies and creates a new Connection (and a new ID)
   252  	// with the same TX which could make confusions and complexity in usage.
   253  	// related PRs: #72/#73, #79/#80, and #497
   254  
   255  	cn := &Connection{
   256  		Store:   c.Store,
   257  		Dialect: c.Dialect,
   258  		TX:      c.TX,
   259  	}
   260  	cn.setID(c.ID) // ID of the source as a seed
   261  
   262  	return cn
   263  }
   264  
   265  // Q creates a new "empty" query for the current connection.
   266  func (c *Connection) Q() *Query {
   267  	return Q(c)
   268  }
   269  
   270  // disableEager disables eager mode for current connection.
   271  func (c *Connection) disableEager() {
   272  	// The check technically is not required, because (*Connection).Eager() creates a (shallow) copy.
   273  	// When not reusing eager connections, this should be safe.
   274  	// However, this write triggers the go race detector.
   275  	if c.eager {
   276  		c.eager = false
   277  		c.eagerFields = []string{}
   278  	}
   279  }
   280  
   281  // TruncateAll truncates all data from the datasource
   282  func (c *Connection) TruncateAll() error {
   283  	return c.Dialect.TruncateAll(c)
   284  }
   285  
   286  func (c *Connection) timeFunc(name string, fn func() error) error {
   287  	start := time.Now()
   288  	err := fn()
   289  	atomic.AddInt64(&c.Elapsed, int64(time.Since(start)))
   290  	if err != nil {
   291  		return err
   292  	}
   293  	return nil
   294  }
   295  
   296  // setID sets a unique ID for a Connection in a specific format indicating the
   297  // Connection type, TX.ID, and optionally a copy ID. It makes it easy to trace
   298  // related queries for a single request.
   299  //
   300  //	examples: "conn-7881415437117811350", "tx-4924907692359316530", "tx-831769923571164863-ytzxZa"
   301  func (c *Connection) setID(id ...string) {
   302  	if len(id) == 1 {
   303  		idElems := strings.Split(id[0], "-")
   304  		l := 2
   305  		if len(idElems) < 2 {
   306  			l = len(idElems)
   307  		}
   308  		prefix := strings.Join(idElems[0:l], "-")
   309  		body := randx.String(6)
   310  
   311  		c.ID = fmt.Sprintf("%s-%s", prefix, body)
   312  	} else {
   313  		prefix := "conn"
   314  		body := rand.Int()
   315  
   316  		if c.TX != nil {
   317  			prefix = "tx"
   318  			body = c.TX.ID
   319  		}
   320  
   321  		c.ID = fmt.Sprintf("%s-%d", prefix, body)
   322  	}
   323  }
   324  
   325  func noTxWrapper(noTx bool, c *Connection, fn func(tx *Connection) error) error {
   326  	if noTx {
   327  		return c.Dialect.Lock(func() error {
   328  			return fn(c)
   329  		})
   330  	} else {
   331  		return c.Transaction(fn)
   332  	}
   333  }