github.com/systematiccaos/gorm@v1.22.6/callbacks.go (about)

     1  package gorm
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"sort"
     9  	"time"
    10  
    11  	"github.com/systematiccaos/gorm/schema"
    12  	"github.com/systematiccaos/gorm/utils"
    13  )
    14  
    15  func initializeCallbacks(db *DB) *callbacks {
    16  	return &callbacks{
    17  		processors: map[string]*processor{
    18  			"create": {db: db},
    19  			"query":  {db: db},
    20  			"update": {db: db},
    21  			"delete": {db: db},
    22  			"row":    {db: db},
    23  			"raw":    {db: db},
    24  		},
    25  	}
    26  }
    27  
    28  // callbacks gorm callbacks manager
    29  type callbacks struct {
    30  	processors map[string]*processor
    31  }
    32  
    33  type processor struct {
    34  	db        *DB
    35  	Clauses   []string
    36  	fns       []func(*DB)
    37  	callbacks []*callback
    38  }
    39  
    40  type callback struct {
    41  	name      string
    42  	before    string
    43  	after     string
    44  	remove    bool
    45  	replace   bool
    46  	match     func(*DB) bool
    47  	handler   func(*DB)
    48  	processor *processor
    49  }
    50  
    51  func (cs *callbacks) Create() *processor {
    52  	return cs.processors["create"]
    53  }
    54  
    55  func (cs *callbacks) Query() *processor {
    56  	return cs.processors["query"]
    57  }
    58  
    59  func (cs *callbacks) Update() *processor {
    60  	return cs.processors["update"]
    61  }
    62  
    63  func (cs *callbacks) Delete() *processor {
    64  	return cs.processors["delete"]
    65  }
    66  
    67  func (cs *callbacks) Row() *processor {
    68  	return cs.processors["row"]
    69  }
    70  
    71  func (cs *callbacks) Raw() *processor {
    72  	return cs.processors["raw"]
    73  }
    74  
    75  func (p *processor) Execute(db *DB) *DB {
    76  	// call scopes
    77  	for len(db.Statement.scopes) > 0 {
    78  		scopes := db.Statement.scopes
    79  		db.Statement.scopes = nil
    80  		for _, scope := range scopes {
    81  			db = scope(db)
    82  		}
    83  	}
    84  
    85  	var (
    86  		curTime           = time.Now()
    87  		stmt              = db.Statement
    88  		resetBuildClauses bool
    89  	)
    90  
    91  	if len(stmt.BuildClauses) == 0 {
    92  		stmt.BuildClauses = p.Clauses
    93  		resetBuildClauses = true
    94  	}
    95  
    96  	// assign model values
    97  	if stmt.Model == nil {
    98  		stmt.Model = stmt.Dest
    99  	} else if stmt.Dest == nil {
   100  		stmt.Dest = stmt.Model
   101  	}
   102  
   103  	// parse model values
   104  	if stmt.Model != nil {
   105  		if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
   106  			if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
   107  				db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
   108  			} else {
   109  				db.AddError(err)
   110  			}
   111  		}
   112  	}
   113  
   114  	// assign stmt.ReflectValue
   115  	if stmt.Dest != nil {
   116  		stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
   117  		for stmt.ReflectValue.Kind() == reflect.Ptr {
   118  			if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
   119  				stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
   120  			}
   121  
   122  			stmt.ReflectValue = stmt.ReflectValue.Elem()
   123  		}
   124  		if !stmt.ReflectValue.IsValid() {
   125  			db.AddError(ErrInvalidValue)
   126  		}
   127  	}
   128  
   129  	for _, f := range p.fns {
   130  		f(db)
   131  	}
   132  
   133  	if stmt.SQL.Len() > 0 {
   134  		db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
   135  			return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
   136  		}, db.Error)
   137  	}
   138  
   139  	if !stmt.DB.DryRun {
   140  		stmt.SQL.Reset()
   141  		stmt.Vars = nil
   142  	}
   143  
   144  	if resetBuildClauses {
   145  		stmt.BuildClauses = nil
   146  	}
   147  
   148  	return db
   149  }
   150  
   151  func (p *processor) Get(name string) func(*DB) {
   152  	for i := len(p.callbacks) - 1; i >= 0; i-- {
   153  		if v := p.callbacks[i]; v.name == name && !v.remove {
   154  			return v.handler
   155  		}
   156  	}
   157  	return nil
   158  }
   159  
   160  func (p *processor) Before(name string) *callback {
   161  	return &callback{before: name, processor: p}
   162  }
   163  
   164  func (p *processor) After(name string) *callback {
   165  	return &callback{after: name, processor: p}
   166  }
   167  
   168  func (p *processor) Match(fc func(*DB) bool) *callback {
   169  	return &callback{match: fc, processor: p}
   170  }
   171  
   172  func (p *processor) Register(name string, fn func(*DB)) error {
   173  	return (&callback{processor: p}).Register(name, fn)
   174  }
   175  
   176  func (p *processor) Remove(name string) error {
   177  	return (&callback{processor: p}).Remove(name)
   178  }
   179  
   180  func (p *processor) Replace(name string, fn func(*DB)) error {
   181  	return (&callback{processor: p}).Replace(name, fn)
   182  }
   183  
   184  func (p *processor) compile() (err error) {
   185  	var callbacks []*callback
   186  	for _, callback := range p.callbacks {
   187  		if callback.match == nil || callback.match(p.db) {
   188  			callbacks = append(callbacks, callback)
   189  		}
   190  	}
   191  	p.callbacks = callbacks
   192  
   193  	if p.fns, err = sortCallbacks(p.callbacks); err != nil {
   194  		p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
   195  	}
   196  	return
   197  }
   198  
   199  func (c *callback) Before(name string) *callback {
   200  	c.before = name
   201  	return c
   202  }
   203  
   204  func (c *callback) After(name string) *callback {
   205  	c.after = name
   206  	return c
   207  }
   208  
   209  func (c *callback) Register(name string, fn func(*DB)) error {
   210  	c.name = name
   211  	c.handler = fn
   212  	c.processor.callbacks = append(c.processor.callbacks, c)
   213  	return c.processor.compile()
   214  }
   215  
   216  func (c *callback) Remove(name string) error {
   217  	c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
   218  	c.name = name
   219  	c.remove = true
   220  	c.processor.callbacks = append(c.processor.callbacks, c)
   221  	return c.processor.compile()
   222  }
   223  
   224  func (c *callback) Replace(name string, fn func(*DB)) error {
   225  	c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
   226  	c.name = name
   227  	c.handler = fn
   228  	c.replace = true
   229  	c.processor.callbacks = append(c.processor.callbacks, c)
   230  	return c.processor.compile()
   231  }
   232  
   233  // getRIndex get right index from string slice
   234  func getRIndex(strs []string, str string) int {
   235  	for i := len(strs) - 1; i >= 0; i-- {
   236  		if strs[i] == str {
   237  			return i
   238  		}
   239  	}
   240  	return -1
   241  }
   242  
   243  func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
   244  	var (
   245  		names, sorted []string
   246  		sortCallback  func(*callback) error
   247  	)
   248  	sort.Slice(cs, func(i, j int) bool {
   249  		return cs[j].before == "*" || cs[j].after == "*"
   250  	})
   251  
   252  	for _, c := range cs {
   253  		// show warning message the callback name already exists
   254  		if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
   255  			c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
   256  		}
   257  		names = append(names, c.name)
   258  	}
   259  
   260  	sortCallback = func(c *callback) error {
   261  		if c.before != "" { // if defined before callback
   262  			if c.before == "*" && len(sorted) > 0 {
   263  				if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
   264  					sorted = append([]string{c.name}, sorted...)
   265  				}
   266  			} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
   267  				if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
   268  					// if before callback already sorted, append current callback just after it
   269  					sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
   270  				} else if curIdx > sortedIdx {
   271  					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
   272  				}
   273  			} else if idx := getRIndex(names, c.before); idx != -1 {
   274  				// if before callback exists
   275  				cs[idx].after = c.name
   276  			}
   277  		}
   278  
   279  		if c.after != "" { // if defined after callback
   280  			if c.after == "*" && len(sorted) > 0 {
   281  				if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
   282  					sorted = append(sorted, c.name)
   283  				}
   284  			} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
   285  				if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
   286  					// if after callback sorted, append current callback to last
   287  					sorted = append(sorted, c.name)
   288  				} else if curIdx < sortedIdx {
   289  					return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
   290  				}
   291  			} else if idx := getRIndex(names, c.after); idx != -1 {
   292  				// if after callback exists but haven't sorted
   293  				// set after callback's before callback to current callback
   294  				after := cs[idx]
   295  
   296  				if after.before == "" {
   297  					after.before = c.name
   298  				}
   299  
   300  				if err := sortCallback(after); err != nil {
   301  					return err
   302  				}
   303  
   304  				if err := sortCallback(c); err != nil {
   305  					return err
   306  				}
   307  			}
   308  		}
   309  
   310  		// if current callback haven't been sorted, append it to last
   311  		if getRIndex(sorted, c.name) == -1 {
   312  			sorted = append(sorted, c.name)
   313  		}
   314  
   315  		return nil
   316  	}
   317  
   318  	for _, c := range cs {
   319  		if err = sortCallback(c); err != nil {
   320  			return
   321  		}
   322  	}
   323  
   324  	for _, name := range sorted {
   325  		if idx := getRIndex(names, name); !cs[idx].remove {
   326  			fns = append(fns, cs[idx].handler)
   327  		}
   328  	}
   329  
   330  	return
   331  }