github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/connection.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"math/rand"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/Accefy/pop/internal/defaults"
    14  	"github.com/Accefy/pop/internal/randx"
    15  	"github.com/gobuffalo/pop/v6/logging"
    16  	"go.opentelemetry.io/otel/attribute"
    17  )
    18  
    19  // Connections contains all available connections
    20  var Connections = map[string]*Connection{}
    21  
    22  // Connection represents all necessary details to talk with a datastore
    23  type Connection struct {
    24  	ID          string
    25  	Store       store
    26  	Dialect     dialect
    27  	Elapsed     int64
    28  	TX          *Tx
    29  	eager       bool
    30  	eagerFields []string
    31  }
    32  
    33  func (c *Connection) String() string {
    34  	return c.URL()
    35  }
    36  
    37  // URL returns the datasource connection string
    38  func (c *Connection) URL() string {
    39  	return c.Dialect.URL()
    40  }
    41  
    42  // Context returns the connection's context set by "Context()" or context.TODO()
    43  // if no context is set.
    44  func (c *Connection) Context() context.Context {
    45  	if c, ok := c.Store.(interface{ Context() context.Context }); ok {
    46  		return c.Context()
    47  	}
    48  
    49  	return context.TODO()
    50  }
    51  
    52  // MigrationURL returns the datasource connection string used for running the migrations
    53  func (c *Connection) MigrationURL() string {
    54  	return c.Dialect.MigrationURL()
    55  }
    56  
    57  // MigrationTableName returns the name of the table to track migrations
    58  func (c *Connection) MigrationTableName() string {
    59  	return c.Dialect.Details().MigrationTableName()
    60  }
    61  
    62  // NewConnection creates a new connection, and sets it's `Dialect`
    63  // appropriately based on the `ConnectionDetails` passed into it.
    64  func NewConnection(deets *ConnectionDetails) (*Connection, error) {
    65  	err := deets.Finalize()
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	c := &Connection{}
    70  	c.setID()
    71  
    72  	if nc, ok := newConnection[deets.Dialect]; ok {
    73  		c.Dialect, err = nc(deets)
    74  		if err != nil {
    75  			return c, fmt.Errorf("could not create new connection: %w", err)
    76  		}
    77  		return c, nil
    78  	}
    79  	return nil, fmt.Errorf("could not found connection creator for %v", deets.Dialect)
    80  }
    81  
    82  // Connect takes the name of a connection, default is "development", and will
    83  // return that connection from the available `Connections`. If a connection with
    84  // that name can not be found an error will be returned. If a connection is
    85  // found, and it has yet to open a connection with its underlying datastore,
    86  // a connection to that store will be opened.
    87  func Connect(e string) (*Connection, error) {
    88  	if len(Connections) == 0 {
    89  		err := LoadConfigFile()
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  	}
    94  	e = defaults.String(e, "development")
    95  	c := Connections[e]
    96  	if c == nil {
    97  		return c, fmt.Errorf("could not find connection named %s", e)
    98  	}
    99  
   100  	if err := c.Open(); err != nil {
   101  		return c, fmt.Errorf("couldn't open connection for %s: %w", e, err)
   102  	}
   103  	return c, nil
   104  }
   105  
   106  // Open creates a new datasource connection
   107  func (c *Connection) Open(attributes ...attribute.KeyValue) error {
   108  	if c.Store != nil {
   109  		return nil
   110  	}
   111  	if c.Dialect == nil {
   112  		return errors.New("invalid connection instance")
   113  	}
   114  	details := c.Dialect.Details()
   115  
   116  	db, err := openPotentiallyInstrumentedConnectionWithOtel(c.Dialect, c.Dialect.URL(), attributes...)
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	db.SetMaxOpenConns(details.Pool)
   122  	if details.IdlePool != 0 {
   123  		db.SetMaxIdleConns(details.IdlePool)
   124  	}
   125  	if details.ConnMaxLifetime > 0 {
   126  		db.SetConnMaxLifetime(details.ConnMaxLifetime)
   127  	}
   128  	if details.ConnMaxIdleTime > 0 {
   129  		db.SetConnMaxIdleTime(details.ConnMaxIdleTime)
   130  	}
   131  	if details.Unsafe {
   132  		db = db.Unsafe()
   133  	}
   134  	c.Store = &dB{db}
   135  
   136  	if d, ok := c.Dialect.(afterOpenable); ok {
   137  		if err := d.AfterOpen(c); err != nil {
   138  			c.Store = nil
   139  			return fmt.Errorf("could not open database connection: %w", err)
   140  		}
   141  	}
   142  	return nil
   143  }
   144  
   145  // Close destroys an active datasource connection
   146  func (c *Connection) Close() error {
   147  	if err := c.Store.Close(); err != nil {
   148  		return fmt.Errorf("couldn't close connection: %w", err)
   149  	}
   150  	c.Store = nil
   151  	return nil
   152  }
   153  
   154  // Transaction will start a new transaction on the connection. If the inner function
   155  // returns an error then the transaction will be rolled back, otherwise the transaction
   156  // will automatically commit at the end.
   157  func (c *Connection) Transaction(fn func(tx *Connection) error) error {
   158  	return c.Dialect.Lock(func() (err error) {
   159  		var dberr error
   160  
   161  		cn, err := c.NewTransaction()
   162  		if err != nil {
   163  			return err
   164  		}
   165  		txlog(logging.SQL, cn, "BEGIN Transaction ---")
   166  
   167  		defer func() {
   168  			if ex := recover(); ex != nil {
   169  				txlog(logging.SQL, cn, "ROLLBACK Transaction (inner function panic) ---")
   170  				dberr = cn.TX.Rollback()
   171  				if dberr != nil {
   172  					txlog(logging.Error, cn, "database error while inner panic rollback: %w", dberr)
   173  				}
   174  				panic(ex)
   175  			}
   176  		}()
   177  
   178  		err = fn(cn)
   179  		if err != nil {
   180  			txlog(logging.SQL, cn, "ROLLBACK Transaction ---")
   181  			dberr = cn.TX.Rollback()
   182  		} else {
   183  			txlog(logging.SQL, cn, "END Transaction ---")
   184  			dberr = cn.TX.Commit()
   185  		}
   186  
   187  		if dberr != nil {
   188  			return fmt.Errorf("database error on committing or rolling back transaction: %w", dberr)
   189  		}
   190  
   191  		return err
   192  	})
   193  
   194  }
   195  
   196  // Rollback will open a new transaction and automatically rollback that transaction
   197  // when the inner function returns, regardless. This can be useful for tests, etc...
   198  func (c *Connection) Rollback(fn func(tx *Connection)) error {
   199  	// TODO: the name of the method could be changed to express it better.
   200  	cn, err := c.NewTransaction()
   201  	if err != nil {
   202  		return err
   203  	}
   204  	txlog(logging.SQL, cn, "BEGIN Transaction for Rollback ---")
   205  	fn(cn)
   206  	txlog(logging.SQL, cn, "ROLLBACK Transaction as planned ---")
   207  	return cn.TX.Rollback()
   208  }
   209  
   210  // NewTransaction starts a new transaction on the connection
   211  func (c *Connection) NewTransaction() (*Connection, error) {
   212  	return c.NewTransactionContextOptions(c.Context(), nil)
   213  }
   214  
   215  // NewTransactionContext starts a new transaction on the connection using the provided context
   216  func (c *Connection) NewTransactionContext(ctx context.Context) (*Connection, error) {
   217  	return c.NewTransactionContextOptions(ctx, nil)
   218  }
   219  
   220  // NewTransactionContextOptions starts a new transaction on the connection using the provided context and transaction options
   221  func (c *Connection) NewTransactionContextOptions(ctx context.Context, options *sql.TxOptions) (*Connection, error) {
   222  	var cn *Connection
   223  	if c.TX == nil {
   224  		tx, err := c.Store.TransactionContextOptions(ctx, options)
   225  		if err != nil {
   226  			return cn, fmt.Errorf("couldn't start a new transaction: %w", err)
   227  		}
   228  
   229  		cn = &Connection{
   230  			Store:   contextStore{store: tx, ctx: ctx},
   231  			Dialect: c.Dialect,
   232  			TX:      tx,
   233  		}
   234  		cn.setID()
   235  	} else {
   236  		cn = c
   237  	}
   238  	return cn, nil
   239  }
   240  
   241  // WithContext returns a copy of the connection, wrapped with a context.
   242  func (c *Connection) WithContext(ctx context.Context) *Connection {
   243  	cn := c.copy()
   244  	cn.Store = contextStore{
   245  		store: cn.Store,
   246  		ctx:   ctx,
   247  	}
   248  	return cn
   249  }
   250  
   251  func (c *Connection) copy() *Connection {
   252  	// TODO: checkme. it copies and creates a new Connection (and a new ID)
   253  	// with the same TX which could make confusions and complexity in usage.
   254  	// related PRs: #72/#73, #79/#80, and #497
   255  
   256  	cn := &Connection{
   257  		Store:   c.Store,
   258  		Dialect: c.Dialect,
   259  		TX:      c.TX,
   260  	}
   261  	cn.setID(c.ID) // ID of the source as a seed
   262  
   263  	return cn
   264  }
   265  
   266  // Q creates a new "empty" query for the current connection.
   267  func (c *Connection) Q() *Query {
   268  	return Q(c)
   269  }
   270  
   271  // disableEager disables eager mode for current connection.
   272  func (c *Connection) disableEager() {
   273  	// The check technically is not required, because (*Connection).Eager() creates a (shallow) copy.
   274  	// When not reusing eager connections, this should be safe.
   275  	// However, this write triggers the go race detector.
   276  	if c.eager {
   277  		c.eager = false
   278  		c.eagerFields = []string{}
   279  	}
   280  }
   281  
   282  // TruncateAll truncates all data from the datasource
   283  func (c *Connection) TruncateAll() error {
   284  	return c.Dialect.TruncateAll(c)
   285  }
   286  
   287  func (c *Connection) timeFunc(name string, fn func() error) error {
   288  	start := time.Now()
   289  	err := fn()
   290  	atomic.AddInt64(&c.Elapsed, int64(time.Since(start)))
   291  	if err != nil {
   292  		return err
   293  	}
   294  	return nil
   295  }
   296  
   297  // setID sets a unique ID for a Connection in a specific format indicating the
   298  // Connection type, TX.ID, and optionally a copy ID. It makes it easy to trace
   299  // related queries for a single request.
   300  //
   301  //  examples: "conn-7881415437117811350", "tx-4924907692359316530", "tx-831769923571164863-ytzxZa"
   302  func (c *Connection) setID(id ...string) {
   303  	if len(id) == 1 {
   304  		idElems := strings.Split(id[0], "-")
   305  		l := 2
   306  		if len(idElems) < 2 {
   307  			l = len(idElems)
   308  		}
   309  		prefix := strings.Join(idElems[0:l], "-")
   310  		body := randx.String(6)
   311  
   312  		c.ID = fmt.Sprintf("%s-%s", prefix, body)
   313  	} else {
   314  		prefix := "conn"
   315  		body := rand.Int()
   316  
   317  		if c.TX != nil {
   318  			prefix = "tx"
   319  			body = c.TX.ID
   320  		}
   321  
   322  		c.ID = fmt.Sprintf("%s-%d", prefix, body)
   323  	}
   324  }