github.com/GnawNom/sqlinternals@v0.0.0-20200413232442-a2dcc5655e0f/unsafe.go (about)

     1  // sqlinternals - retrieve driver.Rows from sql.*Row / sql.*Rows
     2  //
     3  // Copyright 2013 Arne Hormann. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package sqlinternals
    10  
    11  import (
    12  	"database/sql"
    13  	"database/sql/driver"
    14  	"reflect"
    15  	"unsafe"
    16  )
    17  
    18  var (
    19  	// field offsets for unsafe access (types are checked beforehand)
    20  	offsetRowRows   uintptr // sql.Row.rows: sql.*Rows
    21  	offsetRowsRowsi uintptr // sql.Rows.rowsi: driver.Rows
    22  )
    23  
    24  // internal error type
    25  type internalErr string
    26  
    27  func (e internalErr) Error() string {
    28  	return string(e)
    29  }
    30  
    31  const (
    32  	errArgNil       = internalErr("argument must not be nil")
    33  	errArgWrongType = internalErr("argument was not *sql.Row or *sql.Rows")
    34  	errRowRowsNil   = internalErr("'err' xor 'rows' in sql.Row must be nil")
    35  	errRowsRowsiNil = internalErr("'rowsi driver.Rows' in sql.Rows is nil")
    36  )
    37  
    38  // a driver.Rows implementatiton so we are able
    39  // to get a type assignable to driver.Rows with reflect
    40  type dummyRows struct{}
    41  
    42  func (d dummyRows) Columns() []string {
    43  	return nil
    44  }
    45  
    46  func (d dummyRows) Close() error {
    47  	return nil
    48  }
    49  
    50  func (d dummyRows) Next(dest []driver.Value) error {
    51  	return nil
    52  }
    53  
    54  // basic type assertion, panic on error
    55  func panicIfUnassignable(field reflect.StructField, assignable reflect.Type, panicMsg string) {
    56  	fType := field.Type
    57  	if assignable == fType || assignable.AssignableTo(fType) {
    58  		return
    59  	}
    60  	panic(panicMsg + "; " + assignable.String() + " is not assignable to " + fType.String())
    61  }
    62  
    63  func init() {
    64  	// all types we need to check as templates
    65  	var (
    66  		tRow        reflect.Type = reflect.TypeOf(sql.Row{})
    67  		tRows       reflect.Type = reflect.TypeOf(sql.Rows{})
    68  		tRowsPtr    reflect.Type = reflect.TypeOf(&sql.Rows{})
    69  		tDriverRows reflect.Type = reflect.TypeOf((driver.Rows)(dummyRows{}))
    70  	)
    71  	var i, expectFields, fields int
    72  	// sql.Row must have a field "rows sql.*Rows"
    73  	for i, expectFields, fields = 0, 1, tRow.NumField(); i < fields; i++ {
    74  		field := tRow.Field(i)
    75  		switch field.Name {
    76  		case "rows":
    77  			panicIfUnassignable(field, tRowsPtr,
    78  				"database/sql/Row.rows is not database/sql/*Rows")
    79  			offsetRowRows = field.Offset
    80  			expectFields--
    81  		}
    82  	}
    83  	if expectFields != 0 {
    84  		panic("unexpected structure of database/sql/Row")
    85  	}
    86  	// sql.Rows must have a field "rowsi driver.Rows"
    87  	for i, expectFields, fields = 0, 1, tRows.NumField(); i < fields; i++ {
    88  		if field := tRows.Field(i); field.Name == "rowsi" {
    89  			panicIfUnassignable(field, tDriverRows,
    90  				"database/sql/Rows.rowsi is not database/sql/driver/Rows")
    91  			offsetRowsRowsi = field.Offset
    92  			expectFields--
    93  		}
    94  	}
    95  	if expectFields != 0 {
    96  		panic("unexpected structure of database/sql/Rows")
    97  	}
    98  }
    99  
   100  // Inspect extracts the internal driver.Rows from sql.*Row or sql.*Rows.
   101  // This can be used by a driver to work around issue 5606 in Go until a better way exists.
   102  func Inspect(sqlStruct interface{}) (interface{}, error) {
   103  	// All of this has to use unsafe to access unexported fields, but it's robust:
   104  	// we checked the types and structure in init.
   105  	if sqlStruct == nil {
   106  		return nil, errArgNil
   107  	}
   108  	var rows *sql.Rows
   109  	switch v := sqlStruct.(type) {
   110  	case *sql.Row:
   111  		// extract rows from sql/*Row, if v.rows is nil, an error is returned.
   112  		rowsPtr := (uintptr)((unsafe.Pointer)(v)) + offsetRowRows
   113  		unsafeRows := *(**sql.Rows)((unsafe.Pointer)(rowsPtr))
   114  		if unsafeRows == nil {
   115  			return nil, errRowRowsNil
   116  		}
   117  		rows = unsafeRows
   118  	case *sql.Rows:
   119  		rows = v
   120  	default:
   121  		return errArgWrongType, nil
   122  	}
   123  	// return rowsi from sql.*Rows, if rows.rowsi is nil an error is returned.
   124  	rowsiPtr := offsetRowsRowsi + (uintptr)((unsafe.Pointer)(rows))
   125  	rowsi := *(*driver.Rows)((unsafe.Pointer)(rowsiPtr))
   126  	if rowsi == nil {
   127  		return nil, errRowsRowsiNil
   128  	}
   129  	return rowsi, nil
   130  }