github.com/samiam2013/sqlvet@v0.0.0-20221210043606-d72f678fc0aa/pkg/parseutil/sqlx.go (about)

     1  package parseutil
     2  
     3  import (
     4  	"errors"
     5  	"strconv"
     6  	"unicode"
     7  )
     8  
     9  // copied from https://github.com/jmoiron/sqlx/blob/2ba0fc60eb4a54030f3a6d73ff0a047349c7eeca/bind.go
    10  
    11  // Bindvar types supported by Rebind, BindMap and BindStruct.
    12  const (
    13  	UNKNOWN = iota
    14  	QUESTION
    15  	DOLLAR
    16  	NAMED
    17  	AT
    18  )
    19  
    20  // BindType returns the bindtype for a given database given a drivername.
    21  func BindType(driverName string) int {
    22  	switch driverName {
    23  	case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql":
    24  		return DOLLAR
    25  	case "mysql":
    26  		return QUESTION
    27  	case "sqlite3":
    28  		return QUESTION
    29  	case "oci8", "ora", "goracle":
    30  		return NAMED
    31  	case "sqlserver":
    32  		return AT
    33  	}
    34  	return UNKNOWN
    35  }
    36  
    37  // copied from https://github.com/jmoiron/sqlx/blob/2ba0fc60eb4a54030f3a6d73ff0a047349c7eeca/named.go#L291
    38  
    39  var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
    40  
    41  func CompileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
    42  	names = make([]string, 0, 10)
    43  	rebound := make([]byte, 0, len(qs))
    44  
    45  	inName := false
    46  	last := len(qs) - 1
    47  	currentVar := 1
    48  	name := make([]byte, 0, 10)
    49  
    50  	for i, b := range qs {
    51  		// a ':' while we're in a name is an error
    52  		if b == ':' {
    53  			// if this is the second ':' in a '::' escape sequence, append a ':'
    54  			if inName && i > 0 && qs[i-1] == ':' {
    55  				rebound = append(rebound, ':')
    56  				inName = false
    57  				continue
    58  			} else if inName {
    59  				err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
    60  				return query, names, err
    61  			}
    62  			inName = true
    63  			name = []byte{}
    64  		} else if inName && i > 0 && b == '=' && len(name) == 0 {
    65  			rebound = append(rebound, ':', '=')
    66  			inName = false
    67  			continue
    68  			// if we're in a name, and this is an allowed character, continue
    69  		} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
    70  			// append the byte to the name if we are in a name and not on the last byte
    71  			name = append(name, b)
    72  			// if we're in a name and it's not an allowed character, the name is done
    73  		} else if inName {
    74  			inName = false
    75  			// if this is the final byte of the string and it is part of the name, then
    76  			// make sure to add it to the name
    77  			if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
    78  				name = append(name, b)
    79  			}
    80  			// add the string representation to the names list
    81  			names = append(names, string(name))
    82  			// add a proper bindvar for the bindType
    83  			switch bindType {
    84  			// oracle only supports named type bind vars even for positional
    85  			case NAMED:
    86  				rebound = append(rebound, ':')
    87  				rebound = append(rebound, name...)
    88  			case QUESTION, UNKNOWN:
    89  				rebound = append(rebound, '?')
    90  			case DOLLAR:
    91  				rebound = append(rebound, '$')
    92  				for _, b := range strconv.Itoa(currentVar) {
    93  					rebound = append(rebound, byte(b))
    94  				}
    95  				currentVar++
    96  			case AT:
    97  				rebound = append(rebound, '@', 'p')
    98  				for _, b := range strconv.Itoa(currentVar) {
    99  					rebound = append(rebound, byte(b))
   100  				}
   101  				currentVar++
   102  			}
   103  			// add this byte to string unless it was not part of the name
   104  			if i != last {
   105  				rebound = append(rebound, b)
   106  			} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
   107  				rebound = append(rebound, b)
   108  			}
   109  		} else {
   110  			// this is a normal byte and should just go onto the rebound query
   111  			rebound = append(rebound, b)
   112  		}
   113  	}
   114  	return string(rebound), names, err
   115  }