github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/go-xorm/xorm/helpers.go (about)

     1  // Copyright 2015 The Xorm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package xorm
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"reflect"
    11  	"sort"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/insionng/yougam/libraries/go-xorm/core"
    17  )
    18  
    19  // str2PK convert string value to primary key value according to tp
    20  func str2PK(s string, tp reflect.Type) (interface{}, error) {
    21  	var err error
    22  	var result interface{}
    23  	switch tp.Kind() {
    24  	case reflect.Int:
    25  		result, err = strconv.Atoi(s)
    26  		if err != nil {
    27  			return nil, errors.New("convert " + s + " as int: " + err.Error())
    28  		}
    29  	case reflect.Int8:
    30  		x, err := strconv.Atoi(s)
    31  		if err != nil {
    32  			return nil, errors.New("convert " + s + " as int16: " + err.Error())
    33  		}
    34  		result = int8(x)
    35  	case reflect.Int16:
    36  		x, err := strconv.Atoi(s)
    37  		if err != nil {
    38  			return nil, errors.New("convert " + s + " as int16: " + err.Error())
    39  		}
    40  		result = int16(x)
    41  	case reflect.Int32:
    42  		x, err := strconv.Atoi(s)
    43  		if err != nil {
    44  			return nil, errors.New("convert " + s + " as int32: " + err.Error())
    45  		}
    46  		result = int32(x)
    47  	case reflect.Int64:
    48  		result, err = strconv.ParseInt(s, 10, 64)
    49  		if err != nil {
    50  			return nil, errors.New("convert " + s + " as int64: " + err.Error())
    51  		}
    52  	case reflect.Uint:
    53  		x, err := strconv.ParseUint(s, 10, 64)
    54  		if err != nil {
    55  			return nil, errors.New("convert " + s + " as uint: " + err.Error())
    56  		}
    57  		result = uint(x)
    58  	case reflect.Uint8:
    59  		x, err := strconv.ParseUint(s, 10, 64)
    60  		if err != nil {
    61  			return nil, errors.New("convert " + s + " as uint8: " + err.Error())
    62  		}
    63  		result = uint8(x)
    64  	case reflect.Uint16:
    65  		x, err := strconv.ParseUint(s, 10, 64)
    66  		if err != nil {
    67  			return nil, errors.New("convert " + s + " as uint16: " + err.Error())
    68  		}
    69  		result = uint16(x)
    70  	case reflect.Uint32:
    71  		x, err := strconv.ParseUint(s, 10, 64)
    72  		if err != nil {
    73  			return nil, errors.New("convert " + s + " as uint32: " + err.Error())
    74  		}
    75  		result = uint32(x)
    76  	case reflect.Uint64:
    77  		result, err = strconv.ParseUint(s, 10, 64)
    78  		if err != nil {
    79  			return nil, errors.New("convert " + s + " as uint64: " + err.Error())
    80  		}
    81  	case reflect.String:
    82  		result = s
    83  	default:
    84  		panic("unsupported convert type")
    85  	}
    86  	result = reflect.ValueOf(result).Convert(tp).Interface()
    87  	return result, nil
    88  }
    89  
    90  func splitTag(tag string) (tags []string) {
    91  	tag = strings.TrimSpace(tag)
    92  	var hasQuote = false
    93  	var lastIdx = 0
    94  	for i, t := range tag {
    95  		if t == '\'' {
    96  			hasQuote = !hasQuote
    97  		} else if t == ' ' {
    98  			if lastIdx < i && !hasQuote {
    99  				tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
   100  				lastIdx = i + 1
   101  			}
   102  		}
   103  	}
   104  	if lastIdx < len(tag) {
   105  		tags = append(tags, strings.TrimSpace(tag[lastIdx:len(tag)]))
   106  	}
   107  	return
   108  }
   109  
   110  type zeroable interface {
   111  	IsZero() bool
   112  }
   113  
   114  func isZero(k interface{}) bool {
   115  	switch k.(type) {
   116  	case int:
   117  		return k.(int) == 0
   118  	case int8:
   119  		return k.(int8) == 0
   120  	case int16:
   121  		return k.(int16) == 0
   122  	case int32:
   123  		return k.(int32) == 0
   124  	case int64:
   125  		return k.(int64) == 0
   126  	case uint:
   127  		return k.(uint) == 0
   128  	case uint8:
   129  		return k.(uint8) == 0
   130  	case uint16:
   131  		return k.(uint16) == 0
   132  	case uint32:
   133  		return k.(uint32) == 0
   134  	case uint64:
   135  		return k.(uint64) == 0
   136  	case float32:
   137  		return k.(float32) == 0
   138  	case float64:
   139  		return k.(float64) == 0
   140  	case bool:
   141  		return k.(bool) == false
   142  	case string:
   143  		return k.(string) == ""
   144  	case zeroable:
   145  		return k.(zeroable).IsZero()
   146  	}
   147  	return false
   148  }
   149  
   150  func isStructZero(v reflect.Value) bool {
   151  	if !v.IsValid() {
   152  		return true
   153  	}
   154  
   155  	for i := 0; i < v.NumField(); i++ {
   156  		field := v.Field(i)
   157  		switch field.Kind() {
   158  		case reflect.Ptr:
   159  			field = field.Elem()
   160  			fallthrough
   161  		case reflect.Struct:
   162  			if !isStructZero(field) {
   163  				return false
   164  			}
   165  		default:
   166  			if field.CanInterface() && !isZero(field.Interface()) {
   167  				return false
   168  			}
   169  		}
   170  	}
   171  	return true
   172  }
   173  
   174  func int64ToIntValue(id int64, tp reflect.Type) reflect.Value {
   175  	var v interface{}
   176  	switch tp.Kind() {
   177  	case reflect.Int16:
   178  		v = int16(id)
   179  	case reflect.Int32:
   180  		v = int32(id)
   181  	case reflect.Int:
   182  		v = int(id)
   183  	case reflect.Int64:
   184  		v = id
   185  	case reflect.Uint16:
   186  		v = uint16(id)
   187  	case reflect.Uint32:
   188  		v = uint32(id)
   189  	case reflect.Uint64:
   190  		v = uint64(id)
   191  	case reflect.Uint:
   192  		v = uint(id)
   193  	}
   194  	return reflect.ValueOf(v).Convert(tp)
   195  }
   196  
   197  func int64ToInt(id int64, tp reflect.Type) interface{} {
   198  	return int64ToIntValue(id, tp).Interface()
   199  }
   200  
   201  func isPKZero(pk core.PK) bool {
   202  	for _, k := range pk {
   203  		if isZero(k) {
   204  			return true
   205  		}
   206  	}
   207  	return false
   208  }
   209  
   210  func indexNoCase(s, sep string) int {
   211  	return strings.Index(strings.ToLower(s), strings.ToLower(sep))
   212  }
   213  
   214  func splitNoCase(s, sep string) []string {
   215  	idx := indexNoCase(s, sep)
   216  	if idx < 0 {
   217  		return []string{s}
   218  	}
   219  	return strings.Split(s, s[idx:idx+len(sep)])
   220  }
   221  
   222  func splitNNoCase(s, sep string, n int) []string {
   223  	idx := indexNoCase(s, sep)
   224  	if idx < 0 {
   225  		return []string{s}
   226  	}
   227  	return strings.SplitN(s, s[idx:idx+len(sep)], n)
   228  }
   229  
   230  func makeArray(elem string, count int) []string {
   231  	res := make([]string, count)
   232  	for i := 0; i < count; i++ {
   233  		res[i] = elem
   234  	}
   235  	return res
   236  }
   237  
   238  func rValue(bean interface{}) reflect.Value {
   239  	return reflect.Indirect(reflect.ValueOf(bean))
   240  }
   241  
   242  func rType(bean interface{}) reflect.Type {
   243  	sliceValue := reflect.Indirect(reflect.ValueOf(bean))
   244  	//return reflect.TypeOf(sliceValue.Interface())
   245  	return sliceValue.Type()
   246  }
   247  
   248  func structName(v reflect.Type) string {
   249  	for v.Kind() == reflect.Ptr {
   250  		v = v.Elem()
   251  	}
   252  	return v.Name()
   253  }
   254  
   255  func col2NewCols(columns ...string) []string {
   256  	newColumns := make([]string, 0, len(columns))
   257  	for _, col := range columns {
   258  		col = strings.Replace(col, "`", "", -1)
   259  		col = strings.Replace(col, `"`, "", -1)
   260  		ccols := strings.Split(col, ",")
   261  		for _, c := range ccols {
   262  			newColumns = append(newColumns, strings.TrimSpace(c))
   263  		}
   264  	}
   265  	return newColumns
   266  }
   267  
   268  func sliceEq(left, right []string) bool {
   269  	if len(left) != len(right) {
   270  		return false
   271  	}
   272  	sort.Sort(sort.StringSlice(left))
   273  	sort.Sort(sort.StringSlice(right))
   274  	for i := 0; i < len(left); i++ {
   275  		if left[i] != right[i] {
   276  			return false
   277  		}
   278  	}
   279  	return true
   280  }
   281  
   282  func reflect2value(rawValue *reflect.Value) (str string, err error) {
   283  	aa := reflect.TypeOf((*rawValue).Interface())
   284  	vv := reflect.ValueOf((*rawValue).Interface())
   285  	switch aa.Kind() {
   286  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   287  		str = strconv.FormatInt(vv.Int(), 10)
   288  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   289  		str = strconv.FormatUint(vv.Uint(), 10)
   290  	case reflect.Float32, reflect.Float64:
   291  		str = strconv.FormatFloat(vv.Float(), 'f', -1, 64)
   292  	case reflect.String:
   293  		str = vv.String()
   294  	case reflect.Array, reflect.Slice:
   295  		switch aa.Elem().Kind() {
   296  		case reflect.Uint8:
   297  			data := rawValue.Interface().([]byte)
   298  			str = string(data)
   299  		default:
   300  			err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
   301  		}
   302  	// time type
   303  	case reflect.Struct:
   304  		if aa.ConvertibleTo(core.TimeType) {
   305  			str = vv.Convert(core.TimeType).Interface().(time.Time).Format(time.RFC3339Nano)
   306  		} else {
   307  			err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
   308  		}
   309  	case reflect.Bool:
   310  		str = strconv.FormatBool(vv.Bool())
   311  	case reflect.Complex128, reflect.Complex64:
   312  		str = fmt.Sprintf("%v", vv.Complex())
   313  	/* TODO: unsupported types below
   314  	   case reflect.Map:
   315  	   case reflect.Ptr:
   316  	   case reflect.Uintptr:
   317  	   case reflect.UnsafePointer:
   318  	   case reflect.Chan, reflect.Func, reflect.Interface:
   319  	*/
   320  	default:
   321  		err = fmt.Errorf("Unsupported struct type %v", vv.Type().Name())
   322  	}
   323  	return
   324  }
   325  
   326  func value2Bytes(rawValue *reflect.Value) (data []byte, err error) {
   327  	var str string
   328  	str, err = reflect2value(rawValue)
   329  	if err != nil {
   330  		return
   331  	}
   332  	data = []byte(str)
   333  	return
   334  }
   335  
   336  func value2String(rawValue *reflect.Value) (data string, err error) {
   337  	data, err = reflect2value(rawValue)
   338  	if err != nil {
   339  		return
   340  	}
   341  	return
   342  }
   343  
   344  func rows2Strings(rows *core.Rows) (resultsSlice []map[string]string, err error) {
   345  	fields, err := rows.Columns()
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  	for rows.Next() {
   350  		result, err := row2mapStr(rows, fields)
   351  		if err != nil {
   352  			return nil, err
   353  		}
   354  		resultsSlice = append(resultsSlice, result)
   355  	}
   356  
   357  	return resultsSlice, nil
   358  }
   359  
   360  func rows2maps(rows *core.Rows) (resultsSlice []map[string][]byte, err error) {
   361  	fields, err := rows.Columns()
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  	for rows.Next() {
   366  		result, err := row2map(rows, fields)
   367  		if err != nil {
   368  			return nil, err
   369  		}
   370  		resultsSlice = append(resultsSlice, result)
   371  	}
   372  
   373  	return resultsSlice, nil
   374  }
   375  
   376  func row2map(rows *core.Rows, fields []string) (resultsMap map[string][]byte, err error) {
   377  	result := make(map[string][]byte)
   378  	scanResultContainers := make([]interface{}, len(fields))
   379  	for i := 0; i < len(fields); i++ {
   380  		var scanResultContainer interface{}
   381  		scanResultContainers[i] = &scanResultContainer
   382  	}
   383  	if err := rows.Scan(scanResultContainers...); err != nil {
   384  		return nil, err
   385  	}
   386  
   387  	for ii, key := range fields {
   388  		rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
   389  		//if row is null then ignore
   390  		if rawValue.Interface() == nil {
   391  			//fmt.Println("ignore ...", key, rawValue)
   392  			continue
   393  		}
   394  
   395  		if data, err := value2Bytes(&rawValue); err == nil {
   396  			result[key] = data
   397  		} else {
   398  			return nil, err // !nashtsai! REVIEW, should return err or just error log?
   399  		}
   400  	}
   401  	return result, nil
   402  }
   403  
   404  func row2mapStr(rows *core.Rows, fields []string) (resultsMap map[string]string, err error) {
   405  	result := make(map[string]string)
   406  	scanResultContainers := make([]interface{}, len(fields))
   407  	for i := 0; i < len(fields); i++ {
   408  		var scanResultContainer interface{}
   409  		scanResultContainers[i] = &scanResultContainer
   410  	}
   411  	if err := rows.Scan(scanResultContainers...); err != nil {
   412  		return nil, err
   413  	}
   414  
   415  	for ii, key := range fields {
   416  		rawValue := reflect.Indirect(reflect.ValueOf(scanResultContainers[ii]))
   417  		//if row is null then ignore
   418  		if rawValue.Interface() == nil {
   419  			//fmt.Println("ignore ...", key, rawValue)
   420  			continue
   421  		}
   422  
   423  		if data, err := value2String(&rawValue); err == nil {
   424  			result[key] = data
   425  		} else {
   426  			return nil, err // !nashtsai! REVIEW, should return err or just error log?
   427  		}
   428  	}
   429  	return result, nil
   430  }
   431  
   432  func txQuery2(tx *core.Tx, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
   433  	rows, err := tx.Query(sqlStr, params...)
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  	defer rows.Close()
   438  
   439  	return rows2Strings(rows)
   440  }
   441  
   442  func query2(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string]string, err error) {
   443  	s, err := db.Prepare(sqlStr)
   444  	if err != nil {
   445  		return nil, err
   446  	}
   447  	defer s.Close()
   448  	rows, err := s.Query(params...)
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  	defer rows.Close()
   453  	return rows2Strings(rows)
   454  }
   455  
   456  func setColumnInt(bean interface{}, col *core.Column, t int64) {
   457  	v, err := col.ValueOf(bean)
   458  	if err != nil {
   459  		return
   460  	}
   461  	if v.CanSet() {
   462  		switch v.Type().Kind() {
   463  		case reflect.Int, reflect.Int64, reflect.Int32:
   464  			v.SetInt(t)
   465  		case reflect.Uint, reflect.Uint64, reflect.Uint32:
   466  			v.SetUint(uint64(t))
   467  		}
   468  	}
   469  }
   470  
   471  func setColumnTime(bean interface{}, col *core.Column, t time.Time) {
   472  	v, err := col.ValueOf(bean)
   473  	if err != nil {
   474  		return
   475  	}
   476  	if v.CanSet() {
   477  		switch v.Type().Kind() {
   478  		case reflect.Struct:
   479  			v.Set(reflect.ValueOf(t).Convert(v.Type()))
   480  		case reflect.Int, reflect.Int64, reflect.Int32:
   481  			v.SetInt(t.Unix())
   482  		case reflect.Uint, reflect.Uint64, reflect.Uint32:
   483  			v.SetUint(uint64(t.Unix()))
   484  		}
   485  	}
   486  }
   487  
   488  func genCols(table *core.Table, session *Session, bean interface{}, useCol bool, includeQuote bool) ([]string, []interface{}, error) {
   489  	colNames := make([]string, 0, len(table.ColumnsSeq()))
   490  	args := make([]interface{}, 0, len(table.ColumnsSeq()))
   491  
   492  	for _, col := range table.Columns() {
   493  		if useCol && !col.IsVersion && !col.IsCreated && !col.IsUpdated {
   494  			if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
   495  				continue
   496  			}
   497  		}
   498  		if col.MapType == core.ONLYFROMDB {
   499  			continue
   500  		}
   501  
   502  		fieldValuePtr, err := col.ValueOf(bean)
   503  		if err != nil {
   504  			return nil, nil, err
   505  		}
   506  		fieldValue := *fieldValuePtr
   507  
   508  		if col.IsAutoIncrement {
   509  			switch fieldValue.Type().Kind() {
   510  			case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
   511  				if fieldValue.Int() == 0 {
   512  					continue
   513  				}
   514  			case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint, reflect.Uint64:
   515  				if fieldValue.Uint() == 0 {
   516  					continue
   517  				}
   518  			case reflect.String:
   519  				if len(fieldValue.String()) == 0 {
   520  					continue
   521  				}
   522  			}
   523  		}
   524  
   525  		if col.IsDeleted {
   526  			continue
   527  		}
   528  
   529  		if session.Statement.ColumnStr != "" {
   530  			if _, ok := getFlagForColumn(session.Statement.columnMap, col); !ok {
   531  				continue
   532  			}
   533  		}
   534  		if session.Statement.OmitStr != "" {
   535  			if _, ok := getFlagForColumn(session.Statement.columnMap, col); ok {
   536  				continue
   537  			}
   538  		}
   539  
   540  		// !evalphobia! set fieldValue as nil when column is nullable and zero-value
   541  		if _, ok := getFlagForColumn(session.Statement.nullableMap, col); ok {
   542  			if col.Nullable && isZero(fieldValue.Interface()) {
   543  				var nilValue *int
   544  				fieldValue = reflect.ValueOf(nilValue)
   545  			}
   546  		}
   547  
   548  		if (col.IsCreated || col.IsUpdated) && session.Statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
   549  			// if time is non-empty, then set to auto time
   550  			val, t := session.Engine.NowTime2(col.SQLType.Name)
   551  			args = append(args, val)
   552  
   553  			var colName = col.Name
   554  			session.afterClosures = append(session.afterClosures, func(bean interface{}) {
   555  				col := table.GetColumn(colName)
   556  				setColumnTime(bean, col, t)
   557  			})
   558  		} else if col.IsVersion && session.Statement.checkVersion {
   559  			args = append(args, 1)
   560  		} else {
   561  			arg, err := session.value2Interface(col, fieldValue)
   562  			if err != nil {
   563  				return colNames, args, err
   564  			}
   565  			args = append(args, arg)
   566  		}
   567  
   568  		if includeQuote {
   569  			colNames = append(colNames, session.Engine.Quote(col.Name)+" = ?")
   570  		} else {
   571  			colNames = append(colNames, col.Name)
   572  		}
   573  	}
   574  	return colNames, args, nil
   575  }
   576  
   577  func indexName(tableName, idxName string) string {
   578  	return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
   579  }
   580  
   581  func getFlagForColumn(m map[string]bool, col *core.Column) (val bool, has bool) {
   582  
   583  	if len(m) == 0 {
   584  		return false, false
   585  	}
   586  
   587  	n := len(col.Name)
   588  
   589  	for mk := range m {
   590  		if len(mk) != n {
   591  			continue
   592  		}
   593  		if strings.EqualFold(mk, col.Name) {
   594  			return m[mk], true
   595  		}
   596  	}
   597  
   598  	return false, false
   599  }