github.com/GnawNom/sqlinternals@v0.0.0-20200413232442-a2dcc5655e0f/mysqlinternals/mysql_test.go (about)

     1  // sqlinternals for github.com/go-sql-driver/mysql - retrieve column metadata from sql.*Row / sql.*Rows
     2  //
     3  // Copyright 2013 Arne Hormann. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysqlinternals
    10  
    11  import (
    12  	"database/sql"
    13  	_ "github.com/go-sql-driver/mysql"
    14  	"os"
    15  	"reflect"
    16  	"testing"
    17  )
    18  
    19  var dsn string
    20  
    21  type Scanner interface {
    22  	Scan(values ...interface{}) error
    23  }
    24  
    25  func init() {
    26  	if envdsn := os.Getenv("MYSQL_DSN"); envdsn != "" {
    27  		dsn = envdsn
    28  	} else {
    29  		dsn = "root@tcp(127.0.0.1:3306)/"
    30  	}
    31  }
    32  
    33  func testRow(t *testing.T, test typeTest, useQueryRow bool) {
    34  	db, err := sql.Open("mysql", dsn)
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	defer db.Close()
    39  	// check that it is accessible and matches the one in tester.rows
    40  	var source Scanner
    41  	var cols []Column
    42  	switch {
    43  	case useQueryRow:
    44  		var row *sql.Row
    45  		row, err = db.QueryRow(test.query, test.queryArgs...), nil
    46  		if err != nil {
    47  			break
    48  		}
    49  		source = row
    50  	case !useQueryRow:
    51  		var rows *sql.Rows
    52  		rows, err = db.Query(test.query, test.queryArgs...)
    53  		if err != nil {
    54  			break
    55  		}
    56  		defer rows.Close()
    57  		source = rows
    58  	}
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	cols, err = Columns(source)
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	col := cols[0]
    68  	decl, derr := col.MysqlDeclaration(test.sqlDeclParams...)
    69  	if test.sqlTypeError && derr == nil {
    70  		t.Errorf("SQL: expected an error in MysqlDeclaration\n")
    71  	}
    72  	if !test.sqlTypeError && derr != nil {
    73  		t.Errorf("SQL: did not expect an error in MysqlDeclaration, got '%v'\n", derr)
    74  	}
    75  	if decl != test.sqlType {
    76  		t.Errorf("SQL: type '%s' did not match expected '%s'\n", decl, test.sqlType)
    77  	}
    78  	refl, rerr := col.ReflectGoType()
    79  	if test.goTypeError && rerr == nil {
    80  		t.Errorf("Go: expected an error in ReflectType\n")
    81  	}
    82  	if !test.goTypeError && rerr != nil {
    83  		t.Errorf("Go: did not expect an error in ReflectType, got '%v'\n", rerr)
    84  	}
    85  	if refl != test.goType {
    86  		t.Errorf("Go: type '%s' did not match expected '%s'\n", refl, test.goType)
    87  	}
    88  	if test.hasValue {
    89  		if rows, ok := source.(*sql.Rows); ok && !rows.Next() {
    90  			t.Error("could not scan from sql.Rows")
    91  		}
    92  		err = source.Scan(&test.receiver)
    93  		if err != nil {
    94  			t.Error(err)
    95  			return
    96  		}
    97  		eVal := reflect.ValueOf(test.expectedValue)
    98  		rVal := reflect.ValueOf(test.receiver)
    99  		if eVal.Type() != rVal.Type() {
   100  			t.Errorf("types of expected (%s) and received (%s) values didn't match\n",
   101  				eVal.Type(), rVal.Type())
   102  		}
   103  		// TODO: compare value and assignability
   104  	}
   105  }
   106  
   107  func args(v ...interface{}) []interface{} {
   108  	return v
   109  }
   110  
   111  type typeTest struct {
   112  	id            string
   113  	query         string
   114  	queryArgs     []interface{}
   115  	sqlType       string
   116  	sqlDeclParams []interface{}
   117  	sqlTypeError  bool
   118  	goType        reflect.Type
   119  	goTypeError   bool
   120  	hasValue      bool
   121  	expectedValue interface{}
   122  	receiver      interface{}
   123  }
   124  
   125  func TestRows(t *testing.T) {
   126  	testSetups := []typeTest{
   127  		typeTest{
   128  			id:            "select string (text mode)",
   129  			query:         "select 'Hi'",
   130  			sqlType:       "VARCHAR(2) NOT NULL",
   131  			sqlDeclParams: args(2),
   132  			sqlTypeError:  false,
   133  			goType:        reflect.TypeOf(""),
   134  			goTypeError:   false,
   135  			hasValue:      true,
   136  			expectedValue: []byte("Hi"),
   137  		},
   138  		typeTest{
   139  			id:            "select string (binary mode)",
   140  			query:         "select ?",
   141  			queryArgs:     args("Hi"),
   142  			sqlType:       "CHAR(2) NOT NULL",
   143  			sqlDeclParams: args(2),
   144  			sqlTypeError:  false,
   145  			goType:        reflect.TypeOf(""),
   146  			goTypeError:   false,
   147  			hasValue:      true,
   148  			expectedValue: []byte("Hi"),
   149  		},
   150  		// TODO: add more tests (many columns, NULL column, different types...)
   151  	}
   152  	for _, setup := range testSetups {
   153  		testRow(t, setup, false)
   154  	}
   155  }
   156  
   157  func TestIsBinary(t *testing.T) {
   158  	tests := []struct {
   159  		result bool
   160  		query  string
   161  		args   []interface{}
   162  	}{
   163  		{result: false, query: "SELECT 1"},
   164  		{result: true, query: "SELECT ?", args: args(1)},
   165  	}
   166  	db, err := sql.Open("mysql", dsn)
   167  	if err != nil {
   168  		t.Fatal(err)
   169  	}
   170  	defer db.Close()
   171  	for _, setup := range tests {
   172  		rows, err := db.Query(setup.query, setup.args...)
   173  		if err != nil {
   174  			t.Fatal(err)
   175  		}
   176  		defer rows.Close()
   177  		if bin, err := IsBinary(rows); err != nil || bin != setup.result {
   178  			t.Errorf("test %#v failed", setup)
   179  		}
   180  	}
   181  }