github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/transports/params.go (about)

     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   * 	http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    17  
    18  package transports
    19  
    20  import (
    21  	"bytes"
    22  	"fmt"
    23  	"github.com/aacfactory/errors"
    24  	"github.com/aacfactory/fns/commons/bytex"
    25  	"github.com/aacfactory/fns/commons/objects"
    26  	"github.com/valyala/bytebufferpool"
    27  	"reflect"
    28  	"sort"
    29  	"strconv"
    30  	"strings"
    31  	"time"
    32  )
    33  
    34  type Params interface {
    35  	Get(name []byte) []byte
    36  	Set(name []byte, value []byte)
    37  	Add(name []byte, value []byte)
    38  	Values(name []byte) [][]byte
    39  	Remove(name []byte)
    40  	Len() int
    41  	Encode() (p []byte)
    42  }
    43  
    44  func ObjectParams(params Params) objects.Object {
    45  	return paramsObject{
    46  		value: params,
    47  	}
    48  }
    49  
    50  type paramsObject struct {
    51  	value Params
    52  }
    53  
    54  func (p paramsObject) Valid() (ok bool) {
    55  	ok = p.value.Len() > 0
    56  	return
    57  }
    58  
    59  func (p paramsObject) Value() (v any) {
    60  	v = p.value
    61  	return
    62  }
    63  
    64  func (p paramsObject) Unmarshal(dst interface{}) (err error) {
    65  	err = DecodeParams(p.value, dst)
    66  	return
    67  }
    68  
    69  func (p paramsObject) Marshal() ([]byte, error) {
    70  	encoded := p.value.Encode()
    71  	capacity := len(encoded) + 2
    72  	b := make([]byte, capacity)
    73  	b[0] = '"'
    74  	b[capacity-1] = '"'
    75  	copy(b[1:], encoded)
    76  	return b, nil
    77  }
    78  
    79  type paramValues [][]byte
    80  
    81  func (values paramValues) Len() int {
    82  	return len(values)
    83  }
    84  
    85  func (values paramValues) Less(i, j int) bool {
    86  	return bytes.Compare(values[i], values[j]) < 0
    87  }
    88  
    89  func (values paramValues) Swap(i, j int) {
    90  	values[i], values[j] = values[j], values[i]
    91  }
    92  
    93  type param struct {
    94  	key []byte
    95  	val paramValues
    96  }
    97  
    98  func NewParams() Params {
    99  	pp := make(defaultParams, 0, 1)
   100  	return &pp
   101  }
   102  
   103  type defaultParams []param
   104  
   105  func (params *defaultParams) Less(i, j int) bool {
   106  	pp := *params
   107  	return bytes.Compare(pp[i].key, pp[j].key) < 0
   108  }
   109  
   110  func (params *defaultParams) Swap(i, j int) {
   111  	pp := *params
   112  	pp[i], pp[j] = pp[j], pp[i]
   113  	*params = pp
   114  }
   115  
   116  func (params *defaultParams) Get(name []byte) []byte {
   117  	if name == nil {
   118  		return nil
   119  	}
   120  	if len(name) == 0 {
   121  		return nil
   122  	}
   123  	pp := *params
   124  	for _, p := range pp {
   125  		if bytes.Equal(p.key, name) {
   126  			return p.val[0]
   127  		}
   128  	}
   129  	return nil
   130  }
   131  
   132  func (params *defaultParams) Set(name []byte, value []byte) {
   133  	if name == nil || value == nil {
   134  		return
   135  	}
   136  	if len(name) == 0 {
   137  		return
   138  	}
   139  	pp := *params
   140  	for _, p := range pp {
   141  		if bytes.Equal(p.key, name) {
   142  			p.val = [][]byte{value}
   143  			*params = pp
   144  			return
   145  		}
   146  	}
   147  	pp = append(pp, param{
   148  		key: name,
   149  		val: [][]byte{value},
   150  	})
   151  	*params = pp
   152  }
   153  
   154  func (params *defaultParams) Add(name []byte, value []byte) {
   155  	if name == nil || value == nil {
   156  		return
   157  	}
   158  	if len(name) == 0 {
   159  		return
   160  	}
   161  	pp := *params
   162  	for i, p := range pp {
   163  		if bytes.Equal(p.key, name) {
   164  			p.val = append(p.val, value)
   165  			pp[i] = p
   166  			*params = pp
   167  			return
   168  		}
   169  	}
   170  	pp = append(pp, param{
   171  		key: name,
   172  		val: [][]byte{value},
   173  	})
   174  	*params = pp
   175  }
   176  
   177  func (params *defaultParams) Values(name []byte) [][]byte {
   178  	if name == nil {
   179  		return nil
   180  	}
   181  	if len(name) == 0 {
   182  		return nil
   183  	}
   184  	pp := *params
   185  	for _, p := range pp {
   186  		if bytes.Equal(p.key, name) {
   187  			return p.val
   188  		}
   189  	}
   190  	return nil
   191  }
   192  
   193  func (params *defaultParams) Remove(name []byte) {
   194  	if name == nil {
   195  		return
   196  	}
   197  	if len(name) == 0 {
   198  		return
   199  	}
   200  	pp := *params
   201  	n := -1
   202  	for i, p := range pp {
   203  		if bytes.Equal(p.key, name) {
   204  			n = i
   205  			break
   206  		}
   207  	}
   208  	if n == -1 {
   209  		return
   210  	}
   211  	pp = append(pp[:n], pp[n+1:]...)
   212  	*params = pp
   213  }
   214  
   215  func (params *defaultParams) Len() int {
   216  	return len(*params)
   217  }
   218  
   219  func (params *defaultParams) Encode() []byte {
   220  	if params.Len() == 0 {
   221  		return nil
   222  	}
   223  	sort.Sort(params)
   224  	pp := *params
   225  	buf := bytebufferpool.Get()
   226  	for _, p := range pp {
   227  		if p.val.Len() == 1 {
   228  			_, _ = buf.WriteString(fmt.Sprintf("&%s=%s", bytex.ToString(p.key), bytex.ToString(p.val[0])))
   229  			continue
   230  		}
   231  		values := p.val
   232  		sort.Sort(values)
   233  		for _, value := range values {
   234  			_, _ = buf.WriteString(fmt.Sprintf("&%s=%s", bytex.ToString(p.key), bytex.ToString(value)))
   235  		}
   236  	}
   237  	p := buf.Bytes()[1:]
   238  	bytebufferpool.Put(buf)
   239  	return p
   240  }
   241  
   242  var (
   243  	stringType  = reflect.TypeOf("")
   244  	boolType    = reflect.TypeOf(false)
   245  	intType     = reflect.TypeOf(0)
   246  	int32Type   = reflect.TypeOf(int32(0))
   247  	int64Type   = reflect.TypeOf(int64(0))
   248  	float32Type = reflect.TypeOf(float32(0.0))
   249  	float64Type = reflect.TypeOf(float64(0))
   250  	uintType    = reflect.TypeOf(uint(0))
   251  	uint32Type  = reflect.TypeOf(uint32(0))
   252  	uint64Type  = reflect.TypeOf(uint64(0))
   253  	timeType    = reflect.TypeOf(time.Time{})
   254  )
   255  
   256  func DecodeParams(params Params, dst interface{}) (err error) {
   257  	if dst == nil {
   258  		err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("dst target is nil"))
   259  		return
   260  	}
   261  	if params.Len() == 0 {
   262  		return
   263  	}
   264  	rv := reflect.ValueOf(dst)
   265  	if rv.Kind() != reflect.Pointer {
   266  		err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("dst target is not pointer"))
   267  		return
   268  	}
   269  	rv = rv.Elem()
   270  	if rv.Kind() != reflect.Struct {
   271  		err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("dst target is not pointer struct"))
   272  		return
   273  	}
   274  	rt := rv.Type()
   275  	fieldNum := rt.NumField()
   276  	for i := 0; i < fieldNum; i++ {
   277  		ft := rt.Field(i)
   278  		if !ft.IsExported() {
   279  			continue
   280  		}
   281  		if ft.Anonymous {
   282  			if ft.Type.Kind() != reflect.Struct {
   283  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("dst target can not has not struct typed anonymous field"))
   284  				return
   285  			}
   286  			anonymous := rv.Field(i).Addr().Interface()
   287  			err = DecodeParams(params, anonymous)
   288  			if err != nil {
   289  				return
   290  			}
   291  		}
   292  		name := ft.Name
   293  		tag, hasTag := ft.Tag.Lookup("json")
   294  		if hasTag {
   295  			if tag == "-" {
   296  				continue
   297  			}
   298  			n := strings.Index(tag, ",")
   299  			if n > 0 {
   300  				tag = tag[0:n]
   301  			}
   302  			name = tag
   303  		}
   304  		pv := bytes.TrimSpace(params.Get(bytex.FromString(name)))
   305  		if len(pv) == 0 {
   306  			continue
   307  		}
   308  		fv := rv.Field(i)
   309  		switch ft.Type.Kind() {
   310  		case reflect.String:
   311  			s := bytex.ToString(pv)
   312  			if ft.Type == stringType {
   313  				fv.SetString(s)
   314  			} else {
   315  				fv.Set(reflect.ValueOf(s).Convert(ft.Type))
   316  			}
   317  			break
   318  		case reflect.Bool:
   319  			b, parseErr := strconv.ParseBool(bytex.ToString(pv))
   320  			if parseErr != nil {
   321  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not bool", name))
   322  				return
   323  			}
   324  			if ft.Type == boolType {
   325  				fv.SetBool(b)
   326  			} else {
   327  				fv.Set(reflect.ValueOf(b).Convert(ft.Type))
   328  			}
   329  			break
   330  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   331  			n, parseErr := strconv.ParseInt(bytex.ToString(pv), 10, 64)
   332  			if parseErr != nil {
   333  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not int", name))
   334  				return
   335  			}
   336  			if ft.Type == intType || ft.Type == int32Type || ft.Type == int64Type {
   337  				fv.SetInt(n)
   338  			} else {
   339  				fv.Set(reflect.ValueOf(n).Convert(ft.Type))
   340  			}
   341  			break
   342  		case reflect.Float32, reflect.Float64:
   343  			f, parseErr := strconv.ParseFloat(bytex.ToString(pv), 64)
   344  			if parseErr != nil {
   345  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not float", name))
   346  				return
   347  			}
   348  			if ft.Type == float32Type || ft.Type == float64Type {
   349  				fv.SetFloat(f)
   350  			} else {
   351  				fv.Set(reflect.ValueOf(f).Convert(ft.Type))
   352  			}
   353  			break
   354  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   355  			u, parseErr := strconv.ParseUint(bytex.ToString(pv), 10, 64)
   356  			if parseErr != nil {
   357  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not uint", name))
   358  				return
   359  			}
   360  			if ft.Type == uintType || ft.Type == uint32Type || ft.Type == uint64Type {
   361  				fv.SetUint(u)
   362  			} else {
   363  				fv.Set(reflect.ValueOf(u).Convert(ft.Type))
   364  			}
   365  			break
   366  		case reflect.Struct:
   367  			if ft.Type == timeType || timeType.ConvertibleTo(ft.Type) {
   368  				t, parseErr := time.Parse(time.RFC3339, bytex.ToString(pv))
   369  				if parseErr != nil {
   370  					err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not RFC3339 time", name))
   371  					return
   372  				}
   373  				if ft.Type == timeType {
   374  					fv.Set(reflect.ValueOf(t))
   375  				} else {
   376  					fv.Set(reflect.ValueOf(t).Convert(ft.Type))
   377  				}
   378  			} else {
   379  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("type of %s is not supported", name))
   380  				return
   381  			}
   382  			break
   383  		case reflect.Slice:
   384  			// pvv values or splits
   385  			pvv := params.Values(bytex.FromString(name))
   386  			if len(pvv) == 1 {
   387  				pvv = bytes.Split(pvv[0], []byte{','})
   388  			}
   389  			for pi, pvx := range pvv {
   390  				pvv[pi] = bytes.TrimSpace(pvx)
   391  			}
   392  			eft := ft.Type.Elem()
   393  			switch eft.Kind() {
   394  			case reflect.String:
   395  				ss := reflect.MakeSlice(ft.Type, 0, 1)
   396  				for _, pvx := range pvv {
   397  					s := bytex.ToString(pvx)
   398  					e := reflect.New(eft).Elem()
   399  					if e.Type() == stringType {
   400  						e.SetString(s)
   401  					} else {
   402  						e.Set(reflect.ValueOf(s).Convert(e.Type()))
   403  					}
   404  					ss = reflect.Append(ss, e)
   405  				}
   406  				fv.Set(ss)
   407  				break
   408  			case reflect.Bool:
   409  				bb := reflect.MakeSlice(ft.Type, 0, 1)
   410  				for _, pvx := range pvv {
   411  					b, parseErr := strconv.ParseBool(bytex.ToString(pvx))
   412  					if parseErr != nil {
   413  						err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not bool", name))
   414  						return
   415  					}
   416  					e := reflect.New(eft).Elem()
   417  					if e.Type() == boolType {
   418  						e.SetBool(b)
   419  					} else {
   420  						e.Set(reflect.ValueOf(b).Convert(e.Type()))
   421  					}
   422  					bb = reflect.Append(bb, e)
   423  				}
   424  				fv.Set(bb)
   425  				break
   426  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   427  				nn := reflect.MakeSlice(ft.Type, 0, 1)
   428  				for _, pvx := range pvv {
   429  					n, parseErr := strconv.ParseInt(bytex.ToString(pvx), 10, 64)
   430  					if parseErr != nil {
   431  						err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not int", name))
   432  						return
   433  					}
   434  					e := reflect.New(eft).Elem()
   435  					if e.Type() == intType || e.Type() == int32Type || e.Type() == int64Type {
   436  						e.SetInt(n)
   437  					} else {
   438  						e.Set(reflect.ValueOf(n).Convert(e.Type()))
   439  					}
   440  					nn = reflect.Append(nn, e)
   441  				}
   442  				fv.Set(nn)
   443  				break
   444  			case reflect.Float32, reflect.Float64:
   445  				ff := reflect.MakeSlice(ft.Type, 0, 1)
   446  				for _, pvx := range pvv {
   447  					f, parseErr := strconv.ParseFloat(bytex.ToString(pvx), 64)
   448  					if parseErr != nil {
   449  						err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not float", name))
   450  						return
   451  					}
   452  					e := reflect.New(eft).Elem()
   453  					if e.Type() == float32Type || e.Type() == float64Type {
   454  						e.SetFloat(f)
   455  					} else {
   456  						e.Set(reflect.ValueOf(f).Convert(e.Type()))
   457  					}
   458  					ff = reflect.Append(ff, e)
   459  				}
   460  				fv.Set(ff)
   461  				break
   462  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   463  				uu := reflect.MakeSlice(ft.Type, 0, 1)
   464  				for _, pvx := range pvv {
   465  					u, parseErr := strconv.ParseUint(bytex.ToString(pvx), 10, 64)
   466  					if parseErr != nil {
   467  						err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not uint", name))
   468  						return
   469  					}
   470  					e := reflect.New(eft).Elem()
   471  					if e.Type() == uintType || e.Type() == uint32Type || e.Type() == uint64Type {
   472  						e.SetUint(u)
   473  					} else {
   474  						e.Set(reflect.ValueOf(u).Convert(e.Type()))
   475  					}
   476  					uu = reflect.Append(uu, e)
   477  				}
   478  				fv.Set(uu)
   479  				break
   480  			case reflect.Struct:
   481  				if eft == timeType || timeType.ConvertibleTo(eft) {
   482  					tt := reflect.MakeSlice(ft.Type, 0, 1)
   483  					for _, pvx := range pvv {
   484  						t, parseErr := time.Parse(time.RFC3339, bytex.ToString(pvx))
   485  						if parseErr != nil {
   486  							err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("%s is not RFC3339 time", name))
   487  							return
   488  						}
   489  						e := reflect.New(eft).Elem()
   490  						if e.Type() == timeType {
   491  							e.Set(reflect.ValueOf(t))
   492  						} else {
   493  							e.Set(reflect.ValueOf(t).Convert(e.Type()))
   494  						}
   495  						tt = reflect.Append(tt, e)
   496  					}
   497  					fv.Set(tt)
   498  				} else {
   499  					err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("type of %s is not supported", name))
   500  					return
   501  				}
   502  				break
   503  			default:
   504  				err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("type of %s is not supported", name))
   505  				return
   506  			}
   507  			break
   508  		default:
   509  			err = errors.Warning("fns: decode param failed").WithCause(fmt.Errorf("type of %s is not supported", name))
   510  			return
   511  		}
   512  	}
   513  	return
   514  }