goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/database/timeout.go (about)

     1  package database
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"gorm.io/gorm"
     8  	"goyave.dev/goyave/v5/util/errors"
     9  )
    10  
    11  const (
    12  	timeoutCallbackBeforeName = "goyave:timeout_before"
    13  	timeoutCallbackAfterName  = "goyave:timeout_after"
    14  )
    15  
    16  type timeoutContext struct {
    17  	context.Context
    18  
    19  	parentContext context.Context
    20  
    21  	// We store the pointer to the original statement
    22  	// so we can cancel the context only if the original
    23  	// statement is completely finished. This prevents
    24  	// sub-statements (such as preloads) to cancel the context
    25  	// when they are done, despite the parent statement not being
    26  	// executed yet.
    27  	statement *gorm.Statement
    28  
    29  	cancel context.CancelFunc
    30  }
    31  
    32  // TimeoutPlugin GORM plugin adding a default timeout to SQL queries if none is applied
    33  // on the statement already. It works by replacing the statement's context with a child
    34  // context having the configured timeout. The context is replaced in a "before" callback
    35  // on all GORM operations. In a "after" callback, the new context is canceled.
    36  //
    37  // The `ReadTimeout` is applied on the `Query` and `Raw` GORM callbacks. The `WriteTimeout`
    38  // is applied on the rest of the callbacks.
    39  //
    40  // Supports all GORM operations except `Scan()`.
    41  //
    42  // A timeout duration inferior or equal to 0 disables the plugin for the relevant operations.
    43  type TimeoutPlugin struct {
    44  	ReadTimeout  time.Duration
    45  	WriteTimeout time.Duration
    46  }
    47  
    48  // Name returns the name of the plugin
    49  func (p *TimeoutPlugin) Name() string {
    50  	return "goyave:timeout"
    51  }
    52  
    53  // Initialize registers the callbacks for all operations.
    54  func (p *TimeoutPlugin) Initialize(db *gorm.DB) error {
    55  	createCallback := db.Callback().Create()
    56  	if err := createCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil {
    57  		return errors.New(err)
    58  	}
    59  	if err := createCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
    60  		return errors.New(err)
    61  	}
    62  
    63  	queryCallback := db.Callback().Query()
    64  	if err := queryCallback.Before("*").Register(timeoutCallbackBeforeName, p.readTimeoutBefore); err != nil {
    65  		return errors.New(err)
    66  	}
    67  	if err := queryCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
    68  		return errors.New(err)
    69  	}
    70  
    71  	deleteCallback := db.Callback().Delete()
    72  	if err := deleteCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil {
    73  		return errors.New(err)
    74  	}
    75  	if err := deleteCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
    76  		return errors.New(err)
    77  	}
    78  
    79  	updateCallback := db.Callback().Update()
    80  	if err := updateCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil {
    81  		return errors.New(err)
    82  	}
    83  	if err := updateCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
    84  		return errors.New(err)
    85  	}
    86  
    87  	// Cannot use it with `Row()` because context is canceled before the call of `rows.Next()`, causing an error.
    88  	// rowCallback := db.Callback().Row()
    89  	// if err := rowCallback.Before("*").Register(timeoutCallbackBeforeName, p.readTimeoutBefore); err != nil {
    90  	// 	return errors.New(err)
    91  	// }
    92  	// if err := rowCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
    93  	// 	return errors.New(err)
    94  	// }
    95  
    96  	rawCallback := db.Callback().Raw()
    97  	if err := rawCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil {
    98  		return errors.New(err)
    99  	}
   100  	if err := rawCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil {
   101  		return errors.New(err)
   102  	}
   103  	return nil
   104  }
   105  
   106  func (p *TimeoutPlugin) readTimeoutBefore(db *gorm.DB) {
   107  	p.timeoutBefore(db, p.ReadTimeout)
   108  }
   109  
   110  func (p *TimeoutPlugin) writeTimeoutBefore(db *gorm.DB) {
   111  	p.timeoutBefore(db, p.WriteTimeout)
   112  }
   113  
   114  func (p *TimeoutPlugin) timeoutBefore(db *gorm.DB, timeout time.Duration) {
   115  	if timeout <= 0 || db.Statement.Context == nil {
   116  		return
   117  	}
   118  	if tc, ok := db.Statement.Context.(*timeoutContext); ok {
   119  		// The statement is re-used, replace the context with a new one
   120  		ctx, cancel := context.WithTimeout(tc.parentContext, timeout)
   121  		db.Statement.Context = &timeoutContext{
   122  			Context:       ctx,
   123  			parentContext: tc.parentContext,
   124  			statement:     db.Statement,
   125  			cancel:        cancel,
   126  		}
   127  		return
   128  	}
   129  	if _, hasDeadline := db.Statement.Context.Deadline(); hasDeadline {
   130  		return
   131  	}
   132  	ctx, cancel := context.WithTimeout(db.Statement.Context, timeout)
   133  	db.Statement.Context = &timeoutContext{
   134  		Context:       ctx,
   135  		parentContext: db.Statement.Context,
   136  		statement:     db.Statement,
   137  		cancel:        cancel,
   138  	}
   139  }
   140  
   141  func (p *TimeoutPlugin) timeoutAfter(db *gorm.DB) {
   142  	ctx, ok := db.Statement.Context.(*timeoutContext)
   143  	if !ok || ctx.cancel == nil || db.Statement != ctx.statement {
   144  		return
   145  	}
   146  	ctx.cancel()
   147  }