code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/sanitize.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"encoding/hex"
    20  	"fmt"
    21  	"regexp"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    25  )
    26  
    27  // nolint:nakedret
    28  func SanitizeSql(sql string, args ...any) (output string, err error) {
    29  	replacer := func(match string) (replacement string) {
    30  		n, _ := strconv.ParseInt(match[1:], 10, 0)
    31  		switch arg := args[n-1].(type) {
    32  		case string:
    33  			return quoteString(arg)
    34  		case int:
    35  			return strconv.FormatInt(int64(arg), 10)
    36  		case int8:
    37  			return strconv.FormatInt(int64(arg), 10)
    38  		case int16:
    39  			return strconv.FormatInt(int64(arg), 10)
    40  		case int32:
    41  			return strconv.FormatInt(int64(arg), 10)
    42  		case int64:
    43  			return strconv.FormatInt(arg, 10)
    44  		case time.Time:
    45  			return quoteString(arg.Format("2006-01-02 15:04:05.999999 -0700"))
    46  		case uint:
    47  			return strconv.FormatUint(uint64(arg), 10)
    48  		case uint8:
    49  			return strconv.FormatUint(uint64(arg), 10)
    50  		case uint16:
    51  			return strconv.FormatUint(uint64(arg), 10)
    52  		case uint32:
    53  			return strconv.FormatUint(uint64(arg), 10)
    54  		case uint64:
    55  			return strconv.FormatUint(arg, 10)
    56  		case float32:
    57  			return strconv.FormatFloat(float64(arg), 'f', -1, 32)
    58  		case float64:
    59  			return strconv.FormatFloat(arg, 'f', -1, 64)
    60  		case bool:
    61  			return strconv.FormatBool(arg)
    62  		case []byte:
    63  			return `E'\\x` + hex.EncodeToString(arg) + `'`
    64  		case []int16:
    65  			var s string
    66  			s, err = intSliceToArrayString(arg)
    67  			return quoteString(s)
    68  		case []int32:
    69  			var s string
    70  			s, err = intSliceToArrayString(arg)
    71  			return quoteString(s)
    72  		case []int64:
    73  			var s string
    74  			s, err = intSliceToArrayString(arg)
    75  			return quoteString(s)
    76  		case nil:
    77  			return "null"
    78  		default:
    79  			err = fmt.Errorf("unable to sanitize type: %T", arg)
    80  			return ""
    81  		}
    82  	}
    83  
    84  	output = literalPattern.ReplaceAllStringFunc(sql, replacer)
    85  	return
    86  }
    87  
    88  var literalPattern = regexp.MustCompile(`\$\d+`)
    89  
    90  func quoteString(input string) (output string) {
    91  	output = "'" + strings.Replace(input, "'", "''", -1) + "'"
    92  	return
    93  }
    94  
    95  func intSliceToArrayString[T any](nums []T) (string, error) {
    96  	w := strings.Builder{}
    97  	w.WriteString("{")
    98  	for i, n := range nums {
    99  		if i > 0 {
   100  			w.WriteString(",")
   101  		}
   102  		var intx int64
   103  		switch n := any(n).(type) {
   104  		case int16:
   105  			intx = int64(n)
   106  		case int32:
   107  			intx = int64(n)
   108  		case int64:
   109  			intx = n
   110  		default:
   111  			return "", fmt.Errorf("unexpected type %T", n)
   112  		}
   113  		w.WriteString(strconv.FormatInt(intx, 10))
   114  	}
   115  	w.WriteString("}")
   116  	return w.String(), nil
   117  }