github.com/systematiccaos/gorm@v1.22.6/logger/sql.go (about)

     1  package logger
     2  
     3  import (
     4  	"database/sql/driver"
     5  	"fmt"
     6  	"reflect"
     7  	"regexp"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  	"unicode"
    12  
    13  	"github.com/systematiccaos/gorm/utils"
    14  )
    15  
    16  const (
    17  	tmFmtWithMS = "2006-01-02 15:04:05.999"
    18  	tmFmtZero   = "0000-00-00 00:00:00"
    19  	nullStr     = "NULL"
    20  )
    21  
    22  func isPrintable(s []byte) bool {
    23  	for _, r := range s {
    24  		if !unicode.IsPrint(rune(r)) {
    25  			return false
    26  		}
    27  	}
    28  	return true
    29  }
    30  
    31  var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
    32  
    33  func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
    34  	var convertParams func(interface{}, int)
    35  	var vars = make([]string, len(avars))
    36  
    37  	convertParams = func(v interface{}, idx int) {
    38  		switch v := v.(type) {
    39  		case bool:
    40  			vars[idx] = strconv.FormatBool(v)
    41  		case time.Time:
    42  			if v.IsZero() {
    43  				vars[idx] = escaper + tmFmtZero + escaper
    44  			} else {
    45  				vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
    46  			}
    47  		case *time.Time:
    48  			if v != nil {
    49  				if v.IsZero() {
    50  					vars[idx] = escaper + tmFmtZero + escaper
    51  				} else {
    52  					vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
    53  				}
    54  			} else {
    55  				vars[idx] = nullStr
    56  			}
    57  		case driver.Valuer:
    58  			reflectValue := reflect.ValueOf(v)
    59  			if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
    60  				r, _ := v.Value()
    61  				convertParams(r, idx)
    62  			} else {
    63  				vars[idx] = nullStr
    64  			}
    65  		case fmt.Stringer:
    66  			reflectValue := reflect.ValueOf(v)
    67  			if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
    68  				vars[idx] = escaper + strings.Replace(fmt.Sprintf("%v", v), escaper, "\\"+escaper, -1) + escaper
    69  			} else {
    70  				vars[idx] = nullStr
    71  			}
    72  		case []byte:
    73  			if isPrintable(v) {
    74  				vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
    75  			} else {
    76  				vars[idx] = escaper + "<binary>" + escaper
    77  			}
    78  		case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
    79  			vars[idx] = utils.ToString(v)
    80  		case float64, float32:
    81  			vars[idx] = fmt.Sprintf("%.6f", v)
    82  		case string:
    83  			vars[idx] = escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
    84  		default:
    85  			rv := reflect.ValueOf(v)
    86  			if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
    87  				vars[idx] = nullStr
    88  			} else if valuer, ok := v.(driver.Valuer); ok {
    89  				v, _ = valuer.Value()
    90  				convertParams(v, idx)
    91  			} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
    92  				convertParams(reflect.Indirect(rv).Interface(), idx)
    93  			} else {
    94  				for _, t := range convertibleTypes {
    95  					if rv.Type().ConvertibleTo(t) {
    96  						convertParams(rv.Convert(t).Interface(), idx)
    97  						return
    98  					}
    99  				}
   100  				vars[idx] = escaper + strings.Replace(fmt.Sprint(v), escaper, "\\"+escaper, -1) + escaper
   101  			}
   102  		}
   103  	}
   104  
   105  	for idx, v := range avars {
   106  		convertParams(v, idx)
   107  	}
   108  
   109  	if numericPlaceholder == nil {
   110  		var idx int
   111  		var newSQL strings.Builder
   112  
   113  		for _, v := range []byte(sql) {
   114  			if v == '?' {
   115  				if len(vars) > idx {
   116  					newSQL.WriteString(vars[idx])
   117  					idx++
   118  					continue
   119  				}
   120  			}
   121  			newSQL.WriteByte(v)
   122  		}
   123  
   124  		sql = newSQL.String()
   125  	} else {
   126  		sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
   127  		for idx, v := range vars {
   128  			sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1)
   129  		}
   130  	}
   131  
   132  	return sql
   133  }