github.com/RevenueMonster/sqlike@v1.0.6/sql/codec/decoder.go (about)

     1  package codec
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"encoding/base64"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"fmt"
    10  	"net/url"
    11  	"reflect"
    12  	"regexp"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"cloud.google.com/go/civil"
    18  	"cloud.google.com/go/datastore"
    19  	"github.com/RevenueMonster/sqlike/jsonb"
    20  	"github.com/paulmach/orb"
    21  	"github.com/paulmach/orb/encoding/wkb"
    22  	"golang.org/x/text/currency"
    23  	"golang.org/x/text/language"
    24  
    25  	"errors"
    26  )
    27  
    28  // DefaultDecoders :
    29  type DefaultDecoders struct {
    30  	codec *Registry
    31  }
    32  
    33  // DecodeByte :
    34  func (dec DefaultDecoders) DecodeByte(it interface{}, v reflect.Value) error {
    35  	var (
    36  		x   []byte
    37  		err error
    38  	)
    39  	switch vi := it.(type) {
    40  	case string:
    41  		x, err = base64.StdEncoding.DecodeString(vi)
    42  		if err != nil {
    43  			return err
    44  		}
    45  	case []byte:
    46  		x, err = base64.StdEncoding.DecodeString(string(vi))
    47  		if err != nil {
    48  			return err
    49  		}
    50  	case nil:
    51  		x = make([]byte, 0)
    52  	}
    53  	v.SetBytes(x)
    54  	return nil
    55  }
    56  
    57  // DecodeRawBytes :
    58  func (dec DefaultDecoders) DecodeRawBytes(it interface{}, v reflect.Value) error {
    59  	var (
    60  		x sql.RawBytes
    61  	)
    62  	switch vi := it.(type) {
    63  	case []byte:
    64  		x = sql.RawBytes(vi)
    65  	case string:
    66  		x = sql.RawBytes(vi)
    67  	case sql.RawBytes:
    68  		x = vi
    69  	case bool:
    70  		str := strconv.FormatBool(vi)
    71  		x = []byte(str)
    72  	case int64:
    73  		str := strconv.FormatInt(vi, 10)
    74  		x = []byte(str)
    75  	case uint64:
    76  		str := strconv.FormatUint(vi, 10)
    77  		x = []byte(str)
    78  	case float64:
    79  		str := strconv.FormatFloat(vi, 'e', -1, 64)
    80  		x = []byte(str)
    81  	case time.Time:
    82  		x = []byte(vi.Format(time.RFC3339))
    83  	case nil:
    84  	default:
    85  	}
    86  	v.SetBytes(x)
    87  	return nil
    88  }
    89  
    90  // DecodeCurrency :
    91  func (dec DefaultDecoders) DecodeCurrency(it interface{}, v reflect.Value) error {
    92  	var (
    93  		x   currency.Unit
    94  		err error
    95  	)
    96  	switch vi := it.(type) {
    97  	case string:
    98  		x, err = currency.ParseISO(vi)
    99  		if err != nil {
   100  			return err
   101  		}
   102  	case []byte:
   103  		x, err = currency.ParseISO(string(vi))
   104  		if err != nil {
   105  			return err
   106  		}
   107  	case nil:
   108  	}
   109  	v.Set(reflect.ValueOf(x))
   110  	return nil
   111  }
   112  
   113  // DecodeLanguage :
   114  func (dec DefaultDecoders) DecodeLanguage(it interface{}, v reflect.Value) error {
   115  	var (
   116  		x   language.Tag
   117  		str string
   118  		err error
   119  	)
   120  	switch vi := it.(type) {
   121  	case string:
   122  		str = vi
   123  	case []byte:
   124  		str = string(vi)
   125  	case nil:
   126  	default:
   127  		return errors.New("language tag is not well-formed")
   128  	}
   129  	if str != "" {
   130  		x, err = language.Parse(str)
   131  		if err != nil {
   132  			return err
   133  		}
   134  	}
   135  	v.Set(reflect.ValueOf(x))
   136  	return nil
   137  }
   138  
   139  // DecodeJSONRaw :
   140  func (dec DefaultDecoders) DecodeJSONRaw(it interface{}, v reflect.Value) error {
   141  	b := new(bytes.Buffer)
   142  	switch vi := it.(type) {
   143  	case string:
   144  		if err := json.Compact(b, []byte(vi)); err != nil {
   145  			return err
   146  		}
   147  	case []byte:
   148  		if err := json.Compact(b, vi); err != nil {
   149  			return err
   150  		}
   151  	case nil:
   152  	}
   153  	v.SetBytes(b.Bytes())
   154  	return nil
   155  }
   156  
   157  // DecodeDateTime :
   158  func (dec DefaultDecoders) DecodeDateTime(it interface{}, v reflect.Value) error {
   159  	var (
   160  		x   time.Time
   161  		err error
   162  	)
   163  	switch vi := it.(type) {
   164  	case time.Time:
   165  		x = vi
   166  	case string:
   167  		x, err = decodeTime(vi)
   168  		if err != nil {
   169  			return err
   170  		}
   171  	case []byte:
   172  		x, err = decodeTime(b2s(vi))
   173  		if err != nil {
   174  			return err
   175  		}
   176  	case int64:
   177  		x = time.Unix(vi, 0)
   178  	case nil:
   179  	}
   180  	// convert back to UTC
   181  	v.Set(reflect.ValueOf(x.UTC()))
   182  	return nil
   183  }
   184  
   185  // DecodeDate :
   186  func (dec DefaultDecoders) DecodeDate(it interface{}, v reflect.Value) error {
   187  	var (
   188  		x   civil.Date
   189  		err error
   190  	)
   191  	switch vi := it.(type) {
   192  	case time.Time:
   193  		x = civil.DateOf(vi)
   194  	case string:
   195  		x, err = civil.ParseDate(vi)
   196  		if err != nil {
   197  			return err
   198  		}
   199  	case []byte:
   200  		x, err = civil.ParseDate(b2s(vi))
   201  		if err != nil {
   202  			return err
   203  		}
   204  	case int64:
   205  		x = civil.DateOf(time.Unix(vi, 0))
   206  	case nil:
   207  	}
   208  	v.Set(reflect.ValueOf(x))
   209  	return nil
   210  }
   211  
   212  // DecodeTimeLocation :
   213  func (dec DefaultDecoders) DecodeTimeLocation(it interface{}, v reflect.Value) error {
   214  	var x time.Location
   215  	switch vi := it.(type) {
   216  	case string:
   217  		tz, err := time.LoadLocation(vi)
   218  		if err != nil {
   219  			return err
   220  		}
   221  		x = *tz
   222  	case []byte:
   223  		tz, err := time.LoadLocation(string(vi))
   224  		if err != nil {
   225  			return err
   226  		}
   227  		x = *tz
   228  	case nil:
   229  	}
   230  	v.Set(reflect.ValueOf(x))
   231  	return nil
   232  }
   233  
   234  // DecodeTime :
   235  func (dec DefaultDecoders) DecodeTime(it interface{}, v reflect.Value) error {
   236  	var (
   237  		x   civil.Time
   238  		err error
   239  	)
   240  	switch vi := it.(type) {
   241  	case time.Time:
   242  		x = civil.TimeOf(vi)
   243  	case string:
   244  		x, err = civil.ParseTime(vi)
   245  		if err != nil {
   246  			return err
   247  		}
   248  	case []byte:
   249  		x, err = civil.ParseTime(b2s(vi))
   250  		if err != nil {
   251  			return err
   252  		}
   253  	case int64:
   254  		x = civil.TimeOf(time.Unix(vi, 0))
   255  	case nil:
   256  	}
   257  	v.Set(reflect.ValueOf(x))
   258  	return nil
   259  }
   260  
   261  // date format :
   262  var (
   263  	DDMMYYYY         = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}$`)
   264  	DDMMYYYYHHMMSS   = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}$`)
   265  	DDMMYYYYHHMMSSTZ = regexp.MustCompile(`^\d{4}\-\d{2}\-\d{2}\s\d{2}\:\d{2}:\d{2}\.\d+$`)
   266  )
   267  
   268  // DecodeTime : this will decode time by using multiple format
   269  func decodeTime(str string) (t time.Time, err error) {
   270  	switch {
   271  	case DDMMYYYY.MatchString(str):
   272  		t, err = time.Parse("2006-01-02", str)
   273  	case DDMMYYYYHHMMSS.MatchString(str):
   274  		t, err = time.Parse("2006-01-02 15:04:05", str)
   275  	case DDMMYYYYHHMMSSTZ.MatchString(str):
   276  		t, err = time.Parse("2006-01-02 15:04:05.999999", str)
   277  	default:
   278  		t, err = time.Parse(time.RFC3339Nano, str)
   279  	}
   280  	return
   281  }
   282  
   283  // DecodePoint :
   284  func (dec DefaultDecoders) DecodePoint(it interface{}, v reflect.Value) error {
   285  	var p orb.Point
   286  	if it == nil {
   287  		v.Set(reflect.ValueOf(p))
   288  		return nil
   289  	}
   290  
   291  	data, ok := it.([]byte)
   292  	if !ok {
   293  		return errors.New("point must be []byte")
   294  	}
   295  
   296  	length := len(data)
   297  	if length == 0 {
   298  		// empty data, return empty go struct which in this case
   299  		// would be [0,0]
   300  		return nil
   301  	}
   302  
   303  	if length == 42 {
   304  		dst := make([]byte, 21)
   305  		_, err := hex.Decode(dst, data)
   306  		if err != nil {
   307  			return err
   308  		}
   309  		data = dst
   310  	}
   311  
   312  	scanner := wkb.Scanner(&p)
   313  	// if len(data) == 21 {
   314  	// 	// the length of a point type in WKB
   315  	// 	return scan.Scan(data[:])
   316  	// }
   317  
   318  	if length == 25 {
   319  		// Most likely MySQL's SRID+WKB format.
   320  		// However, could be a line string or multipoint with only one point.
   321  		// But those would be invalid for parsing a point.
   322  		// return p.unmarshalWKB(data[4:])
   323  		if err := scanner.Scan(data[4:]); err != nil {
   324  			return err
   325  		}
   326  		v.Set(reflect.ValueOf(p))
   327  		return nil
   328  	}
   329  
   330  	return errors.New("incorrect point")
   331  }
   332  
   333  // DecodeLineString :
   334  func (dec DefaultDecoders) DecodeLineString(it interface{}, v reflect.Value) error {
   335  	var ls orb.LineString
   336  	if it == nil {
   337  		v.Set(reflect.ValueOf(ls))
   338  		return nil
   339  	}
   340  
   341  	data, ok := it.([]byte)
   342  	if !ok {
   343  		return errors.New("line string must be []byte")
   344  	}
   345  
   346  	if len(data) == 0 {
   347  		return nil
   348  	}
   349  
   350  	scanner := wkb.Scanner(&ls)
   351  	if err := scanner.Scan(data[4:]); err != nil {
   352  		return err
   353  	}
   354  
   355  	v.Set(reflect.ValueOf(ls))
   356  	return nil
   357  }
   358  
   359  // DecodeString :
   360  func (dec DefaultDecoders) DecodeString(it interface{}, v reflect.Value) error {
   361  	var x string
   362  	switch vi := it.(type) {
   363  	case string:
   364  		x = vi
   365  	case []byte:
   366  		x = string(vi)
   367  	case int64:
   368  		x = strconv.FormatInt(vi, 10)
   369  	case uint64:
   370  		x = strconv.FormatUint(vi, 10)
   371  	case float64:
   372  		x = strconv.FormatFloat(vi, 'f', -1, 64)
   373  	case bool:
   374  		x = strconv.FormatBool(vi)
   375  	case nil:
   376  	}
   377  	v.SetString(x)
   378  	return nil
   379  }
   380  
   381  // DecodeBool :
   382  func (dec DefaultDecoders) DecodeBool(it interface{}, v reflect.Value) error {
   383  	var (
   384  		x   bool
   385  		err error
   386  	)
   387  	switch vi := it.(type) {
   388  	case []byte:
   389  		x, err = strconv.ParseBool(b2s(vi))
   390  		if err != nil {
   391  			return err
   392  		}
   393  	case string:
   394  		x, err = strconv.ParseBool(vi)
   395  		if err != nil {
   396  			return err
   397  		}
   398  	case bool:
   399  		x = vi
   400  	case int64:
   401  		if vi == 1 {
   402  			x = true
   403  		}
   404  	case uint64:
   405  		if vi == 1 {
   406  			x = true
   407  		}
   408  	case nil:
   409  	}
   410  	v.SetBool(x)
   411  	return nil
   412  }
   413  
   414  // DecodeInt :
   415  func (dec DefaultDecoders) DecodeInt(it interface{}, v reflect.Value) error {
   416  	var (
   417  		x   int64
   418  		err error
   419  	)
   420  	switch vi := it.(type) {
   421  	case []byte:
   422  		x, err = strconv.ParseInt(b2s(vi), 10, 64)
   423  		if err != nil {
   424  			return err
   425  		}
   426  	case string:
   427  		x, err = strconv.ParseInt(vi, 10, 64)
   428  		if err != nil {
   429  			return err
   430  		}
   431  	case int64:
   432  		x = vi
   433  	case uint64:
   434  		x = int64(vi)
   435  	case float64:
   436  		x = int64(vi)
   437  	case nil:
   438  	}
   439  	if v.OverflowInt(x) {
   440  		return errors.New("integer overflow")
   441  	}
   442  	v.SetInt(x)
   443  	return nil
   444  }
   445  
   446  // DecodeUint :
   447  func (dec DefaultDecoders) DecodeUint(it interface{}, v reflect.Value) error {
   448  	var (
   449  		x   uint64
   450  		err error
   451  	)
   452  	switch vi := it.(type) {
   453  	case []byte:
   454  		x, err = strconv.ParseUint(b2s(vi), 10, 64)
   455  		if err != nil {
   456  			return err
   457  		}
   458  	case string:
   459  		x, err = strconv.ParseUint(vi, 10, 64)
   460  		if err != nil {
   461  			return err
   462  		}
   463  	case int64:
   464  		x = uint64(vi)
   465  	case uint64:
   466  		x = vi
   467  	case float64:
   468  		if vi > 0 {
   469  			x = uint64(vi)
   470  		}
   471  	case nil:
   472  	}
   473  	if v.OverflowUint(x) {
   474  		return errors.New("unsigned integer overflow")
   475  	}
   476  	v.SetUint(x)
   477  	return nil
   478  }
   479  
   480  // DecodeFloat :
   481  func (dec DefaultDecoders) DecodeFloat(it interface{}, v reflect.Value) error {
   482  	var (
   483  		x   float64
   484  		err error
   485  	)
   486  	switch vi := it.(type) {
   487  	case []byte:
   488  		x, err = strconv.ParseFloat(b2s(vi), 64)
   489  		if err != nil {
   490  			return err
   491  		}
   492  	case string:
   493  		x, err = strconv.ParseFloat(vi, 64)
   494  		if err != nil {
   495  			return err
   496  		}
   497  	case float64:
   498  		x = vi
   499  	case int64:
   500  		x = float64(vi)
   501  	case uint64:
   502  		x = float64(vi)
   503  	case nil:
   504  
   505  	}
   506  	if v.OverflowFloat(x) {
   507  		return errors.New("float overflow")
   508  	}
   509  	v.SetFloat(x)
   510  	return nil
   511  }
   512  
   513  // DecodePtr :
   514  func (dec *DefaultDecoders) DecodePtr(it interface{}, v reflect.Value) error {
   515  	t := v.Type()
   516  	if it == nil {
   517  		v.Set(reflect.Zero(t))
   518  		return nil
   519  	}
   520  	t = t.Elem()
   521  	decoder, err := dec.codec.LookupDecoder(t)
   522  	if err != nil {
   523  		return err
   524  	}
   525  	return decoder(it, v.Elem())
   526  }
   527  
   528  // DecodeStruct :
   529  func (dec *DefaultDecoders) DecodeStruct(it interface{}, v reflect.Value) error {
   530  	var b []byte
   531  	switch vi := it.(type) {
   532  	case string:
   533  		b = []byte(vi)
   534  	case []byte:
   535  		b = vi
   536  	}
   537  	return jsonb.UnmarshalValue(b, v)
   538  }
   539  
   540  // DecodeArray :
   541  func (dec DefaultDecoders) DecodeArray(it interface{}, v reflect.Value) error {
   542  	var b []byte
   543  	switch vi := it.(type) {
   544  	case string:
   545  		b = []byte(vi)
   546  	case []byte:
   547  		b = vi
   548  	}
   549  	return jsonb.UnmarshalValue(b, v)
   550  }
   551  
   552  // DecodeMap :
   553  func (dec DefaultDecoders) DecodeMap(it interface{}, v reflect.Value) error {
   554  	var b []byte
   555  	switch vi := it.(type) {
   556  	case string:
   557  		b = []byte(vi)
   558  	case []byte:
   559  		b = vi
   560  	}
   561  	return jsonb.UnmarshalValue(b, v)
   562  }
   563  
   564  func (dec DefaultDecoders) DecodeDatastoreKey(it interface{}, v reflect.Value) error {
   565  	key, err := parseKey(fmt.Sprintf("%s", it))
   566  	if err != nil {
   567  		return err
   568  	}
   569  
   570  	v.Set(reflect.ValueOf(key).Elem())
   571  	return nil
   572  }
   573  
   574  func parseKey(str string) (*datastore.Key, error) {
   575  	str = strings.Trim(strings.TrimSpace(str), `"`)
   576  	if str == "" {
   577  		var k *datastore.Key
   578  		return k, nil
   579  	}
   580  
   581  	paths := strings.Split(strings.Trim(str, "/"), "/")
   582  	parentKey := new(datastore.Key)
   583  	endOfIndex := len(paths) - 1
   584  	for i, p := range paths {
   585  		path := strings.Split(p, ",")
   586  		if len(path) != 2 && i != endOfIndex {
   587  			return nil, fmt.Errorf("goloquent: incorrect key value: %q, suppose %q", p, "table,value")
   588  		}
   589  
   590  		kind, value := "", ""
   591  		if len(path) != 2 {
   592  			kind = ""
   593  			value = path[0]
   594  		} else {
   595  			kind = path[0]
   596  			value = path[1]
   597  		}
   598  
   599  		key := new(datastore.Key)
   600  		key.Kind = kind
   601  		if isNameKey(value) {
   602  			name, err := url.PathUnescape(strings.Trim(value, `'`))
   603  			if err != nil {
   604  				return nil, err
   605  			}
   606  			key.Name = name
   607  		} else {
   608  			n, err := strconv.ParseInt(value, 10, 64)
   609  			if err != nil {
   610  				return nil, fmt.Errorf("goloquent: incorrect key id, %v", value)
   611  			}
   612  			key.ID = n
   613  		}
   614  
   615  		if !parentKey.Incomplete() {
   616  			key.Parent = parentKey
   617  		}
   618  		parentKey = key
   619  	}
   620  
   621  	return parentKey, nil
   622  }
   623  
   624  func isNameKey(strKey string) bool {
   625  	if strKey == "" {
   626  		return false
   627  	}
   628  	if strings.HasPrefix(strKey, "name=") {
   629  		return true
   630  	}
   631  	_, err := strconv.ParseInt(strKey, 10, 64)
   632  	if err != nil {
   633  		return true
   634  	}
   635  	paths := strings.Split(strKey, "/")
   636  	if len(paths) != 2 {
   637  		return strings.HasPrefix(strKey, "'") && strings.HasSuffix(strKey, "'")
   638  	}
   639  	lastPath := strings.Split(paths[len(paths)-1], ",")[1]
   640  	return strings.HasPrefix(lastPath, "'") || strings.HasSuffix(lastPath, "'")
   641  }