github.com/sereiner/library@v0.0.0-20200518095232-1fa3e640cc5f/db/tpl/tpl.go (about)

     1  package tpl
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/sereiner/library/concurrent/cmap"
     8  )
     9  
    10  const (
    11  	cOra    = "ora"
    12  	cOracle = "oracle"
    13  	cSqlite = "sqlite"
    14  )
    15  
    16  var (
    17  	tpls      map[string]ITPLContext
    18  	tplCaches cmap.ConcurrentMap
    19  )
    20  
    21  //ITPLContext 模板上下文
    22  type ITPLContext interface {
    23  	GetSQLContext(tpl string, input map[string]interface{}) (query string, args []interface{})
    24  	GetSPContext(tpl string, input map[string]interface{}) (query string, args []interface{})
    25  	Replace(sql string, args []interface{}) (r string)
    26  }
    27  
    28  func init() {
    29  	tpls = make(map[string]ITPLContext)
    30  	tplCaches = cmap.New(8)
    31  
    32  	Register("oracle", ATTPLContext{name: "oracle", prefix: ":"})
    33  	Register("ora", ATTPLContext{name: "ora", prefix: ":"})
    34  	Register("mysql", MTPLContext{name: "mysql", prefix: "?"})
    35  	Register("sqlite", MTPLContext{name: "sqlite", prefix: "?"})
    36  	Register("postgres", ATTPLContext{name: "postgres", prefix: "$"})
    37  }
    38  func Register(name string, tpl ITPLContext) {
    39  	if _, ok := tpls[name]; ok {
    40  		panic("重复的注册:" + name)
    41  	}
    42  	tpls[name] = tpl
    43  }
    44  
    45  //GetDBContext 获取数据库上下文操作
    46  func GetDBContext(name string) (ITPLContext, error) {
    47  	if v, ok := tpls[strings.ToLower(name)]; ok {
    48  		return v, nil
    49  	}
    50  	return nil, fmt.Errorf("不支持的数据库类型:%s", name)
    51  }