github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/caches/caches.go (about)

     1  package caches
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/auxten/postgresql-parser/pkg/sql/parser"
     6  	"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
     7  	"github.com/auxten/postgresql-parser/pkg/walk"
     8  	"github.com/samber/lo"
     9  	"github.com/unionj-cloud/go-doudou/v2/toolkit/stringutils"
    10  	"github.com/xwb1989/sqlparser"
    11  	"gorm.io/driver/mysql"
    12  	"gorm.io/driver/postgres"
    13  	"gorm.io/gorm"
    14  	"gorm.io/gorm/callbacks"
    15  	"strings"
    16  	"sync"
    17  )
    18  
    19  type Caches struct {
    20  	Conf  *Config
    21  	queue *sync.Map
    22  }
    23  
    24  type Config struct {
    25  	Easer  bool
    26  	Cacher Cacher
    27  }
    28  
    29  func (c *Caches) Name() string {
    30  	return "gorm:caches"
    31  }
    32  
    33  func (c *Caches) Initialize(db *gorm.DB) error {
    34  	if c.Conf == nil {
    35  		c.Conf = &Config{
    36  			Easer:  false,
    37  			Cacher: nil,
    38  		}
    39  	}
    40  
    41  	if c.Conf.Easer {
    42  		c.queue = &sync.Map{}
    43  	}
    44  
    45  	callback := db.Callback().Query().Get("gorm:query")
    46  
    47  	err := db.Callback().Query().Replace("gorm:query", c.Query(callback))
    48  	if err != nil {
    49  		return err
    50  	}
    51  
    52  	err = db.Callback().Create().After("gorm:after_create").Register("cache:after_create", c.AfterWrite)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	err = db.Callback().Delete().After("gorm:after_delete").Register("cache:after_delete", c.AfterWrite)
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	err = db.Callback().Update().After("gorm:after_update").Register("cache:after_update", c.AfterWrite)
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	err = db.Callback().Raw().After("gorm:raw").Register("cache:after_raw", c.AfterWrite)
    68  	if err != nil {
    69  		return err
    70  	}
    71  
    72  	return nil
    73  }
    74  
    75  func (c *Caches) Query(callback func(*gorm.DB)) func(*gorm.DB) {
    76  	return func(db *gorm.DB) {
    77  		if c.Conf.Easer == false && c.Conf.Cacher == nil {
    78  			callback(db)
    79  			return
    80  		}
    81  
    82  		identifier := buildIdentifier(db)
    83  		if stringutils.ContainsI(identifier, "INSERT INTO") {
    84  			callback(db)
    85  			c.AfterWrite(db)
    86  			return
    87  		}
    88  
    89  		if db.DryRun {
    90  			return
    91  		}
    92  
    93  		if res, ok := c.checkCache(identifier); ok {
    94  			res.replaceOn(db)
    95  			return
    96  		}
    97  
    98  		c.ease(db, identifier, callback)
    99  		if db.Error != nil {
   100  			return
   101  		}
   102  
   103  		c.storeInCache(db, identifier)
   104  		if db.Error != nil {
   105  			return
   106  		}
   107  	}
   108  }
   109  
   110  func (c *Caches) AfterWrite(db *gorm.DB) {
   111  	if c.Conf.Easer == false && c.Conf.Cacher == nil {
   112  		return
   113  	}
   114  
   115  	callbacks.BuildQuerySQL(db)
   116  
   117  	tables := getTables(db)
   118  	if len(tables) == 0 {
   119  		return
   120  	} else if len(tables) == 1 {
   121  		c.deleteCache(db, tables[0])
   122  	} else {
   123  		c.deleteCache(db, tables[0], tables[1:]...)
   124  	}
   125  
   126  	if db.Error != nil {
   127  		return
   128  	}
   129  }
   130  
   131  func (c *Caches) ease(db *gorm.DB, identifier string, callback func(*gorm.DB)) {
   132  	if c.Conf.Easer == false {
   133  		//if true {
   134  		callback(db)
   135  		return
   136  	}
   137  
   138  	res := ease(&queryTask{
   139  		id:      identifier,
   140  		db:      db,
   141  		queryCb: callback,
   142  	}, c.queue).(*queryTask)
   143  
   144  	if db.Error != nil {
   145  		return
   146  	}
   147  
   148  	if res.db.Statement.Dest == db.Statement.Dest {
   149  		return
   150  	}
   151  
   152  	q := Query{
   153  		Dest:         res.db.Statement.Dest,
   154  		RowsAffected: res.db.Statement.RowsAffected,
   155  	}
   156  	q.replaceOn(db)
   157  }
   158  
   159  func (c *Caches) checkCache(identifier string) (res *Query, ok bool) {
   160  	if c.Conf.Cacher != nil {
   161  		if res = c.Conf.Cacher.Get(identifier); res != nil {
   162  			return res, true
   163  		}
   164  	}
   165  	return nil, false
   166  }
   167  
   168  func getTables(db *gorm.DB) []string {
   169  	callbacks.BuildQuerySQL(db)
   170  	switch db.Dialector.(type) {
   171  	case *mysql.Dialector:
   172  		return getTablesMysql(db)
   173  	case *postgres.Dialector:
   174  		return getTablesPostgres(db)
   175  	}
   176  	return nil
   177  }
   178  
   179  func getTablesMysql(db *gorm.DB) []string {
   180  	stmt, err := sqlparser.Parse(db.Statement.SQL.String())
   181  	if err != nil {
   182  		fmt.Println("Error: " + err.Error())
   183  	}
   184  	tableNames := make([]string, 0)
   185  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   186  		switch node := node.(type) {
   187  		case sqlparser.TableName:
   188  			tableNames = append(tableNames, node.Name.CompliantName())
   189  		}
   190  		return true, nil
   191  	}, stmt)
   192  	tableNames = lo.Filter(tableNames, func(x string, index int) bool {
   193  		return stringutils.IsNotEmpty(x)
   194  	})
   195  	tableNames = lo.Uniq(tableNames)
   196  	return tableNames
   197  }
   198  
   199  func getTablesPostgres(db *gorm.DB) []string {
   200  	tableNames := make([]string, 0)
   201  	sql := db.Statement.SQL.String()
   202  	w := &walk.AstWalker{
   203  		Fn: func(ctx interface{}, node interface{}) (stop bool) {
   204  			//log.Printf("%T", node)
   205  			switch expr := node.(type) {
   206  			case *tree.TableName:
   207  				var sb strings.Builder
   208  				fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
   209  				expr.TableNamePrefix.Format(fmtCtx)
   210  				sb.WriteString(fmtCtx.String())
   211  
   212  				if sb.Len() > 0 {
   213  					sb.WriteString(".")
   214  				}
   215  
   216  				fmtCtx = tree.NewFmtCtx(tree.FmtSimple)
   217  				expr.TableName.Format(fmtCtx)
   218  				sb.WriteString(fmtCtx.String())
   219  
   220  				tableNames = append(tableNames, sb.String())
   221  			case *tree.Insert:
   222  				fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
   223  				expr.Table.Format(fmtCtx)
   224  				tableName := fmtCtx.String()
   225  				tableNames = append(tableNames, tableName)
   226  			case *tree.Update:
   227  				fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
   228  				expr.Table.Format(fmtCtx)
   229  				tableName := fmtCtx.String()
   230  				tableNames = append(tableNames, tableName)
   231  			case *tree.Delete:
   232  				fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
   233  				expr.Table.Format(fmtCtx)
   234  				tableName := fmtCtx.String()
   235  				tableNames = append(tableNames, tableName)
   236  			}
   237  			return false
   238  		},
   239  	}
   240  	stmts, err := parser.Parse(sql)
   241  	if err != nil {
   242  		return nil
   243  	}
   244  	_, err = w.Walk(stmts, nil)
   245  	if err != nil {
   246  		return nil
   247  	}
   248  	return tableNames
   249  }
   250  
   251  func (c *Caches) storeInCache(db *gorm.DB, identifier string) {
   252  	if c.Conf.Cacher != nil {
   253  		if _, ok := db.Statement.Dest.(map[string]interface{}); ok {
   254  			fmt.Println(db.Statement.Dest)
   255  		}
   256  		err := c.Conf.Cacher.Store(identifier, &Query{
   257  			Tags:         getTables(db),
   258  			Dest:         db.Statement.Dest,
   259  			RowsAffected: db.Statement.RowsAffected,
   260  		})
   261  		if err != nil {
   262  			_ = db.AddError(err)
   263  		}
   264  	}
   265  }
   266  
   267  func (c *Caches) deleteCache(db *gorm.DB, tag string, tags ...string) {
   268  	if c.Conf.Cacher != nil {
   269  		err := c.Conf.Cacher.Delete(tag, tags...)
   270  		if err != nil {
   271  			_ = db.AddError(err)
   272  		}
   273  	}
   274  }