github.com/GnawNom/sqlinternals@v0.0.0-20200413232442-a2dcc5655e0f/sqlinternals_test.go (about)

     1  // sqlinternals - retrieve driver.Rows from sql.*Row / sql.*Rows
     2  //
     3  // Copyright 2013 Arne Hormann. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package sqlinternals
    10  
    11  import (
    12  	"database/sql"
    13  	"database/sql/driver"
    14  	"io"
    15  	"testing"
    16  )
    17  
    18  type omnithing struct {
    19  	numInputs int
    20  	columns   []string
    21  	rows      [][]interface{}
    22  }
    23  
    24  func (t *omnithing) Close() error { return nil }
    25  
    26  // driver.Driver
    27  func (t *omnithing) Open(name string) (driver.Conn, error) { return t, nil }
    28  
    29  // driver.Conn
    30  func (t *omnithing) Prepare(query string) (driver.Stmt, error) { return t, nil }
    31  func (t *omnithing) Begin() (driver.Tx, error)                 { return t, nil }
    32  
    33  // driver.Tx
    34  func (t *omnithing) Commit() error   { return nil }
    35  func (t *omnithing) Rollback() error { return nil }
    36  
    37  // driver.Stmt
    38  func (t *omnithing) NumInput() int                                   { return t.numInputs }
    39  func (t *omnithing) Exec(args []driver.Value) (driver.Result, error) { return t, nil }
    40  func (t *omnithing) Query(args []driver.Value) (driver.Rows, error)  { return t, nil }
    41  
    42  // driver.Result
    43  func (t *omnithing) LastInsertId() (int64, error) { return 0, nil }
    44  func (t *omnithing) RowsAffected() (int64, error) { return 0, nil }
    45  
    46  // driver.Rows
    47  func (t *omnithing) Columns() []string { return t.columns }
    48  func (t *omnithing) Next(dest []driver.Value) error {
    49  	if len(t.rows) == 0 {
    50  		return io.EOF
    51  	}
    52  	var row []interface{}
    53  	row, t.rows = t.rows[0], t.rows[1:]
    54  	for i, v := range row {
    55  		dest[i] = v
    56  	}
    57  	return nil
    58  }
    59  
    60  func (o *omnithing) setDB(numInputs int, columns []string, cells ...interface{}) *omnithing {
    61  	o.numInputs = numInputs
    62  	o.columns = columns
    63  	numCols, numCells := len(columns), len(cells)
    64  	numRows := numCells / numCols
    65  	if numCols*numRows != numCells {
    66  		panic("wrong number of cells")
    67  	}
    68  	rows := [][]interface{}{}
    69  	for r := 0; r < numRows; r++ {
    70  		cols := []interface{}{}
    71  		for c := 0; c < numCols; c++ {
    72  			cols = append(cols, cells[r*numCols+c])
    73  		}
    74  		rows = append(rows, cols)
    75  	}
    76  	o.rows = rows
    77  	return o
    78  }
    79  
    80  type querier func(conn *sql.DB) (interface{}, error)
    81  
    82  var (
    83  	testdriver = &omnithing{}
    84  	// make sure the test type implements the interfaces
    85  	_ driver.Driver = testdriver
    86  	_ driver.Conn   = testdriver
    87  	_ driver.Tx     = testdriver
    88  	_ driver.Stmt   = testdriver
    89  	_ driver.Result = testdriver
    90  	_ driver.Rows   = testdriver
    91  )
    92  
    93  const driverType = "test"
    94  
    95  func init() {
    96  	sql.Register(driverType, testdriver)
    97  }
    98  
    99  func runRowsTest(t *testing.T, query querier, numInputs int, columns []string, cells ...interface{}) {
   100  	// set intial state before usage
   101  	testdriver.setDB(numInputs, columns, cells...)
   102  	// run a query, retrieve *sql.Rows
   103  	conn, err := sql.Open(driverType, "")
   104  	defer conn.Close()
   105  	rowOrRows, err := query(conn)
   106  	if closer, ok := rowOrRows.(io.Closer); ok {
   107  		defer closer.Close()
   108  	}
   109  	// check that it is accessible and matches the one in testdriver.rows
   110  	unwrapped, err := Inspect(rowOrRows)
   111  	if err != nil {
   112  		t.Error(err)
   113  		return
   114  	}
   115  	myrows, ok := unwrapped.(*omnithing)
   116  	if !ok || myrows != testdriver {
   117  		t.Errorf("returned driver.Rows must match those passed in.")
   118  	}
   119  }
   120  
   121  func TestRowWithoutArgs(t *testing.T) {
   122  	query := func(conn *sql.DB) (interface{}, error) {
   123  		return conn.QueryRow(`SELECT "test"`), nil
   124  	}
   125  	runRowsTest(t, query, 0, []string{"header"}, "test")
   126  }
   127  
   128  func TestRowWithArgs(t *testing.T) {
   129  	query := func(conn *sql.DB) (interface{}, error) {
   130  		return conn.QueryRow(`SELECT ?`, "test"), nil
   131  	}
   132  	runRowsTest(t, query, 1, []string{"header"}, "test")
   133  }
   134  
   135  func TestRowsWithoutArgs(t *testing.T) {
   136  	query := func(conn *sql.DB) (interface{}, error) {
   137  		return conn.Query(`SELECT "test"`)
   138  	}
   139  	runRowsTest(t, query, 0, []string{"header"}, "test")
   140  }
   141  
   142  func TestRowsWithArgs(t *testing.T) {
   143  	query := func(conn *sql.DB) (interface{}, error) {
   144  		return conn.Query(`SELECT ?`, "test")
   145  	}
   146  	runRowsTest(t, query, 1, []string{"header"}, "test")
   147  }