github.com/RevenueMonster/sqlike@v1.0.6/plugin/opentracing/helper.go (about)

     1  package opentracing
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"regexp"
    10  	"strconv"
    11  	"time"
    12  
    13  	"github.com/RevenueMonster/sqlike/util"
    14  	"github.com/opentracing/opentracing-go"
    15  	"github.com/opentracing/opentracing-go/ext"
    16  	"github.com/opentracing/opentracing-go/log"
    17  )
    18  
    19  var r = regexp.MustCompile(`(\$\d|\?|\:\w+)`)
    20  
    21  func (ot *OpenTracingInterceptor) logQuery(span opentracing.Span, query string) {
    22  	ot.logQueryArgs(span, query, nil)
    23  }
    24  
    25  func (ot *OpenTracingInterceptor) logQueryArgs(span opentracing.Span, query string, args []driver.NamedValue) {
    26  	if span == nil {
    27  		return
    28  	}
    29  
    30  	if !ot.opts.Args || len(args) == 0 {
    31  		span.LogFields(
    32  			log.String(string(ext.DBStatement), query),
    33  		)
    34  		return
    35  	}
    36  
    37  	blr := util.AcquireString()
    38  	defer util.ReleaseString(blr)
    39  
    40  	mapQuery(query, blr, args)
    41  
    42  	span.LogFields(
    43  		log.String(string(ext.DBStatement), blr.String()),
    44  	)
    45  }
    46  
    47  func mapQuery(query string, w io.StringWriter, args []driver.NamedValue) {
    48  	var (
    49  		i      int
    50  		paths  []int
    51  		value  string
    52  		length = len(args)
    53  	)
    54  
    55  	for {
    56  		paths = r.FindStringIndex(query)
    57  		if len(paths) < 2 {
    58  			w.WriteString(query)
    59  			break
    60  		}
    61  
    62  		w.WriteString(query[:paths[0]])
    63  
    64  		// by default, query string won't be have invalid arguments
    65  		// TODO: if it's :name argument, we should store the value in map
    66  		switch v := args[i].Value.(type) {
    67  		case string:
    68  			value = strconv.Quote(v)
    69  		case int64:
    70  			value = strconv.FormatInt(v, 10)
    71  		case uint64:
    72  			value = strconv.FormatUint(v, 10)
    73  		case float64:
    74  			value = strconv.FormatFloat(v, 'e', -1, 64)
    75  		case bool:
    76  			value = strconv.FormatBool(v)
    77  		case time.Time:
    78  			value = `"` + v.Format(time.RFC3339) + `"`
    79  		case []byte:
    80  			value = strconv.Quote(util.UnsafeString(v))
    81  		case json.RawMessage:
    82  			value = strconv.Quote(util.UnsafeString(v))
    83  		case sql.RawBytes:
    84  			value = string(v)
    85  		case fmt.Stringer:
    86  			value = strconv.Quote(v.String())
    87  		case nil:
    88  			value = "NULL"
    89  		default:
    90  			value = strconv.Quote(fmt.Sprintf("%v", v))
    91  		}
    92  
    93  		w.WriteString(value)
    94  		query = query[paths[1]:]
    95  		i++
    96  
    97  		if i >= length {
    98  			w.WriteString(query)
    99  			break
   100  		}
   101  	}
   102  }
   103  
   104  func (ot *OpenTracingInterceptor) logError(span opentracing.Span, err error) {
   105  	if err != nil && err != driver.ErrSkip {
   106  		// we didn't want to log driver.ErrSkip, because the native sql package will handle
   107  		ext.LogError(span, err)
   108  	}
   109  }