modernc.org/ql@v1.4.7/driver1.8.go (about)

     1  // +build go1.8
     2  
     3  package ql // import "modernc.org/ql"
     4  
     5  import (
     6  	"context"
     7  	"database/sql"
     8  	"database/sql/driver"
     9  	"errors"
    10  	"fmt"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  const prefix = "$"
    16  
    17  var (
    18  	_ driver.ExecerContext      = (*driverConn)(nil)
    19  	_ driver.QueryerContext     = (*driverConn)(nil)
    20  	_ driver.ConnBeginTx        = (*driverConn)(nil)
    21  	_ driver.ConnPrepareContext = (*driverConn)(nil)
    22  )
    23  
    24  // BeginTx implements driver.ConnBeginTx.
    25  func (c *driverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    26  	// Check the transaction level. If the transaction level is non-default
    27  	// then return an error here as the BeginTx driver value is not supported.
    28  	if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
    29  		return nil, errors.New("ql: driver does not support non-default isolation level")
    30  	}
    31  
    32  	// If a read-only transaction is requested return an error as the
    33  	// BeginTx driver value is not supported.
    34  	if opts.ReadOnly {
    35  		return nil, errors.New("ql: driver does not support read-only transactions")
    36  	}
    37  
    38  	if c.ctx == nil {
    39  		c.ctx = NewRWCtx()
    40  	}
    41  
    42  	if _, _, err := c.db.db.Execute(c.ctx, txBegin); err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	c.tnl++
    47  	return c, nil
    48  }
    49  
    50  func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
    51  	query, vals, err := replaceNamed(query, args)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	return c.Exec(query, vals)
    57  }
    58  
    59  func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value, error) {
    60  	toks, err := tokenize(query)
    61  	if err != nil {
    62  		return "", nil, err
    63  	}
    64  
    65  	a := make([]driver.Value, len(args))
    66  	m := map[string]int{}
    67  	for _, v := range args {
    68  		m[v.Name] = v.Ordinal
    69  		a[v.Ordinal-1] = v.Value
    70  	}
    71  	for i, v := range toks {
    72  		if len(v) > 1 && strings.HasPrefix(v, prefix) {
    73  			if v[1] >= '1' && v[1] <= '9' {
    74  				continue
    75  			}
    76  
    77  			nm := v[1:]
    78  			k, ok := m[nm]
    79  			if !ok {
    80  				return query, nil, fmt.Errorf("unknown named parameter %s", nm)
    81  			}
    82  
    83  			toks[i] = fmt.Sprintf("$%d", k)
    84  		}
    85  	}
    86  	return strings.Join(toks, " "), a, nil
    87  }
    88  
    89  func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
    90  	query, vals, err := replaceNamed(query, args)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return c.Query(query, vals)
    96  }
    97  
    98  func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
    99  	query, err := filterNamedArgs(query)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	return c.Prepare(query)
   105  }
   106  
   107  func filterNamedArgs(query string) (string, error) {
   108  	toks, err := tokenize(query)
   109  	if err != nil {
   110  		return "", err
   111  	}
   112  
   113  	n := 0
   114  	for _, v := range toks {
   115  		if len(v) > 1 && strings.HasPrefix(v, prefix) && v[1] >= '1' && v[1] <= '9' {
   116  			m, err := strconv.ParseUint(v[1:], 10, 31)
   117  			if err != nil {
   118  				return "", err
   119  			}
   120  
   121  			if int(m) > n {
   122  				n = int(m)
   123  			}
   124  		}
   125  	}
   126  	for i, v := range toks {
   127  		if len(v) > 1 && strings.HasPrefix(v, prefix) {
   128  			if v[1] >= '1' && v[1] <= '9' {
   129  				continue
   130  			}
   131  
   132  			n++
   133  			toks[i] = fmt.Sprintf("$%d", n)
   134  		}
   135  	}
   136  	return strings.Join(toks, " "), nil
   137  }
   138  
   139  func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   140  	a := make([]driver.Value, len(args))
   141  	for k, v := range args {
   142  		a[k] = v.Value
   143  	}
   144  	return s.Exec(a)
   145  }
   146  
   147  func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   148  	a := make([]driver.Value, len(args))
   149  	for k, v := range args {
   150  		a[k] = v.Value
   151  	}
   152  	return s.Query(a)
   153  }
   154  
   155  func tokenize(s string) (r []string, _ error) {
   156  	lx, err := newLexer(s)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	var lval yySymType
   162  	for lx.Lex(&lval) != 0 {
   163  		s := string(lx.TokenBytes(nil))
   164  		if s != "" {
   165  			switch s[len(s)-1] {
   166  			case '"':
   167  				s = "\"" + s
   168  			case '`':
   169  				s = "`" + s
   170  			}
   171  		}
   172  		r = append(r, s)
   173  	}
   174  	return r, nil
   175  }