github.com/GnawNom/sqlinternals@v0.0.0-20200413232442-a2dcc5655e0f/sqlinternals_test.go (about) 1 // sqlinternals - retrieve driver.Rows from sql.*Row / sql.*Rows 2 // 3 // Copyright 2013 Arne Hormann. All rights reserved. 4 // 5 // This Source Code Form is subject to the terms of the Mozilla Public 6 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 7 // You can obtain one at http://mozilla.org/MPL/2.0/. 8 9 package sqlinternals 10 11 import ( 12 "database/sql" 13 "database/sql/driver" 14 "io" 15 "testing" 16 ) 17 18 type omnithing struct { 19 numInputs int 20 columns []string 21 rows [][]interface{} 22 } 23 24 func (t *omnithing) Close() error { return nil } 25 26 // driver.Driver 27 func (t *omnithing) Open(name string) (driver.Conn, error) { return t, nil } 28 29 // driver.Conn 30 func (t *omnithing) Prepare(query string) (driver.Stmt, error) { return t, nil } 31 func (t *omnithing) Begin() (driver.Tx, error) { return t, nil } 32 33 // driver.Tx 34 func (t *omnithing) Commit() error { return nil } 35 func (t *omnithing) Rollback() error { return nil } 36 37 // driver.Stmt 38 func (t *omnithing) NumInput() int { return t.numInputs } 39 func (t *omnithing) Exec(args []driver.Value) (driver.Result, error) { return t, nil } 40 func (t *omnithing) Query(args []driver.Value) (driver.Rows, error) { return t, nil } 41 42 // driver.Result 43 func (t *omnithing) LastInsertId() (int64, error) { return 0, nil } 44 func (t *omnithing) RowsAffected() (int64, error) { return 0, nil } 45 46 // driver.Rows 47 func (t *omnithing) Columns() []string { return t.columns } 48 func (t *omnithing) Next(dest []driver.Value) error { 49 if len(t.rows) == 0 { 50 return io.EOF 51 } 52 var row []interface{} 53 row, t.rows = t.rows[0], t.rows[1:] 54 for i, v := range row { 55 dest[i] = v 56 } 57 return nil 58 } 59 60 func (o *omnithing) setDB(numInputs int, columns []string, cells ...interface{}) *omnithing { 61 o.numInputs = numInputs 62 o.columns = columns 63 numCols, numCells := len(columns), len(cells) 64 numRows := numCells / numCols 65 if numCols*numRows != numCells { 66 panic("wrong number of cells") 67 } 68 rows := [][]interface{}{} 69 for r := 0; r < numRows; r++ { 70 cols := []interface{}{} 71 for c := 0; c < numCols; c++ { 72 cols = append(cols, cells[r*numCols+c]) 73 } 74 rows = append(rows, cols) 75 } 76 o.rows = rows 77 return o 78 } 79 80 type querier func(conn *sql.DB) (interface{}, error) 81 82 var ( 83 testdriver = &omnithing{} 84 // make sure the test type implements the interfaces 85 _ driver.Driver = testdriver 86 _ driver.Conn = testdriver 87 _ driver.Tx = testdriver 88 _ driver.Stmt = testdriver 89 _ driver.Result = testdriver 90 _ driver.Rows = testdriver 91 ) 92 93 const driverType = "test" 94 95 func init() { 96 sql.Register(driverType, testdriver) 97 } 98 99 func runRowsTest(t *testing.T, query querier, numInputs int, columns []string, cells ...interface{}) { 100 // set intial state before usage 101 testdriver.setDB(numInputs, columns, cells...) 102 // run a query, retrieve *sql.Rows 103 conn, err := sql.Open(driverType, "") 104 defer conn.Close() 105 rowOrRows, err := query(conn) 106 if closer, ok := rowOrRows.(io.Closer); ok { 107 defer closer.Close() 108 } 109 // check that it is accessible and matches the one in testdriver.rows 110 unwrapped, err := Inspect(rowOrRows) 111 if err != nil { 112 t.Error(err) 113 return 114 } 115 myrows, ok := unwrapped.(*omnithing) 116 if !ok || myrows != testdriver { 117 t.Errorf("returned driver.Rows must match those passed in.") 118 } 119 } 120 121 func TestRowWithoutArgs(t *testing.T) { 122 query := func(conn *sql.DB) (interface{}, error) { 123 return conn.QueryRow(`SELECT "test"`), nil 124 } 125 runRowsTest(t, query, 0, []string{"header"}, "test") 126 } 127 128 func TestRowWithArgs(t *testing.T) { 129 query := func(conn *sql.DB) (interface{}, error) { 130 return conn.QueryRow(`SELECT ?`, "test"), nil 131 } 132 runRowsTest(t, query, 1, []string{"header"}, "test") 133 } 134 135 func TestRowsWithoutArgs(t *testing.T) { 136 query := func(conn *sql.DB) (interface{}, error) { 137 return conn.Query(`SELECT "test"`) 138 } 139 runRowsTest(t, query, 0, []string{"header"}, "test") 140 } 141 142 func TestRowsWithArgs(t *testing.T) { 143 query := func(conn *sql.DB) (interface{}, error) { 144 return conn.Query(`SELECT ?`, "test") 145 } 146 runRowsTest(t, query, 1, []string{"header"}, "test") 147 }