github.com/sereiner/library@v0.0.0-20200518095232-1fa3e640cc5f/db/tpl/tpl.go (about) 1 package tpl 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/sereiner/library/concurrent/cmap" 8 ) 9 10 const ( 11 cOra = "ora" 12 cOracle = "oracle" 13 cSqlite = "sqlite" 14 ) 15 16 var ( 17 tpls map[string]ITPLContext 18 tplCaches cmap.ConcurrentMap 19 ) 20 21 //ITPLContext 模板上下文 22 type ITPLContext interface { 23 GetSQLContext(tpl string, input map[string]interface{}) (query string, args []interface{}) 24 GetSPContext(tpl string, input map[string]interface{}) (query string, args []interface{}) 25 Replace(sql string, args []interface{}) (r string) 26 } 27 28 func init() { 29 tpls = make(map[string]ITPLContext) 30 tplCaches = cmap.New(8) 31 32 Register("oracle", ATTPLContext{name: "oracle", prefix: ":"}) 33 Register("ora", ATTPLContext{name: "ora", prefix: ":"}) 34 Register("mysql", MTPLContext{name: "mysql", prefix: "?"}) 35 Register("sqlite", MTPLContext{name: "sqlite", prefix: "?"}) 36 Register("postgres", ATTPLContext{name: "postgres", prefix: "$"}) 37 } 38 func Register(name string, tpl ITPLContext) { 39 if _, ok := tpls[name]; ok { 40 panic("重复的注册:" + name) 41 } 42 tpls[name] = tpl 43 } 44 45 //GetDBContext 获取数据库上下文操作 46 func GetDBContext(name string) (ITPLContext, error) { 47 if v, ok := tpls[strings.ToLower(name)]; ok { 48 return v, nil 49 } 50 return nil, fmt.Errorf("不支持的数据库类型:%s", name) 51 }