github.com/qxnw/lib4go@v0.0.0-20180426074627-c80c7e84b925/db/tpl/tpl.at.go (about)

     1  package tpl
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strconv"
     7  	"strings"
     8  )
     9  
    10  //ATTPLContext 参数化时使用@+参数名作为占位符的SQL数据库如:oracle,sql server
    11  type ATTPLContext struct {
    12  	name string
    13  }
    14  
    15  func (o ATTPLContext) getSPName(query string) string {
    16  	return fmt.Sprintf("begin %s;end;", strings.Trim(strings.Trim(query, ";"), ","))
    17  }
    18  
    19  //GetSQLContext 获取查询串
    20  func (o ATTPLContext) GetSQLContext(tpl string, input map[string]interface{}) (sql string, args []interface{}) {
    21  	index := 0
    22  	f := func() string {
    23  		index++
    24  		return fmt.Sprint(":", index)
    25  	}
    26  	return AnalyzeTPLFromCache(o.name, tpl, input, f)
    27  }
    28  
    29  //GetSPContext 获取
    30  func (o ATTPLContext) GetSPContext(tpl string, input map[string]interface{}) (sql string, args []interface{}) {
    31  	q, args := o.GetSQLContext(tpl, input)
    32  	sql = o.getSPName(q)
    33  	return
    34  }
    35  
    36  //Replace 替换SQL中的占位符
    37  func (o ATTPLContext) Replace(sql string, args []interface{}) (r string) {
    38  	if strings.EqualFold(sql, "") || args == nil {
    39  		return sql
    40  	}
    41  	word, _ := regexp.Compile(`:\d+([,|\) ;]|$)`)
    42  	sql = word.ReplaceAllStringFunc(sql, func(s string) string {
    43  		c := len(s)
    44  		num := s[1 : c-1]
    45  		// 处理匹配到结尾
    46  		if num == "" {
    47  			num = s[1:c]
    48  			c++
    49  		}
    50  		k, err := strconv.Atoi(num)
    51  		if err != nil || len(args) < k {
    52  			return "NULL" + s[c-1:]
    53  		}
    54  		return fmt.Sprintf("'%v'%s", args[k-1], s[c-1:])
    55  	})
    56  	/*end*/
    57  	return sql
    58  }
    59  func init() {
    60  	// Register("oracle", ATTPLContext{name: "oracle"})
    61  	// Register("mysql", ATTPLContext{name: "mysql"})
    62  }