goyave.dev/goyave/v5@v5.0.0-rc9.0.20240517145003-d3f977d0b9f3/database/timeout.go (about) 1 package database 2 3 import ( 4 "context" 5 "time" 6 7 "gorm.io/gorm" 8 "goyave.dev/goyave/v5/util/errors" 9 ) 10 11 const ( 12 timeoutCallbackBeforeName = "goyave:timeout_before" 13 timeoutCallbackAfterName = "goyave:timeout_after" 14 ) 15 16 type timeoutContext struct { 17 context.Context 18 19 parentContext context.Context 20 21 // We store the pointer to the original statement 22 // so we can cancel the context only if the original 23 // statement is completely finished. This prevents 24 // sub-statements (such as preloads) to cancel the context 25 // when they are done, despite the parent statement not being 26 // executed yet. 27 statement *gorm.Statement 28 29 cancel context.CancelFunc 30 } 31 32 // TimeoutPlugin GORM plugin adding a default timeout to SQL queries if none is applied 33 // on the statement already. It works by replacing the statement's context with a child 34 // context having the configured timeout. The context is replaced in a "before" callback 35 // on all GORM operations. In a "after" callback, the new context is canceled. 36 // 37 // The `ReadTimeout` is applied on the `Query` and `Raw` GORM callbacks. The `WriteTimeout` 38 // is applied on the rest of the callbacks. 39 // 40 // Supports all GORM operations except `Scan()`. 41 // 42 // A timeout duration inferior or equal to 0 disables the plugin for the relevant operations. 43 type TimeoutPlugin struct { 44 ReadTimeout time.Duration 45 WriteTimeout time.Duration 46 } 47 48 // Name returns the name of the plugin 49 func (p *TimeoutPlugin) Name() string { 50 return "goyave:timeout" 51 } 52 53 // Initialize registers the callbacks for all operations. 54 func (p *TimeoutPlugin) Initialize(db *gorm.DB) error { 55 createCallback := db.Callback().Create() 56 if err := createCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil { 57 return errors.New(err) 58 } 59 if err := createCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 60 return errors.New(err) 61 } 62 63 queryCallback := db.Callback().Query() 64 if err := queryCallback.Before("*").Register(timeoutCallbackBeforeName, p.readTimeoutBefore); err != nil { 65 return errors.New(err) 66 } 67 if err := queryCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 68 return errors.New(err) 69 } 70 71 deleteCallback := db.Callback().Delete() 72 if err := deleteCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil { 73 return errors.New(err) 74 } 75 if err := deleteCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 76 return errors.New(err) 77 } 78 79 updateCallback := db.Callback().Update() 80 if err := updateCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil { 81 return errors.New(err) 82 } 83 if err := updateCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 84 return errors.New(err) 85 } 86 87 // Cannot use it with `Row()` because context is canceled before the call of `rows.Next()`, causing an error. 88 // rowCallback := db.Callback().Row() 89 // if err := rowCallback.Before("*").Register(timeoutCallbackBeforeName, p.readTimeoutBefore); err != nil { 90 // return errors.New(err) 91 // } 92 // if err := rowCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 93 // return errors.New(err) 94 // } 95 96 rawCallback := db.Callback().Raw() 97 if err := rawCallback.Before("*").Register(timeoutCallbackBeforeName, p.writeTimeoutBefore); err != nil { 98 return errors.New(err) 99 } 100 if err := rawCallback.After("*").Register(timeoutCallbackAfterName, p.timeoutAfter); err != nil { 101 return errors.New(err) 102 } 103 return nil 104 } 105 106 func (p *TimeoutPlugin) readTimeoutBefore(db *gorm.DB) { 107 p.timeoutBefore(db, p.ReadTimeout) 108 } 109 110 func (p *TimeoutPlugin) writeTimeoutBefore(db *gorm.DB) { 111 p.timeoutBefore(db, p.WriteTimeout) 112 } 113 114 func (p *TimeoutPlugin) timeoutBefore(db *gorm.DB, timeout time.Duration) { 115 if timeout <= 0 || db.Statement.Context == nil { 116 return 117 } 118 if tc, ok := db.Statement.Context.(*timeoutContext); ok { 119 // The statement is re-used, replace the context with a new one 120 ctx, cancel := context.WithTimeout(tc.parentContext, timeout) 121 db.Statement.Context = &timeoutContext{ 122 Context: ctx, 123 parentContext: tc.parentContext, 124 statement: db.Statement, 125 cancel: cancel, 126 } 127 return 128 } 129 if _, hasDeadline := db.Statement.Context.Deadline(); hasDeadline { 130 return 131 } 132 ctx, cancel := context.WithTimeout(db.Statement.Context, timeout) 133 db.Statement.Context = &timeoutContext{ 134 Context: ctx, 135 parentContext: db.Statement.Context, 136 statement: db.Statement, 137 cancel: cancel, 138 } 139 } 140 141 func (p *TimeoutPlugin) timeoutAfter(db *gorm.DB) { 142 ctx, ok := db.Statement.Context.(*timeoutContext) 143 if !ok || ctx.cancel == nil || db.Statement != ctx.statement { 144 return 145 } 146 ctx.cancel() 147 }