github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/utils.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/lingyao2333/mo-zero/core/logx"
    12  	"github.com/lingyao2333/mo-zero/core/mapping"
    13  )
    14  
    15  var errUnbalancedEscape = errors.New("no char after escape char")
    16  
    17  func desensitize(datasource string) string {
    18  	// remove account
    19  	pos := strings.LastIndex(datasource, "@")
    20  	if 0 <= pos && pos+1 < len(datasource) {
    21  		datasource = datasource[pos+1:]
    22  	}
    23  
    24  	return datasource
    25  }
    26  
    27  func escape(input string) string {
    28  	var b strings.Builder
    29  
    30  	for _, ch := range input {
    31  		switch ch {
    32  		case '\x00':
    33  			b.WriteString(`\x00`)
    34  		case '\r':
    35  			b.WriteString(`\r`)
    36  		case '\n':
    37  			b.WriteString(`\n`)
    38  		case '\\':
    39  			b.WriteString(`\\`)
    40  		case '\'':
    41  			b.WriteString(`\'`)
    42  		case '"':
    43  			b.WriteString(`\"`)
    44  		case '\x1a':
    45  			b.WriteString(`\x1a`)
    46  		default:
    47  			b.WriteRune(ch)
    48  		}
    49  	}
    50  
    51  	return b.String()
    52  }
    53  
    54  func format(query string, args ...interface{}) (string, error) {
    55  	numArgs := len(args)
    56  	if numArgs == 0 {
    57  		return query, nil
    58  	}
    59  
    60  	var b strings.Builder
    61  	var argIndex int
    62  	bytes := len(query)
    63  
    64  	for i := 0; i < bytes; i++ {
    65  		ch := query[i]
    66  		switch ch {
    67  		case '?':
    68  			if argIndex >= numArgs {
    69  				return "", fmt.Errorf("error: %d ? in sql, but less arguments provided", argIndex)
    70  			}
    71  
    72  			writeValue(&b, args[argIndex])
    73  			argIndex++
    74  		case ':', '$':
    75  			var j int
    76  			for j = i + 1; j < bytes; j++ {
    77  				char := query[j]
    78  				if char < '0' || '9' < char {
    79  					break
    80  				}
    81  			}
    82  
    83  			if j > i+1 {
    84  				index, err := strconv.Atoi(query[i+1 : j])
    85  				if err != nil {
    86  					return "", err
    87  				}
    88  
    89  				// index starts from 1 for pg or oracle
    90  				if index > argIndex {
    91  					argIndex = index
    92  				}
    93  
    94  				index--
    95  				if index < 0 || numArgs <= index {
    96  					return "", fmt.Errorf("error: wrong index %d in sql", index)
    97  				}
    98  
    99  				writeValue(&b, args[index])
   100  				i = j - 1
   101  			}
   102  		case '\'', '"', '`':
   103  			b.WriteByte(ch)
   104  
   105  			for j := i + 1; j < bytes; j++ {
   106  				cur := query[j]
   107  				b.WriteByte(cur)
   108  
   109  				if cur == '\\' {
   110  					j++
   111  					if j >= bytes {
   112  						return "", errUnbalancedEscape
   113  					}
   114  
   115  					b.WriteByte(query[j])
   116  				} else if cur == ch {
   117  					i = j
   118  					break
   119  				}
   120  			}
   121  		default:
   122  			b.WriteByte(ch)
   123  		}
   124  	}
   125  
   126  	if argIndex < numArgs {
   127  		return "", fmt.Errorf("error: %d arguments provided, not matching sql", argIndex)
   128  	}
   129  
   130  	return b.String(), nil
   131  }
   132  
   133  func logInstanceError(datasource string, err error) {
   134  	datasource = desensitize(datasource)
   135  	logx.Errorf("Error on getting sql instance of %s: %v", datasource, err)
   136  }
   137  
   138  func logSqlError(ctx context.Context, stmt string, err error) {
   139  	if err != nil && err != ErrNotFound {
   140  		logx.WithContext(ctx).Errorf("stmt: %s, error: %s", stmt, err.Error())
   141  	}
   142  }
   143  
   144  func writeValue(buf *strings.Builder, arg interface{}) {
   145  	switch v := arg.(type) {
   146  	case bool:
   147  		if v {
   148  			buf.WriteByte('1')
   149  		} else {
   150  			buf.WriteByte('0')
   151  		}
   152  	case string:
   153  		buf.WriteByte('\'')
   154  		buf.WriteString(escape(v))
   155  		buf.WriteByte('\'')
   156  	case time.Time:
   157  		buf.WriteByte('\'')
   158  		buf.WriteString(v.String())
   159  		buf.WriteByte('\'')
   160  	case *time.Time:
   161  		buf.WriteByte('\'')
   162  		buf.WriteString(v.String())
   163  		buf.WriteByte('\'')
   164  	default:
   165  		buf.WriteString(mapping.Repr(v))
   166  	}
   167  }