github.com/systematiccaos/gorm@v1.22.6/callbacks.go (about) 1 package gorm 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "sort" 9 "time" 10 11 "github.com/systematiccaos/gorm/schema" 12 "github.com/systematiccaos/gorm/utils" 13 ) 14 15 func initializeCallbacks(db *DB) *callbacks { 16 return &callbacks{ 17 processors: map[string]*processor{ 18 "create": {db: db}, 19 "query": {db: db}, 20 "update": {db: db}, 21 "delete": {db: db}, 22 "row": {db: db}, 23 "raw": {db: db}, 24 }, 25 } 26 } 27 28 // callbacks gorm callbacks manager 29 type callbacks struct { 30 processors map[string]*processor 31 } 32 33 type processor struct { 34 db *DB 35 Clauses []string 36 fns []func(*DB) 37 callbacks []*callback 38 } 39 40 type callback struct { 41 name string 42 before string 43 after string 44 remove bool 45 replace bool 46 match func(*DB) bool 47 handler func(*DB) 48 processor *processor 49 } 50 51 func (cs *callbacks) Create() *processor { 52 return cs.processors["create"] 53 } 54 55 func (cs *callbacks) Query() *processor { 56 return cs.processors["query"] 57 } 58 59 func (cs *callbacks) Update() *processor { 60 return cs.processors["update"] 61 } 62 63 func (cs *callbacks) Delete() *processor { 64 return cs.processors["delete"] 65 } 66 67 func (cs *callbacks) Row() *processor { 68 return cs.processors["row"] 69 } 70 71 func (cs *callbacks) Raw() *processor { 72 return cs.processors["raw"] 73 } 74 75 func (p *processor) Execute(db *DB) *DB { 76 // call scopes 77 for len(db.Statement.scopes) > 0 { 78 scopes := db.Statement.scopes 79 db.Statement.scopes = nil 80 for _, scope := range scopes { 81 db = scope(db) 82 } 83 } 84 85 var ( 86 curTime = time.Now() 87 stmt = db.Statement 88 resetBuildClauses bool 89 ) 90 91 if len(stmt.BuildClauses) == 0 { 92 stmt.BuildClauses = p.Clauses 93 resetBuildClauses = true 94 } 95 96 // assign model values 97 if stmt.Model == nil { 98 stmt.Model = stmt.Dest 99 } else if stmt.Dest == nil { 100 stmt.Dest = stmt.Model 101 } 102 103 // parse model values 104 if stmt.Model != nil { 105 if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { 106 if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { 107 db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) 108 } else { 109 db.AddError(err) 110 } 111 } 112 } 113 114 // assign stmt.ReflectValue 115 if stmt.Dest != nil { 116 stmt.ReflectValue = reflect.ValueOf(stmt.Dest) 117 for stmt.ReflectValue.Kind() == reflect.Ptr { 118 if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { 119 stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) 120 } 121 122 stmt.ReflectValue = stmt.ReflectValue.Elem() 123 } 124 if !stmt.ReflectValue.IsValid() { 125 db.AddError(ErrInvalidValue) 126 } 127 } 128 129 for _, f := range p.fns { 130 f(db) 131 } 132 133 if stmt.SQL.Len() > 0 { 134 db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { 135 return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected 136 }, db.Error) 137 } 138 139 if !stmt.DB.DryRun { 140 stmt.SQL.Reset() 141 stmt.Vars = nil 142 } 143 144 if resetBuildClauses { 145 stmt.BuildClauses = nil 146 } 147 148 return db 149 } 150 151 func (p *processor) Get(name string) func(*DB) { 152 for i := len(p.callbacks) - 1; i >= 0; i-- { 153 if v := p.callbacks[i]; v.name == name && !v.remove { 154 return v.handler 155 } 156 } 157 return nil 158 } 159 160 func (p *processor) Before(name string) *callback { 161 return &callback{before: name, processor: p} 162 } 163 164 func (p *processor) After(name string) *callback { 165 return &callback{after: name, processor: p} 166 } 167 168 func (p *processor) Match(fc func(*DB) bool) *callback { 169 return &callback{match: fc, processor: p} 170 } 171 172 func (p *processor) Register(name string, fn func(*DB)) error { 173 return (&callback{processor: p}).Register(name, fn) 174 } 175 176 func (p *processor) Remove(name string) error { 177 return (&callback{processor: p}).Remove(name) 178 } 179 180 func (p *processor) Replace(name string, fn func(*DB)) error { 181 return (&callback{processor: p}).Replace(name, fn) 182 } 183 184 func (p *processor) compile() (err error) { 185 var callbacks []*callback 186 for _, callback := range p.callbacks { 187 if callback.match == nil || callback.match(p.db) { 188 callbacks = append(callbacks, callback) 189 } 190 } 191 p.callbacks = callbacks 192 193 if p.fns, err = sortCallbacks(p.callbacks); err != nil { 194 p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) 195 } 196 return 197 } 198 199 func (c *callback) Before(name string) *callback { 200 c.before = name 201 return c 202 } 203 204 func (c *callback) After(name string) *callback { 205 c.after = name 206 return c 207 } 208 209 func (c *callback) Register(name string, fn func(*DB)) error { 210 c.name = name 211 c.handler = fn 212 c.processor.callbacks = append(c.processor.callbacks, c) 213 return c.processor.compile() 214 } 215 216 func (c *callback) Remove(name string) error { 217 c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) 218 c.name = name 219 c.remove = true 220 c.processor.callbacks = append(c.processor.callbacks, c) 221 return c.processor.compile() 222 } 223 224 func (c *callback) Replace(name string, fn func(*DB)) error { 225 c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) 226 c.name = name 227 c.handler = fn 228 c.replace = true 229 c.processor.callbacks = append(c.processor.callbacks, c) 230 return c.processor.compile() 231 } 232 233 // getRIndex get right index from string slice 234 func getRIndex(strs []string, str string) int { 235 for i := len(strs) - 1; i >= 0; i-- { 236 if strs[i] == str { 237 return i 238 } 239 } 240 return -1 241 } 242 243 func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { 244 var ( 245 names, sorted []string 246 sortCallback func(*callback) error 247 ) 248 sort.Slice(cs, func(i, j int) bool { 249 return cs[j].before == "*" || cs[j].after == "*" 250 }) 251 252 for _, c := range cs { 253 // show warning message the callback name already exists 254 if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { 255 c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) 256 } 257 names = append(names, c.name) 258 } 259 260 sortCallback = func(c *callback) error { 261 if c.before != "" { // if defined before callback 262 if c.before == "*" && len(sorted) > 0 { 263 if curIdx := getRIndex(sorted, c.name); curIdx == -1 { 264 sorted = append([]string{c.name}, sorted...) 265 } 266 } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { 267 if curIdx := getRIndex(sorted, c.name); curIdx == -1 { 268 // if before callback already sorted, append current callback just after it 269 sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) 270 } else if curIdx > sortedIdx { 271 return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) 272 } 273 } else if idx := getRIndex(names, c.before); idx != -1 { 274 // if before callback exists 275 cs[idx].after = c.name 276 } 277 } 278 279 if c.after != "" { // if defined after callback 280 if c.after == "*" && len(sorted) > 0 { 281 if curIdx := getRIndex(sorted, c.name); curIdx == -1 { 282 sorted = append(sorted, c.name) 283 } 284 } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { 285 if curIdx := getRIndex(sorted, c.name); curIdx == -1 { 286 // if after callback sorted, append current callback to last 287 sorted = append(sorted, c.name) 288 } else if curIdx < sortedIdx { 289 return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) 290 } 291 } else if idx := getRIndex(names, c.after); idx != -1 { 292 // if after callback exists but haven't sorted 293 // set after callback's before callback to current callback 294 after := cs[idx] 295 296 if after.before == "" { 297 after.before = c.name 298 } 299 300 if err := sortCallback(after); err != nil { 301 return err 302 } 303 304 if err := sortCallback(c); err != nil { 305 return err 306 } 307 } 308 } 309 310 // if current callback haven't been sorted, append it to last 311 if getRIndex(sorted, c.name) == -1 { 312 sorted = append(sorted, c.name) 313 } 314 315 return nil 316 } 317 318 for _, c := range cs { 319 if err = sortCallback(c); err != nil { 320 return 321 } 322 } 323 324 for _, name := range sorted { 325 if idx := getRIndex(names, name); !cs[idx].remove { 326 fns = append(fns, cs[idx].handler) 327 } 328 } 329 330 return 331 }