github.com/XiaoMi/Gaea@v1.2.5/mysql/result.go (about)

     1  // Copyright 2016 The kingshard Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"): you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
    11  // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
    12  // License for the specific language governing permissions and limitations
    13  // under the License.
    14  
    15  // Copyright 2019 The Gaea Authors. All Rights Reserved.
    16  //
    17  // Licensed under the Apache License, Version 2.0 (the "License");
    18  // you may not use this file except in compliance with the License.
    19  // You may obtain a copy of the License at
    20  //
    21  //     http://www.apache.org/licenses/LICENSE-2.0
    22  //
    23  // Unless required by applicable law or agreed to in writing, software
    24  // distributed under the License is distributed on an "AS IS" BASIS,
    25  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    26  // See the License for the specific language governing permissions and
    27  // limitations under the License.
    28  
    29  package mysql
    30  
    31  import (
    32  	"bytes"
    33  	"encoding/binary"
    34  	"fmt"
    35  	"strconv"
    36  
    37  	"github.com/XiaoMi/Gaea/core/errors"
    38  	"github.com/XiaoMi/Gaea/util/hack"
    39  )
    40  
    41  // RowData row in []byte format
    42  type RowData []byte
    43  
    44  // Parse parse data to field
    45  func (p RowData) Parse(f []*Field, binary bool) ([]interface{}, error) {
    46  	if binary {
    47  		return p.ParseBinary(f)
    48  	}
    49  	return p.ParseText(f)
    50  }
    51  
    52  // ParseText parse text format data
    53  func (p RowData) ParseText(f []*Field) ([]interface{}, error) {
    54  	data := make([]interface{}, len(f))
    55  
    56  	var err error
    57  	var v []byte
    58  	var isNull, isUnsigned bool
    59  	var pos = 0
    60  	var ok = false
    61  
    62  	for i := range f {
    63  		v, pos, isNull, ok = ReadLenEncStringAsBytes(p, pos)
    64  		if !ok {
    65  			return nil, fmt.Errorf("ReadLenEncStringAsBytes in ParseText failed")
    66  		}
    67  
    68  		if isNull {
    69  			data[i] = nil
    70  		} else {
    71  			isUnsigned = (f[i].Flag&uint16(UnsignedFlag) > 0)
    72  			switch f[i].Type {
    73  			case TypeTiny, TypeShort, TypeLong, TypeInt24,
    74  				TypeLonglong, TypeYear:
    75  				if isUnsigned {
    76  					data[i], err = strconv.ParseUint(string(v), 10, 64)
    77  				} else {
    78  					data[i], err = strconv.ParseInt(string(v), 10, 64)
    79  				}
    80  			case TypeFloat, TypeDouble, TypeNewDecimal:
    81  				data[i], err = strconv.ParseFloat(string(v), 64)
    82  			case TypeVarchar, TypeVarString,
    83  				TypeString, TypeDatetime,
    84  				TypeDate, TypeDuration, TypeTimestamp:
    85  				data[i] = string(v)
    86  			default:
    87  				data[i] = v
    88  			}
    89  
    90  			if err != nil {
    91  				return nil, err
    92  			}
    93  		}
    94  	}
    95  
    96  	return data, nil
    97  }
    98  
    99  // ParseBinary parse binary format data
   100  func (p RowData) ParseBinary(f []*Field) ([]interface{}, error) {
   101  	data := make([]interface{}, len(f))
   102  
   103  	if p[0] != OKHeader {
   104  		return nil, ErrMalformPacket
   105  	}
   106  
   107  	pos := 1 + ((len(f) + 7 + 2) >> 3)
   108  
   109  	nullBitmap := p[1:pos]
   110  
   111  	var isUnsigned bool
   112  	var isNull bool
   113  	var err error
   114  	var v []byte
   115  	for i := range data {
   116  		if nullBitmap[(i+2)/8]&(1<<(uint(i+2)%8)) > 0 {
   117  			data[i] = nil
   118  			continue
   119  		}
   120  
   121  		isUnsigned = f[i].Flag&uint16(UnsignedFlag) > 0
   122  
   123  		switch f[i].Type {
   124  		case TypeNull:
   125  			data[i] = nil
   126  			continue
   127  
   128  		case TypeTiny:
   129  			if isUnsigned {
   130  				data[i] = uint64(p[pos])
   131  			} else {
   132  				data[i] = int64(p[pos])
   133  			}
   134  			pos++
   135  			continue
   136  
   137  		case TypeShort, TypeYear:
   138  			if isUnsigned {
   139  				data[i] = uint64(binary.LittleEndian.Uint16(p[pos : pos+2]))
   140  			} else {
   141  				var n int16
   142  				err = binary.Read(bytes.NewBuffer(p[pos:pos+2]), binary.LittleEndian, &n)
   143  				if err != nil {
   144  					return nil, err
   145  				}
   146  				data[i] = int64(n)
   147  			}
   148  			pos += 2
   149  			continue
   150  
   151  		case TypeInt24, TypeLong:
   152  			if isUnsigned {
   153  				data[i] = uint64(binary.LittleEndian.Uint32(p[pos : pos+4]))
   154  			} else {
   155  				var n int32
   156  				err = binary.Read(bytes.NewBuffer(p[pos:pos+4]), binary.LittleEndian, &n)
   157  				if err != nil {
   158  					return nil, err
   159  				}
   160  				data[i] = int64(n)
   161  			}
   162  			pos += 4
   163  			continue
   164  
   165  		case TypeLonglong:
   166  			if isUnsigned {
   167  				data[i] = binary.LittleEndian.Uint64(p[pos : pos+8])
   168  			} else {
   169  				var n int64
   170  				err = binary.Read(bytes.NewBuffer(p[pos:pos+8]), binary.LittleEndian, &n)
   171  				if err != nil {
   172  					return nil, err
   173  				}
   174  				data[i] = int64(n)
   175  			}
   176  			pos += 8
   177  			continue
   178  
   179  		case TypeFloat:
   180  			//data[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(p[pos : pos+4])))
   181  			var n float32
   182  			err = binary.Read(bytes.NewBuffer(p[pos:pos+4]), binary.LittleEndian, &n)
   183  			if err != nil {
   184  				return nil, err
   185  			}
   186  			data[i] = float64(n)
   187  			pos += 4
   188  			continue
   189  
   190  		case TypeDouble:
   191  			var n float64
   192  			err = binary.Read(bytes.NewBuffer(p[pos:pos+8]), binary.LittleEndian, &n)
   193  			if err != nil {
   194  				return nil, err
   195  			}
   196  			data[i] = n
   197  			pos += 8
   198  			continue
   199  
   200  		case TypeDecimal, TypeNewDecimal, TypeVarchar,
   201  			TypeBit, TypeEnum, TypeSet, TypeTinyBlob,
   202  			TypeMediumBlob, TypeLongBlob, TypeBlob,
   203  			TypeVarString, TypeString, TypeGeometry:
   204  			var ok = false
   205  			v, pos, isNull, ok = ReadLenEncStringAsBytes(p, pos)
   206  			if !ok {
   207  				return nil, fmt.Errorf("ReadLenEncStringAsBytes in ParseBinary failed")
   208  			}
   209  
   210  			if !isNull {
   211  				data[i] = v
   212  				continue
   213  			} else {
   214  				data[i] = nil
   215  				continue
   216  			}
   217  		case TypeDate, TypeNewDate:
   218  			var num uint64
   219  			num, pos, isNull, _ = ReadLenEncInt(p, pos)
   220  
   221  			if isNull {
   222  				data[i] = nil
   223  				continue
   224  			}
   225  
   226  			data[i], err = FormatBinaryDate(int(num), p[pos:])
   227  			pos += int(num)
   228  
   229  			if err != nil {
   230  				return nil, err
   231  			}
   232  
   233  		case TypeTimestamp, TypeDatetime:
   234  			var num uint64
   235  			num, pos, isNull, _ = ReadLenEncInt(p, pos)
   236  
   237  			if isNull {
   238  				data[i] = nil
   239  				continue
   240  			}
   241  
   242  			data[i], err = FormatBinaryDateTime(int(num), p[pos:])
   243  			pos += int(num)
   244  
   245  			if err != nil {
   246  				return nil, err
   247  			}
   248  
   249  		case TypeDuration:
   250  			var num uint64
   251  			num, pos, isNull, _ = ReadLenEncInt(p, pos)
   252  
   253  			if isNull {
   254  				data[i] = nil
   255  				continue
   256  			}
   257  
   258  			data[i], err = FormatBinaryTime(int(num), p[pos:])
   259  			pos += int(num)
   260  
   261  			if err != nil {
   262  				return nil, err
   263  			}
   264  
   265  		default:
   266  			return nil, fmt.Errorf("Stmt Unknown FieldType %d %s", f[i].Type, f[i].Name)
   267  		}
   268  	}
   269  
   270  	return data, nil
   271  }
   272  
   273  // Result means mysql status、results after sql execution
   274  type Result struct {
   275  	Status uint16
   276  
   277  	InsertID     uint64
   278  	AffectedRows uint64
   279  
   280  	*Resultset
   281  }
   282  
   283  // Resultset means mysql results of sql execution, included split table sql
   284  type Resultset struct {
   285  	Fields     []*Field        // columns information
   286  	FieldNames map[string]int  // column information, key: column name value: index in Fields
   287  	Values     [][]interface{} // values after sql handled
   288  
   289  	RowDatas []RowData // data will returned
   290  }
   291  
   292  // RowNumber return row number of results
   293  func (r *Resultset) RowNumber() int {
   294  	return len(r.Values)
   295  }
   296  
   297  // ColumnNumber return column number of results
   298  func (r *Resultset) ColumnNumber() int {
   299  	return len(r.Fields)
   300  }
   301  
   302  // GetValue return value in special row and column
   303  func (r *Resultset) GetValue(row, column int) (interface{}, error) {
   304  	if row >= len(r.Values) || row < 0 {
   305  		return nil, fmt.Errorf("invalid row index %d", row)
   306  	}
   307  
   308  	if column >= len(r.Fields) || column < 0 {
   309  		return nil, fmt.Errorf("invalid column index %d", column)
   310  	}
   311  
   312  	return r.Values[row][column], nil
   313  }
   314  
   315  // NameIndex return column index in Fields
   316  func (r *Resultset) NameIndex(name string) (int, error) {
   317  	column, ok := r.FieldNames[name]
   318  	if ok {
   319  		return column, nil
   320  	}
   321  	return 0, fmt.Errorf("invalid field name %s", name)
   322  }
   323  
   324  // GetValueByName return value in special row and column
   325  func (r *Resultset) GetValueByName(row int, name string) (interface{}, error) {
   326  	column, err := r.NameIndex(name)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	return r.GetValue(row, column)
   331  }
   332  
   333  // IsNull check if value in special row and column is Null
   334  func (r *Resultset) IsNull(row, column int) (bool, error) {
   335  	d, err := r.GetValue(row, column)
   336  	if err != nil {
   337  		return false, err
   338  	}
   339  
   340  	return d == nil, nil
   341  }
   342  
   343  // IsNullByName check if value in special row and column is Null, but the entry param is column name
   344  func (r *Resultset) IsNullByName(row int, name string) (bool, error) {
   345  	column, err := r.NameIndex(name)
   346  	if err != nil {
   347  		return false, err
   348  	}
   349  	return r.IsNull(row, column)
   350  }
   351  
   352  // GetUint return value in special row and column in uint64 type
   353  func (r *Resultset) GetUint(row, column int) (uint64, error) {
   354  	d, err := r.GetValue(row, column)
   355  	if err != nil {
   356  		return 0, err
   357  	}
   358  
   359  	switch v := d.(type) {
   360  	case uint64:
   361  		return v, nil
   362  	case int64:
   363  		return uint64(v), nil
   364  	case float64:
   365  		return uint64(v), nil
   366  	case string:
   367  		return strconv.ParseUint(v, 10, 64)
   368  	case []byte:
   369  		return strconv.ParseUint(string(v), 10, 64)
   370  	case nil:
   371  		return 0, nil
   372  	default:
   373  		return 0, fmt.Errorf("data type is %T", v)
   374  	}
   375  }
   376  
   377  // GetUintByName return value in special row and column in uint64 type, but the entry param is name
   378  func (r *Resultset) GetUintByName(row int, name string) (uint64, error) {
   379  	column, err := r.NameIndex(name)
   380  	if err != nil {
   381  		return 0, err
   382  	}
   383  	return r.GetUint(row, column)
   384  }
   385  
   386  // GetIntByName return value in special row and column in int64 type, but the entry param is name
   387  func (r *Resultset) GetIntByName(row int, name string) (int64, error) {
   388  	column, err := r.NameIndex(name)
   389  	if err != nil {
   390  		return 0, err
   391  	}
   392  	return r.GetInt(row, column)
   393  }
   394  
   395  // GetInt return value in special row and column in int64 type
   396  func (r *Resultset) GetInt(row, column int) (int64, error) {
   397  	d, err := r.GetValue(row, column)
   398  	if err != nil {
   399  		return 0, err
   400  	}
   401  
   402  	switch v := d.(type) {
   403  	case uint64:
   404  		return int64(v), nil
   405  	case int64:
   406  		return v, nil
   407  	case float64:
   408  		return int64(v), nil
   409  	case string:
   410  		return strconv.ParseInt(v, 10, 64)
   411  	case []byte:
   412  		return strconv.ParseInt(string(v), 10, 64)
   413  	case nil:
   414  		return 0, nil
   415  	default:
   416  		return 0, fmt.Errorf("data type is %T", v)
   417  	}
   418  }
   419  
   420  // GetFloat return value in special row and column in float64 type
   421  func (r *Resultset) GetFloat(row, column int) (float64, error) {
   422  	d, err := r.GetValue(row, column)
   423  	if err != nil {
   424  		return 0, err
   425  	}
   426  
   427  	switch v := d.(type) {
   428  	case float64:
   429  		return v, nil
   430  	case uint64:
   431  		return float64(v), nil
   432  	case int64:
   433  		return float64(v), nil
   434  	case string:
   435  		return strconv.ParseFloat(v, 64)
   436  	case []byte:
   437  		return strconv.ParseFloat(string(v), 64)
   438  	case nil:
   439  		return 0, nil
   440  	default:
   441  		return 0, fmt.Errorf("data type is %T", v)
   442  	}
   443  }
   444  
   445  // GetFloatByName return value in special row and column in float64 type, but the entry param is name
   446  func (r *Resultset) GetFloatByName(row int, name string) (float64, error) {
   447  	column, err := r.NameIndex(name)
   448  	if err != nil {
   449  		return 0, err
   450  	}
   451  	return r.GetFloat(row, column)
   452  }
   453  
   454  // GetString return value in special row and column in string type
   455  func (r *Resultset) GetString(row, column int) (string, error) {
   456  	d, err := r.GetValue(row, column)
   457  	if err != nil {
   458  		return "", err
   459  	}
   460  
   461  	switch v := d.(type) {
   462  	case string:
   463  		return v, nil
   464  	case []byte:
   465  		return hack.String(v), nil
   466  	case int64:
   467  		return strconv.FormatInt(v, 10), nil
   468  	case uint64:
   469  		return strconv.FormatUint(v, 10), nil
   470  	case float64:
   471  		return strconv.FormatFloat(v, 'f', -1, 64), nil
   472  	case nil:
   473  		return "", nil
   474  	default:
   475  		return "", fmt.Errorf("data type is %T", v)
   476  	}
   477  }
   478  
   479  // GetStringByName return value in special row and column in string type, but the entry param is name
   480  func (r *Resultset) GetStringByName(row int, name string) (string, error) {
   481  	column, err := r.NameIndex(name)
   482  	if err != nil {
   483  		return "", err
   484  	}
   485  	return r.GetString(row, column)
   486  }
   487  
   488  // BuildResultset build resultset
   489  func BuildResultset(fields []*Field, names []string, values [][]interface{}) (*Resultset, error) {
   490  	var ExistFields bool
   491  	r := new(Resultset)
   492  
   493  	r.Fields = make([]*Field, len(names))
   494  	r.FieldNames = make(map[string]int, len(names))
   495  
   496  	//use the field def that get from true database
   497  	if len(fields) != 0 {
   498  		if len(r.Fields) == len(fields) {
   499  			ExistFields = true
   500  		} else {
   501  			return nil, errors.ErrInvalidArgument
   502  		}
   503  	}
   504  
   505  	var b []byte
   506  	var err error
   507  
   508  	for i, vs := range values {
   509  		if len(vs) != len(r.Fields) {
   510  			return nil, fmt.Errorf("row %d has %d column not equal %d", i, len(vs), len(r.Fields))
   511  		}
   512  
   513  		var row []byte
   514  		for j, value := range vs {
   515  			// build fields
   516  			if i == 0 {
   517  				if ExistFields {
   518  					r.Fields[j] = fields[j]
   519  					r.FieldNames[string(r.Fields[j].Name)] = j
   520  				} else {
   521  					field := &Field{}
   522  					r.Fields[j] = field
   523  					r.FieldNames[string(r.Fields[j].Name)] = j
   524  					field.Name = hack.Slice(names[j])
   525  					if err = formatField(field, value); err != nil {
   526  						return nil, err
   527  					}
   528  				}
   529  
   530  			}
   531  			// build row values
   532  			b, err = formatValue(value)
   533  			if err != nil {
   534  				return nil, err
   535  			}
   536  
   537  			row = AppendLenEncStringBytes(row, b)
   538  		}
   539  
   540  		r.RowDatas = append(r.RowDatas, row)
   541  	}
   542  	//assign the values to the result
   543  	r.Values = values
   544  
   545  	return r, nil
   546  }
   547  
   548  // BuildBinaryResultset build binary resultset
   549  // https://dev.mysql.com/doc/internals/en/binary-protocol-resultset.html
   550  func BuildBinaryResultset(fields []*Field, values [][]interface{}) (*Resultset, error) {
   551  	r := new(Resultset)
   552  	r.Fields = make([]*Field, len(fields))
   553  	for i := range fields {
   554  		r.Fields[i] = fields[i]
   555  	}
   556  
   557  	bitmapLen := ((len(fields) + 7 + 2) >> 3)
   558  	for i, v := range values {
   559  		if len(v) != len(r.Fields) {
   560  			return nil, fmt.Errorf("row %d has %d columns not equal %d", i, len(v), len(r.Fields))
   561  		}
   562  
   563  		var row []byte
   564  		nullBitMap := make([]byte, bitmapLen)
   565  		row = append(row, 0)
   566  		row = append(row, nullBitMap...)
   567  		for j, rowVal := range v {
   568  			if rowVal == nil {
   569  				bytePos := (j + 2) / 8
   570  				bitPos := byte((j + 2) % 8)
   571  				nullBitMap[bytePos] |= 1 << bitPos
   572  				continue
   573  			}
   574  
   575  			var err error
   576  			row, err = AppendBinaryValue(row, r.Fields[j].Type, rowVal)
   577  			if err != nil {
   578  				return nil, err
   579  			}
   580  		}
   581  		copy(row[1:], nullBitMap)
   582  		r.RowDatas = append(r.RowDatas, row)
   583  	}
   584  
   585  	return r, nil
   586  }
   587  
   588  // formatField encode field according to type of value if necessary
   589  func formatField(field *Field, value interface{}) error {
   590  	switch value.(type) {
   591  	case int8, int16, int32, int64, int:
   592  		field.Charset = 63
   593  		field.Type = TypeLonglong
   594  		field.Flag = uint16(BinaryFlag | NotNullFlag)
   595  	case uint8, uint16, uint32, uint64, uint:
   596  		field.Charset = 63
   597  		field.Type = TypeLonglong
   598  		field.Flag = uint16(BinaryFlag | NotNullFlag | UnsignedFlag)
   599  	case float32, float64:
   600  		field.Charset = 63
   601  		field.Type = TypeDouble
   602  		field.Flag = uint16(BinaryFlag | NotNullFlag)
   603  	case string, []byte:
   604  		field.Charset = 33
   605  		field.Type = TypeVarString
   606  	default:
   607  		return fmt.Errorf("unsupport type %T for resultset", value)
   608  	}
   609  	return nil
   610  }
   611  
   612  // formatValue encode value into a string format
   613  func formatValue(value interface{}) ([]byte, error) {
   614  	if value == nil {
   615  		return hack.Slice("NULL"), nil
   616  	}
   617  	switch v := value.(type) {
   618  	case int8:
   619  		return strconv.AppendInt(nil, int64(v), 10), nil
   620  	case int16:
   621  		return strconv.AppendInt(nil, int64(v), 10), nil
   622  	case int32:
   623  		return strconv.AppendInt(nil, int64(v), 10), nil
   624  	case int64:
   625  		return strconv.AppendInt(nil, int64(v), 10), nil
   626  	case int:
   627  		return strconv.AppendInt(nil, int64(v), 10), nil
   628  	case uint8:
   629  		return strconv.AppendUint(nil, uint64(v), 10), nil
   630  	case uint16:
   631  		return strconv.AppendUint(nil, uint64(v), 10), nil
   632  	case uint32:
   633  		return strconv.AppendUint(nil, uint64(v), 10), nil
   634  	case uint64:
   635  		return strconv.AppendUint(nil, uint64(v), 10), nil
   636  	case uint:
   637  		return strconv.AppendUint(nil, uint64(v), 10), nil
   638  	case float32:
   639  		return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil
   640  	case float64:
   641  		return strconv.AppendFloat(nil, float64(v), 'f', -1, 64), nil
   642  	case []byte:
   643  		return v, nil
   644  	case string:
   645  		return hack.Slice(v), nil
   646  	default:
   647  		return nil, fmt.Errorf("invalid type %T", value)
   648  	}
   649  }