github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/utils.go (about) 1 package sqlx 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/lingyao2333/mo-zero/core/logx" 12 "github.com/lingyao2333/mo-zero/core/mapping" 13 ) 14 15 var errUnbalancedEscape = errors.New("no char after escape char") 16 17 func desensitize(datasource string) string { 18 // remove account 19 pos := strings.LastIndex(datasource, "@") 20 if 0 <= pos && pos+1 < len(datasource) { 21 datasource = datasource[pos+1:] 22 } 23 24 return datasource 25 } 26 27 func escape(input string) string { 28 var b strings.Builder 29 30 for _, ch := range input { 31 switch ch { 32 case '\x00': 33 b.WriteString(`\x00`) 34 case '\r': 35 b.WriteString(`\r`) 36 case '\n': 37 b.WriteString(`\n`) 38 case '\\': 39 b.WriteString(`\\`) 40 case '\'': 41 b.WriteString(`\'`) 42 case '"': 43 b.WriteString(`\"`) 44 case '\x1a': 45 b.WriteString(`\x1a`) 46 default: 47 b.WriteRune(ch) 48 } 49 } 50 51 return b.String() 52 } 53 54 func format(query string, args ...interface{}) (string, error) { 55 numArgs := len(args) 56 if numArgs == 0 { 57 return query, nil 58 } 59 60 var b strings.Builder 61 var argIndex int 62 bytes := len(query) 63 64 for i := 0; i < bytes; i++ { 65 ch := query[i] 66 switch ch { 67 case '?': 68 if argIndex >= numArgs { 69 return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex) 70 } 71 72 writeValue(&b, args[argIndex]) 73 argIndex++ 74 case ':', '$': 75 var j int 76 for j = i + 1; j < bytes; j++ { 77 char := query[j] 78 if char < '0' || '9' < char { 79 break 80 } 81 } 82 83 if j > i+1 { 84 index, err := strconv.Atoi(query[i+1 : j]) 85 if err != nil { 86 return "", err 87 } 88 89 // index starts from 1 for pg or oracle 90 if index > argIndex { 91 argIndex = index 92 } 93 94 index-- 95 if index < 0 || numArgs <= index { 96 return "", fmt.Errorf("error: wrong index %d in sql", index) 97 } 98 99 writeValue(&b, args[index]) 100 i = j - 1 101 } 102 case '\'', '"', '`': 103 b.WriteByte(ch) 104 105 for j := i + 1; j < bytes; j++ { 106 cur := query[j] 107 b.WriteByte(cur) 108 109 if cur == '\\' { 110 j++ 111 if j >= bytes { 112 return "", errUnbalancedEscape 113 } 114 115 b.WriteByte(query[j]) 116 } else if cur == ch { 117 i = j 118 break 119 } 120 } 121 default: 122 b.WriteByte(ch) 123 } 124 } 125 126 if argIndex < numArgs { 127 return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex) 128 } 129 130 return b.String(), nil 131 } 132 133 func logInstanceError(datasource string, err error) { 134 datasource = desensitize(datasource) 135 logx.Errorf("Error on getting sql instance of %s: %v", datasource, err) 136 } 137 138 func logSqlError(ctx context.Context, stmt string, err error) { 139 if err != nil && err != ErrNotFound { 140 logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error()) 141 } 142 } 143 144 func writeValue(buf *strings.Builder, arg interface{}) { 145 switch v := arg.(type) { 146 case bool: 147 if v { 148 buf.WriteByte('1') 149 } else { 150 buf.WriteByte('0') 151 } 152 case string: 153 buf.WriteByte('\'') 154 buf.WriteString(escape(v)) 155 buf.WriteByte('\'') 156 case time.Time: 157 buf.WriteByte('\'') 158 buf.WriteString(v.String()) 159 buf.WriteByte('\'') 160 case *time.Time: 161 buf.WriteByte('\'') 162 buf.WriteString(v.String()) 163 buf.WriteByte('\'') 164 default: 165 buf.WriteString(mapping.Repr(v)) 166 } 167 }