code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/sanitize.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "encoding/hex" 20 "fmt" 21 "regexp" 22 "strconv" 23 "strings" 24 "time" 25 ) 26 27 // nolint:nakedret 28 func SanitizeSql(sql string, args ...any) (output string, err error) { 29 replacer := func(match string) (replacement string) { 30 n, _ := strconv.ParseInt(match[1:], 10, 0) 31 switch arg := args[n-1].(type) { 32 case string: 33 return quoteString(arg) 34 case int: 35 return strconv.FormatInt(int64(arg), 10) 36 case int8: 37 return strconv.FormatInt(int64(arg), 10) 38 case int16: 39 return strconv.FormatInt(int64(arg), 10) 40 case int32: 41 return strconv.FormatInt(int64(arg), 10) 42 case int64: 43 return strconv.FormatInt(arg, 10) 44 case time.Time: 45 return quoteString(arg.Format("2006-01-02 15:04:05.999999 -0700")) 46 case uint: 47 return strconv.FormatUint(uint64(arg), 10) 48 case uint8: 49 return strconv.FormatUint(uint64(arg), 10) 50 case uint16: 51 return strconv.FormatUint(uint64(arg), 10) 52 case uint32: 53 return strconv.FormatUint(uint64(arg), 10) 54 case uint64: 55 return strconv.FormatUint(arg, 10) 56 case float32: 57 return strconv.FormatFloat(float64(arg), 'f', -1, 32) 58 case float64: 59 return strconv.FormatFloat(arg, 'f', -1, 64) 60 case bool: 61 return strconv.FormatBool(arg) 62 case []byte: 63 return `E'\\x` + hex.EncodeToString(arg) + `'` 64 case []int16: 65 var s string 66 s, err = intSliceToArrayString(arg) 67 return quoteString(s) 68 case []int32: 69 var s string 70 s, err = intSliceToArrayString(arg) 71 return quoteString(s) 72 case []int64: 73 var s string 74 s, err = intSliceToArrayString(arg) 75 return quoteString(s) 76 case nil: 77 return "null" 78 default: 79 err = fmt.Errorf("unable to sanitize type: %T", arg) 80 return "" 81 } 82 } 83 84 output = literalPattern.ReplaceAllStringFunc(sql, replacer) 85 return 86 } 87 88 var literalPattern = regexp.MustCompile(`\$\d+`) 89 90 func quoteString(input string) (output string) { 91 output = "'" + strings.Replace(input, "'", "''", -1) + "'" 92 return 93 } 94 95 func intSliceToArrayString[T any](nums []T) (string, error) { 96 w := strings.Builder{} 97 w.WriteString("{") 98 for i, n := range nums { 99 if i > 0 { 100 w.WriteString(",") 101 } 102 var intx int64 103 switch n := any(n).(type) { 104 case int16: 105 intx = int64(n) 106 case int32: 107 intx = int64(n) 108 case int64: 109 intx = n 110 default: 111 return "", fmt.Errorf("unexpected type %T", n) 112 } 113 w.WriteString(strconv.FormatInt(intx, 10)) 114 } 115 w.WriteString("}") 116 return w.String(), nil 117 }