modernc.org/ql@v1.4.7/driver1.8.go (about) 1 // +build go1.8 2 3 package ql // import "modernc.org/ql" 4 5 import ( 6 "context" 7 "database/sql" 8 "database/sql/driver" 9 "errors" 10 "fmt" 11 "strconv" 12 "strings" 13 ) 14 15 const prefix = "$" 16 17 var ( 18 _ driver.ExecerContext = (*driverConn)(nil) 19 _ driver.QueryerContext = (*driverConn)(nil) 20 _ driver.ConnBeginTx = (*driverConn)(nil) 21 _ driver.ConnPrepareContext = (*driverConn)(nil) 22 ) 23 24 // BeginTx implements driver.ConnBeginTx. 25 func (c *driverConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 26 // Check the transaction level. If the transaction level is non-default 27 // then return an error here as the BeginTx driver value is not supported. 28 if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { 29 return nil, errors.New("ql: driver does not support non-default isolation level") 30 } 31 32 // If a read-only transaction is requested return an error as the 33 // BeginTx driver value is not supported. 34 if opts.ReadOnly { 35 return nil, errors.New("ql: driver does not support read-only transactions") 36 } 37 38 if c.ctx == nil { 39 c.ctx = NewRWCtx() 40 } 41 42 if _, _, err := c.db.db.Execute(c.ctx, txBegin); err != nil { 43 return nil, err 44 } 45 46 c.tnl++ 47 return c, nil 48 } 49 50 func (c *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 51 query, vals, err := replaceNamed(query, args) 52 if err != nil { 53 return nil, err 54 } 55 56 return c.Exec(query, vals) 57 } 58 59 func replaceNamed(query string, args []driver.NamedValue) (string, []driver.Value, error) { 60 toks, err := tokenize(query) 61 if err != nil { 62 return "", nil, err 63 } 64 65 a := make([]driver.Value, len(args)) 66 m := map[string]int{} 67 for _, v := range args { 68 m[v.Name] = v.Ordinal 69 a[v.Ordinal-1] = v.Value 70 } 71 for i, v := range toks { 72 if len(v) > 1 && strings.HasPrefix(v, prefix) { 73 if v[1] >= '1' && v[1] <= '9' { 74 continue 75 } 76 77 nm := v[1:] 78 k, ok := m[nm] 79 if !ok { 80 return query, nil, fmt.Errorf("unknown named parameter %s", nm) 81 } 82 83 toks[i] = fmt.Sprintf("$%d", k) 84 } 85 } 86 return strings.Join(toks, " "), a, nil 87 } 88 89 func (c *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 90 query, vals, err := replaceNamed(query, args) 91 if err != nil { 92 return nil, err 93 } 94 95 return c.Query(query, vals) 96 } 97 98 func (c *driverConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 99 query, err := filterNamedArgs(query) 100 if err != nil { 101 return nil, err 102 } 103 104 return c.Prepare(query) 105 } 106 107 func filterNamedArgs(query string) (string, error) { 108 toks, err := tokenize(query) 109 if err != nil { 110 return "", err 111 } 112 113 n := 0 114 for _, v := range toks { 115 if len(v) > 1 && strings.HasPrefix(v, prefix) && v[1] >= '1' && v[1] <= '9' { 116 m, err := strconv.ParseUint(v[1:], 10, 31) 117 if err != nil { 118 return "", err 119 } 120 121 if int(m) > n { 122 n = int(m) 123 } 124 } 125 } 126 for i, v := range toks { 127 if len(v) > 1 && strings.HasPrefix(v, prefix) { 128 if v[1] >= '1' && v[1] <= '9' { 129 continue 130 } 131 132 n++ 133 toks[i] = fmt.Sprintf("$%d", n) 134 } 135 } 136 return strings.Join(toks, " "), nil 137 } 138 139 func (s *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 140 a := make([]driver.Value, len(args)) 141 for k, v := range args { 142 a[k] = v.Value 143 } 144 return s.Exec(a) 145 } 146 147 func (s *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 148 a := make([]driver.Value, len(args)) 149 for k, v := range args { 150 a[k] = v.Value 151 } 152 return s.Query(a) 153 } 154 155 func tokenize(s string) (r []string, _ error) { 156 lx, err := newLexer(s) 157 if err != nil { 158 return nil, err 159 } 160 161 var lval yySymType 162 for lx.Lex(&lval) != 0 { 163 s := string(lx.TokenBytes(nil)) 164 if s != "" { 165 switch s[len(s)-1] { 166 case '"': 167 s = "\"" + s 168 case '`': 169 s = "`" + s 170 } 171 } 172 r = append(r, s) 173 } 174 return r, nil 175 }