github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/caches/caches.go (about) 1 package caches 2 3 import ( 4 "fmt" 5 "github.com/auxten/postgresql-parser/pkg/sql/parser" 6 "github.com/auxten/postgresql-parser/pkg/sql/sem/tree" 7 "github.com/auxten/postgresql-parser/pkg/walk" 8 "github.com/samber/lo" 9 "github.com/unionj-cloud/go-doudou/v2/toolkit/stringutils" 10 "github.com/xwb1989/sqlparser" 11 "gorm.io/driver/mysql" 12 "gorm.io/driver/postgres" 13 "gorm.io/gorm" 14 "gorm.io/gorm/callbacks" 15 "strings" 16 "sync" 17 ) 18 19 type Caches struct { 20 Conf *Config 21 queue *sync.Map 22 } 23 24 type Config struct { 25 Easer bool 26 Cacher Cacher 27 } 28 29 func (c *Caches) Name() string { 30 return "gorm:caches" 31 } 32 33 func (c *Caches) Initialize(db *gorm.DB) error { 34 if c.Conf == nil { 35 c.Conf = &Config{ 36 Easer: false, 37 Cacher: nil, 38 } 39 } 40 41 if c.Conf.Easer { 42 c.queue = &sync.Map{} 43 } 44 45 callback := db.Callback().Query().Get("gorm:query") 46 47 err := db.Callback().Query().Replace("gorm:query", c.Query(callback)) 48 if err != nil { 49 return err 50 } 51 52 err = db.Callback().Create().After("gorm:after_create").Register("cache:after_create", c.AfterWrite) 53 if err != nil { 54 return err 55 } 56 57 err = db.Callback().Delete().After("gorm:after_delete").Register("cache:after_delete", c.AfterWrite) 58 if err != nil { 59 return err 60 } 61 62 err = db.Callback().Update().After("gorm:after_update").Register("cache:after_update", c.AfterWrite) 63 if err != nil { 64 return err 65 } 66 67 err = db.Callback().Raw().After("gorm:raw").Register("cache:after_raw", c.AfterWrite) 68 if err != nil { 69 return err 70 } 71 72 return nil 73 } 74 75 func (c *Caches) Query(callback func(*gorm.DB)) func(*gorm.DB) { 76 return func(db *gorm.DB) { 77 if c.Conf.Easer == false && c.Conf.Cacher == nil { 78 callback(db) 79 return 80 } 81 82 identifier := buildIdentifier(db) 83 if stringutils.ContainsI(identifier, "INSERT INTO") { 84 callback(db) 85 c.AfterWrite(db) 86 return 87 } 88 89 if db.DryRun { 90 return 91 } 92 93 if res, ok := c.checkCache(identifier); ok { 94 res.replaceOn(db) 95 return 96 } 97 98 c.ease(db, identifier, callback) 99 if db.Error != nil { 100 return 101 } 102 103 c.storeInCache(db, identifier) 104 if db.Error != nil { 105 return 106 } 107 } 108 } 109 110 func (c *Caches) AfterWrite(db *gorm.DB) { 111 if c.Conf.Easer == false && c.Conf.Cacher == nil { 112 return 113 } 114 115 callbacks.BuildQuerySQL(db) 116 117 tables := getTables(db) 118 if len(tables) == 0 { 119 return 120 } else if len(tables) == 1 { 121 c.deleteCache(db, tables[0]) 122 } else { 123 c.deleteCache(db, tables[0], tables[1:]...) 124 } 125 126 if db.Error != nil { 127 return 128 } 129 } 130 131 func (c *Caches) ease(db *gorm.DB, identifier string, callback func(*gorm.DB)) { 132 if c.Conf.Easer == false { 133 //if true { 134 callback(db) 135 return 136 } 137 138 res := ease(&queryTask{ 139 id: identifier, 140 db: db, 141 queryCb: callback, 142 }, c.queue).(*queryTask) 143 144 if db.Error != nil { 145 return 146 } 147 148 if res.db.Statement.Dest == db.Statement.Dest { 149 return 150 } 151 152 q := Query{ 153 Dest: res.db.Statement.Dest, 154 RowsAffected: res.db.Statement.RowsAffected, 155 } 156 q.replaceOn(db) 157 } 158 159 func (c *Caches) checkCache(identifier string) (res *Query, ok bool) { 160 if c.Conf.Cacher != nil { 161 if res = c.Conf.Cacher.Get(identifier); res != nil { 162 return res, true 163 } 164 } 165 return nil, false 166 } 167 168 func getTables(db *gorm.DB) []string { 169 callbacks.BuildQuerySQL(db) 170 switch db.Dialector.(type) { 171 case *mysql.Dialector: 172 return getTablesMysql(db) 173 case *postgres.Dialector: 174 return getTablesPostgres(db) 175 } 176 return nil 177 } 178 179 func getTablesMysql(db *gorm.DB) []string { 180 stmt, err := sqlparser.Parse(db.Statement.SQL.String()) 181 if err != nil { 182 fmt.Println("Error: " + err.Error()) 183 } 184 tableNames := make([]string, 0) 185 _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { 186 switch node := node.(type) { 187 case sqlparser.TableName: 188 tableNames = append(tableNames, node.Name.CompliantName()) 189 } 190 return true, nil 191 }, stmt) 192 tableNames = lo.Filter(tableNames, func(x string, index int) bool { 193 return stringutils.IsNotEmpty(x) 194 }) 195 tableNames = lo.Uniq(tableNames) 196 return tableNames 197 } 198 199 func getTablesPostgres(db *gorm.DB) []string { 200 tableNames := make([]string, 0) 201 sql := db.Statement.SQL.String() 202 w := &walk.AstWalker{ 203 Fn: func(ctx interface{}, node interface{}) (stop bool) { 204 //log.Printf("%T", node) 205 switch expr := node.(type) { 206 case *tree.TableName: 207 var sb strings.Builder 208 fmtCtx := tree.NewFmtCtx(tree.FmtSimple) 209 expr.TableNamePrefix.Format(fmtCtx) 210 sb.WriteString(fmtCtx.String()) 211 212 if sb.Len() > 0 { 213 sb.WriteString(".") 214 } 215 216 fmtCtx = tree.NewFmtCtx(tree.FmtSimple) 217 expr.TableName.Format(fmtCtx) 218 sb.WriteString(fmtCtx.String()) 219 220 tableNames = append(tableNames, sb.String()) 221 case *tree.Insert: 222 fmtCtx := tree.NewFmtCtx(tree.FmtSimple) 223 expr.Table.Format(fmtCtx) 224 tableName := fmtCtx.String() 225 tableNames = append(tableNames, tableName) 226 case *tree.Update: 227 fmtCtx := tree.NewFmtCtx(tree.FmtSimple) 228 expr.Table.Format(fmtCtx) 229 tableName := fmtCtx.String() 230 tableNames = append(tableNames, tableName) 231 case *tree.Delete: 232 fmtCtx := tree.NewFmtCtx(tree.FmtSimple) 233 expr.Table.Format(fmtCtx) 234 tableName := fmtCtx.String() 235 tableNames = append(tableNames, tableName) 236 } 237 return false 238 }, 239 } 240 stmts, err := parser.Parse(sql) 241 if err != nil { 242 return nil 243 } 244 _, err = w.Walk(stmts, nil) 245 if err != nil { 246 return nil 247 } 248 return tableNames 249 } 250 251 func (c *Caches) storeInCache(db *gorm.DB, identifier string) { 252 if c.Conf.Cacher != nil { 253 if _, ok := db.Statement.Dest.(map[string]interface{}); ok { 254 fmt.Println(db.Statement.Dest) 255 } 256 err := c.Conf.Cacher.Store(identifier, &Query{ 257 Tags: getTables(db), 258 Dest: db.Statement.Dest, 259 RowsAffected: db.Statement.RowsAffected, 260 }) 261 if err != nil { 262 _ = db.AddError(err) 263 } 264 } 265 } 266 267 func (c *Caches) deleteCache(db *gorm.DB, tag string, tags ...string) { 268 if c.Conf.Cacher != nil { 269 err := c.Conf.Cacher.Delete(tag, tags...) 270 if err != nil { 271 _ = db.AddError(err) 272 } 273 } 274 }