github.com/GnawNom/sqlinternals@v0.0.0-20200413232442-a2dcc5655e0f/mysqlinternals/unsafe.go (about)

     1  // sqlinternals for github.com/go-sql-driver/mysql - retrieve column metadata from sql.*Row / sql.*Rows
     2  //
     3  // Copyright 2013 Arne Hormann. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysqlinternals
    10  
    11  import (
    12  	"database/sql/driver"
    13  	"fmt"
    14  	"reflect"
    15  	"sync"
    16  	"unsafe"
    17  
    18  	"github.com/GnawNom/sqlinternals"
    19  )
    20  
    21  // keep in sync with github.com/go-sql-driver/mysql/const.go
    22  type fieldType byte
    23  const (
    24  	fieldTypeDecimal fieldType = iota
    25  	fieldTypeTiny
    26  	fieldTypeShort
    27  	fieldTypeLong
    28  	fieldTypeFloat
    29  	fieldTypeDouble
    30  	fieldTypeNULL
    31  	fieldTypeTimestamp
    32  	fieldTypeLongLong
    33  	fieldTypeInt24
    34  	fieldTypeDate
    35  	fieldTypeTime
    36  	fieldTypeDateTime
    37  	fieldTypeYear
    38  	fieldTypeNewDate
    39  	fieldTypeVarChar
    40  	fieldTypeBit
    41  )
    42  
    43  const (
    44  	fieldTypeJSON fieldType = iota + 0xf5
    45  	fieldTypeNewDecimal
    46  	fieldTypeEnum
    47  	fieldTypeSet
    48  	fieldTypeTinyBLOB
    49  	fieldTypeMediumBLOB
    50  	fieldTypeLongBLOB
    51  	fieldTypeBLOB
    52  	fieldTypeVarString
    53  	fieldTypeString
    54  	fieldTypeGeometry
    55  )
    56  
    57  type fieldFlag uint16
    58  
    59  const (
    60  	flagNotNULL fieldFlag = 1 << iota
    61  	flagPriKey
    62  	flagUniqueKey
    63  	flagMultipleKey
    64  	flagBLOB
    65  	flagUnsigned
    66  	flagZeroFill
    67  	flagBinary
    68  	flagEnum
    69  	flagAutoIncrement
    70  	flagTimestamp
    71  	flagSet
    72  	flagUnknown1
    73  	flagUnknown2
    74  	flagUnknown3
    75  	flagUnknown4
    76  )
    77  
    78  // keep mysqlRows and mysqlField in sync with structs in github.com/go-sql-driver/rows.go
    79  type mysqlField struct {
    80  	tableName string
    81  	name      string
    82  	length    uint32
    83  	flags     fieldFlag
    84  	fieldType fieldType
    85  	decimals  byte
    86  	charSet   uint8
    87  }
    88  
    89  type resultSet struct {
    90  	columns []mysqlField
    91  	columnNames []string
    92  	done    bool
    93  }
    94  
    95  type mysqlRows struct {
    96  	mc *mysqlConn
    97  	rs resultSet
    98  	finish func()
    99  }
   100  
   101  type emptyRows struct{}
   102  
   103  type rowEmbedder struct {
   104  	mysqlRows
   105  }
   106  
   107  // dummy for mysqlRows
   108  type mysqlConn struct{}
   109  
   110  // internals
   111  type mysqlError string
   112  
   113  func (e mysqlError) Error() string {
   114  	return string(e)
   115  }
   116  
   117  const (
   118  	errUnexpectedNil  = mysqlError("wrong argument, rows must not be nil")
   119  	errUnexpectedType = mysqlError("wrong argument, must be *mysql.mysqlRows")
   120  	rowtypeBinary     = "binaryRows"
   121  	rowtypeText       = "textRows"
   122  	rowtypeEmpty      = "emptyRows"
   123  )
   124  
   125  var (
   126  	// populate the offset only once
   127  	initMutex      sync.Mutex
   128  	failedInit     bool
   129  	structsChecked bool
   130  )
   131  
   132  // canConvert returns true if the memory layout and the struct field names of
   133  // 'from' match those of 'to'.
   134  func canConvert(from, to reflect.Type) bool {
   135  	switch {
   136  	case from.Kind() != reflect.Struct,
   137  		from.Kind() != to.Kind(),
   138  		from.Size() != to.Size(),
   139  		from.Name() != to.Name(),
   140  		from.NumField() != to.NumField():
   141  		return false
   142  	}
   143  	for i, max := 0, from.NumField(); i < max; i++ {
   144  		sf, tf := from.Field(i), to.Field(i)
   145  		if sf.Name != tf.Name || sf.Offset != tf.Offset {
   146  			return false
   147  		}
   148  		tsf, ttf := sf.Type, tf.Type
   149  		for done := false; !done; {
   150  			k := tsf.Kind()
   151  			if k != ttf.Kind() {
   152  				return false
   153  			}
   154  			switch k {
   155  			case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
   156  				tsf, ttf = tsf.Elem(), ttf.Elem()
   157  			case reflect.Interface:
   158  				// don't have to handle matching interfaces here
   159  				if tsf != ttf {
   160  					// there are none in our case, so we are extra strict
   161  					return false
   162  				}
   163  			case reflect.Struct:
   164  				if tsf.Name() != ttf.Name() {
   165  					return false
   166  				}
   167  				done = true
   168  			default:
   169  				done = true
   170  			}
   171  		}
   172  	}
   173  	return true
   174  }
   175  
   176  func initOffsets(rows driver.Rows) error {
   177  	const (
   178  		errWrapperMismatch   = mysqlError("unexpected structure of textRows or binaryRows")
   179  		errRowsMismatch      = mysqlError("unexpected structure of mysqlRows")
   180  		errResultsetMismatch = mysqlError("unexpected structure of resultSet")
   181  		errFieldMismatch     = mysqlError("unexpected structure of mysqlField")
   182  	)
   183  	// make sure mysqlRows is the right type (full certainty is impossible).
   184  	if rows == nil {
   185  		return errUnexpectedNil
   186  	}
   187  	argType := reflect.TypeOf(rows)
   188  	if argType.Kind() != reflect.Ptr {
   189  		return errUnexpectedType
   190  	}
   191  	elemType := argType.Elem()
   192  	if elemType.Kind() != reflect.Struct {
   193  		return errUnexpectedType
   194  	}
   195  	switch typeName := elemType.Name(); typeName {
   196  	case rowtypeBinary, rowtypeText:
   197  	default:
   198  		return errUnexpectedType
   199  	}
   200  	embedded, ok := elemType.FieldByName("mysqlRows")
   201  	if !ok {
   202  		return errWrapperMismatch
   203  	}
   204  	elemType = embedded.Type
   205  	// compare mysqlRows
   206  	if !canConvert(elemType, reflect.TypeOf(mysqlRows{})) {
   207  		return errRowsMismatch
   208  	}
   209  	resultSetField, ok := elemType.FieldByName("rs")
   210  	if !ok {
   211  		return errRowsMismatch
   212  	}
   213  	elemType = resultSetField.Type
   214  	// compare resultSet
   215  	if !canConvert(elemType, reflect.TypeOf(resultSet{})) {
   216  		return errRowsMismatch
   217  	}
   218  	colsField, ok := elemType.FieldByName("columns")
   219  	if !ok {
   220  		return errResultsetMismatch
   221  	}
   222  	// compare mysqlField
   223  	if !canConvert(colsField.Type.Elem(), reflect.TypeOf(mysqlField{})) {
   224  		fmt.Printf("=> %#v\n\n", reflect.Zero(colsField.Type.Elem()).Interface())
   225  		return errFieldMismatch
   226  	}
   227  	return nil
   228  }
   229  
   230  func driverRows(rowOrRows interface{}) (driver.Rows, bool) {
   231  	if rowOrRows == nil || failedInit {
   232  		return nil, false
   233  	}
   234  	rows, err := sqlinternals.Inspect(rowOrRows)
   235  	if err != nil || rows == nil {
   236  		return nil, false
   237  	}
   238  	dRows, ok := rows.(driver.Rows)
   239  	if !ok {
   240  		return nil, false
   241  	}
   242  	if uninitialized := !structsChecked; uninitialized {
   243  		ok = true
   244  		initMutex.Lock()
   245  		defer initMutex.Unlock()
   246  		if !failedInit {
   247  			switch err = initOffsets(dRows); err {
   248  			case nil:
   249  				structsChecked = true
   250  				uninitialized = false
   251  			case errUnexpectedType, errUnexpectedNil:
   252  				ok = false
   253  			default:
   254  				failedInit = true
   255  				ok = false
   256  			}
   257  			if !ok {
   258  				return nil, false
   259  			}
   260  		}
   261  	}
   262  	return dRows, true
   263  }
   264  
   265  // IsBinary reports whether the row value was retrieved using the binary protocol.
   266  //
   267  // MySQL results retrieved with prepared statements or Query with additional arguments
   268  // use the binary protocol. The results are typed, the driver will use the closest
   269  // matching Go type.
   270  // A plain Query call with only the query itself will not use the binary protocol but the
   271  // text protocol. The results are all strings in that case.
   272  func IsBinary(rowOrRows interface{}) (bool, error) {
   273  	const errUnavailable = mysqlError("IsBinary is not available")
   274  	dRows, ok := driverRows(rowOrRows)
   275  	if !ok {
   276  		return false, errUnavailable
   277  	}
   278  	argType := reflect.TypeOf(dRows)
   279  	return rowtypeBinary == argType.Elem().Name(), nil
   280  }
   281  
   282  // Columns retrieves a []Column for sql.Rows or sql.Row with type inspection abilities.
   283  //
   284  // The field indices match those of a call to Columns().
   285  // Returns an error if the argument is not sql.Rows or sql.Row based on github.com/go-sql-driver/mysql.
   286  func Columns(rowOrRows interface{}) ([]Column, error) {
   287  	const errUnavailable = mysqlError("Columns is not available")
   288  	dRows, ok := driverRows(rowOrRows)
   289  	if !ok {
   290  		return nil, errUnavailable
   291  	}
   292  	if rowtypeEmpty == reflect.TypeOf(dRows).Name() {
   293  		return nil, nil
   294  	}
   295  	cols := (*mysqlRows)((unsafe.Pointer)(reflect.ValueOf(dRows).Pointer())).rs.columns
   296  	columns := make([]Column, len(cols))
   297  	for i, c := range cols {
   298  		columns[i] = c
   299  	}
   300  	return columns, nil
   301  }