github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/transport_http/transform/parameter_meta_test.go (about)

     1  package transform
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/go-courier/ptr"
    14  	"github.com/stretchr/testify/assert"
    15  
    16  	"github.com/artisanhe/tools/courier/status_error"
    17  	"github.com/artisanhe/tools/reflectx"
    18  	"github.com/artisanhe/tools/timelib"
    19  )
    20  
    21  type JSON string
    22  
    23  func (p JSON) MarshalJSON() ([]byte, error) {
    24  	return []byte("json"), nil
    25  }
    26  
    27  func (p *JSON) UnmarshalJSON(data []byte) error {
    28  	*p = JSON("json")
    29  	return nil
    30  }
    31  
    32  type P string
    33  
    34  func (p P) MarshalText() ([]byte, error) {
    35  	return []byte("parameter"), nil
    36  }
    37  
    38  func (p *P) UnmarshalText(data []byte) error {
    39  	*p = P(string(data))
    40  	if string(*p) == "error" {
    41  		return fmt.Errorf("error")
    42  	}
    43  	*p = P("parameter")
    44  	return nil
    45  }
    46  
    47  type Uint64List []uint64
    48  
    49  func (list Uint64List) MarshalJSON() ([]byte, error) {
    50  	if len(list) == 0 {
    51  		return []byte(`[]`), nil
    52  	}
    53  	strValues := make([]string, 0)
    54  	for _, v := range list {
    55  		strValues = append(strValues, fmt.Sprintf(`"%d"`, v))
    56  	}
    57  	return []byte(`[` + strings.Join(strValues, ",") + `]`), nil
    58  }
    59  
    60  func (list *Uint64List) UnmarshalJSON(data []byte) (err error) {
    61  	strValues := make([]string, 0)
    62  	err = json.Unmarshal(data, &strValues)
    63  	if err != nil {
    64  		return err
    65  	}
    66  	finalList := Uint64List{}
    67  	for i, strValue := range strValues {
    68  		v, parseErr := strconv.ParseUint(strValue, 10, 64)
    69  		if parseErr != nil {
    70  			err = fmt.Errorf(`[%d] cannot unmarshal string into value of type uint64`, i)
    71  			return
    72  		}
    73  		finalList = append(finalList, v)
    74  	}
    75  	*list = finalList
    76  	return
    77  }
    78  
    79  func TestParameterMeta(t *testing.T) {
    80  	tt := assert.New(t)
    81  
    82  	type Data struct {
    83  		String  string  `json:"string" validate:"@string[3,]"`
    84  		Pointer *string `json:"pointer" validate:"@string[3,]"`
    85  	}
    86  
    87  	type SomeReq struct {
    88  		P                   P                      `name:"p" in:"query"`
    89  		Slice               []string               `name:"slice" in:"query"`
    90  		Array               [5]string              `name:"array" in:"query"`
    91  		Query               string                 `name:"query" in:"query" validate:"@string[3,]"`
    92  		Uint64List          Uint64List             `name:"Uint64List" in:"query" default:"" validate:"@array[1,]:@uint64[3,]"`
    93  		QueryWithDefaults   string                 `name:"queryWithDefaults" in:"query" default:"123123" validate:"@string[3,]"`
    94  		Pointer             *string                `name:"pointerWithDefaults" in:"path" default:"123123" validate:"@string[3,]"`
    95  		PointerWithDefaults *string                `name:"pointer" in:"path" validate:"@string[3,]"`
    96  		Bytes               []byte                 `name:"bytes" in:"query" style:"form,explode" `
    97  		MustCreateTime      timelib.MySQLTimestamp `in:"query"`
    98  		CreateTime          timelib.MySQLTimestamp `name:"createTime" in:"query" default:""`
    99  		UpdateTime          timelib.MySQLTimestamp `name:"updateTime,omitempty" in:"query"`
   100  		Data                Data                   `in:"body"`
   101  		PtrData             *Data                  `in:"body"`
   102  		DataSlice           []string               `in:"body"`
   103  		DataArray           [5]string              `in:"body"`
   104  	}
   105  
   106  	req := &SomeReq{}
   107  	tpe := reflectx.IndirectType(reflect.TypeOf(req))
   108  	rv := reflect.Indirect(reflect.ValueOf(req))
   109  
   110  	for i := 0; i < tpe.NumField(); i++ {
   111  		field := tpe.Field(i)
   112  		fieldValue := rv.Field(i)
   113  
   114  		tagIn, _, tagInFlags := GetTagIn(&field)
   115  
   116  		p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags)
   117  
   118  		switch p.Name {
   119  		case "Uint64List":
   120  			{
   121  				err := p.UnmarshalStringAndValidate()
   122  				tt.Equal(status_error.ErrorFields{
   123  					{
   124  						In:    "query",
   125  						Field: "Uint64List",
   126  						Msg:   "切片元素个数不在[1, 0]范围内,当前个数:0",
   127  					},
   128  				}, status_error.FromError(err).ErrorFields)
   129  			}
   130  			{
   131  				err := p.UnmarshalStringAndValidate("123", "123")
   132  				tt.NoError(err)
   133  				tt.Equal(Uint64List{123, 123}, p.Value.Interface())
   134  			}
   135  		case "queryWithDefaults":
   136  			{
   137  				err := p.UnmarshalStringAndValidate("")
   138  				tt.NoError(err)
   139  				tt.Equal("123123", p.Value.Interface())
   140  			}
   141  		case "pointerWithDefaults":
   142  			{
   143  				err := p.UnmarshalStringAndValidate("")
   144  				tt.NoError(err)
   145  				tt.Equal(ptr.String("123123"), p.Value.Interface())
   146  			}
   147  		case "MustCreateTime":
   148  			{
   149  				err := p.UnmarshalStringAndValidate("")
   150  				tt.NotNil(err)
   151  				tt.Equal(status_error.ErrorFields{
   152  					{
   153  						In:    "query",
   154  						Field: "MustCreateTime",
   155  						Msg:   ErrMsgForRequired,
   156  					},
   157  				}, status_error.FromError(err).ErrorFields)
   158  			}
   159  		case "query":
   160  			{
   161  				err := p.UnmarshalStringAndValidate("")
   162  				tt.NotNil(err)
   163  				tt.Equal(status_error.ErrorFields{
   164  					{
   165  						In:    "query",
   166  						Field: "query",
   167  						Msg:   ErrMsgForRequired,
   168  					},
   169  				}, status_error.FromError(err).ErrorFields)
   170  			}
   171  
   172  			{
   173  				err := p.UnmarshalStringAndValidate("100")
   174  				tt.Nil(err)
   175  				tt.Equal("100", p.Value.Interface())
   176  			}
   177  		case "p":
   178  			{
   179  				err := p.UnmarshalStringAndValidate("error")
   180  				tt.Error(err)
   181  				tt.Equal(status_error.ErrorFields{
   182  					{
   183  						In:    "query",
   184  						Field: "p",
   185  						Msg:   "error",
   186  					},
   187  				}, status_error.FromError(err).ErrorFields)
   188  			}
   189  			{
   190  				err := p.UnmarshalStringAndValidate("100")
   191  				tt.Nil(err)
   192  				tt.Equal(P("parameter"), p.Value.Interface())
   193  			}
   194  		case "pointer":
   195  			{
   196  				err := p.UnmarshalStringAndValidate("")
   197  				tt.Nil(p.Value.Interface())
   198  				tt.NotNil(err)
   199  				tt.Equal(status_error.ErrorFields{
   200  					status_error.NewErrorField("path", "pointer", ErrMsgForRequired),
   201  				}, status_error.FromError(err).ErrorFields.Sort())
   202  			}
   203  			{
   204  				err := p.UnmarshalStringAndValidate("10")
   205  				tt.Equal(ptr.String("10"), p.Value.Interface())
   206  				tt.Error(err)
   207  				tt.Equal(status_error.ErrorFields{
   208  					status_error.NewErrorField("path", "pointer", "字符串长度不在[3, 1024]范围内,当前长度:2"),
   209  				}, status_error.FromError(err).ErrorFields.Sort())
   210  			}
   211  			{
   212  				err := p.UnmarshalStringAndValidate("100")
   213  				tt.Equal(ptr.String("100"), p.Value.Interface())
   214  				tt.NoError(err)
   215  			}
   216  		case "bytes":
   217  			{
   218  				err := p.UnmarshalStringAndValidate("111")
   219  				tt.NoError(err)
   220  				tt.Equal([]byte("111"), p.Value.Interface())
   221  			}
   222  		case "slice":
   223  			{
   224  				err := p.UnmarshalStringAndValidate("111", "222")
   225  				tt.NoError(err)
   226  				tt.Equal([]string{"111", "222"}, p.Value.Interface())
   227  			}
   228  		case "array":
   229  			{
   230  				err := p.UnmarshalStringAndValidate("111", "222")
   231  				tt.NoError(err)
   232  				tt.Equal([5]string{"111", "222", "", "", ""}, p.Value.Interface())
   233  			}
   234  		case "updateTime":
   235  			{
   236  				err := p.UnmarshalStringAndValidate()
   237  				tt.NoError(err)
   238  			}
   239  		case "createTime":
   240  			{
   241  				err := p.UnmarshalStringAndValidate("2017-10-10T00:00:00Z")
   242  				tt.Nil(err)
   243  				d, _ := timelib.ParseMySQLTimestampFromString("2017-10-10T00:00:00Z")
   244  				tt.Equal(d, p.Value.Interface())
   245  			}
   246  		case "PtrData":
   247  			{
   248  				buf := bytes.NewBufferString(`{"string":"1"}`)
   249  				err := p.UnmarshalFromReader(buf)
   250  				tt.NotNil(err)
   251  				tt.Equal(status_error.ErrorFields{
   252  					status_error.NewErrorField("body", "pointer", ErrMsgForRequired),
   253  					status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"),
   254  				}, status_error.FromError(err).ErrorFields.Sort())
   255  			}
   256  		case "Data":
   257  			{
   258  				err := p.UnmarshalFromReader(nil)
   259  				tt.NotNil(err)
   260  				tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code)
   261  			}
   262  
   263  			{
   264  				file, _ := ioutil.TempFile("", "")
   265  				file.Close()
   266  				err := p.UnmarshalFromReader(file)
   267  				tt.NotNil(err)
   268  				tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code)
   269  			}
   270  
   271  			{
   272  				file, _ := ioutil.TempFile("", "")
   273  				err := p.UnmarshalFromReader(file)
   274  				tt.NotNil(err)
   275  				tt.Equal(int64(status_error.InvalidBodyStruct), status_error.FromError(err).Code)
   276  				file.Close()
   277  			}
   278  			{
   279  				buf := bytes.NewBufferString(`{"string":"1"}`)
   280  				err := p.UnmarshalFromReader(buf)
   281  				tt.NotNil(err)
   282  				tt.Equal(status_error.ErrorFields{
   283  					status_error.NewErrorField("body", "pointer", ErrMsgForRequired),
   284  					status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"),
   285  				}, status_error.FromError(err).ErrorFields.Sort())
   286  			}
   287  			{
   288  				buf := bytes.NewBufferString(`{"string":"111", "pointer":1}`)
   289  				err := p.UnmarshalFromReader(buf)
   290  				tt.NotNil(err)
   291  				tt.Equal(status_error.ErrorFields{
   292  					status_error.NewErrorField("body", "pointer", "json: cannot unmarshal number into Go struct field Data.pointer of type string"),
   293  				}, status_error.FromError(err).ErrorFields.Sort())
   294  			}
   295  			{
   296  				buf := bytes.NewBufferString(`{"string":"111","pointer":"111"}`)
   297  				err := p.UnmarshalFromReader(buf)
   298  				tt.Nil(err)
   299  				tt.Equal(Data{
   300  					String:  "111",
   301  					Pointer: ptr.String("111"),
   302  				}, p.Value.Interface())
   303  			}
   304  		case "DataSlice":
   305  			{
   306  				buf := bytes.NewBufferString(`["123","123"]`)
   307  				err := p.UnmarshalFromReader(buf)
   308  				tt.Nil(err)
   309  				tt.Equal([]string{
   310  					"123",
   311  					"123",
   312  				}, p.Value.Interface())
   313  			}
   314  		case "DataArray":
   315  			{
   316  				buf := bytes.NewBufferString(`["123","123"]`)
   317  				err := p.UnmarshalFromReader(buf)
   318  				tt.Nil(err)
   319  				tt.Equal([5]string{
   320  					"123",
   321  					"123",
   322  					"",
   323  					"",
   324  					"",
   325  				}, p.Value.Interface())
   326  			}
   327  		}
   328  	}
   329  }
   330  
   331  func TestParameterMeta_Marshal(t *testing.T) {
   332  	tt := assert.New(t)
   333  
   334  	type Data struct {
   335  		String string `json:"string" validate:"@string[3,]"`
   336  	}
   337  
   338  	type SomeReq struct {
   339  		JSON       JSON                   `name:"json" in:"query"`
   340  		P          P                      `name:"p" in:"query"`
   341  		Bytes      []byte                 `name:"bytes" in:"query" style:"form,explode" `
   342  		Slice      []string               `name:"slice" in:"query"`
   343  		Query      string                 `name:"query" in:"query" validate:"@string[3,]"`
   344  		Pointer    *string                `name:"pointer" in:"path" validate:"@string[3,]"`
   345  		Pointer2   *string                `name:"pointerIgnore" in:"path" validate:"@string[3,]"`
   346  		CreateTime timelib.MySQLTimestamp `name:"createTime" in:"query" default:""`
   347  		Data       Data                   `in:"body"`
   348  		DataSlice  []string               `in:"body"`
   349  		FormData   string                 `in:"formData"`
   350  	}
   351  
   352  	req := &SomeReq{
   353  		P:       "!",
   354  		Query:   "query",
   355  		Pointer: ptr.String("pointer"),
   356  		Bytes:   []byte("bytes"),
   357  		Slice: []string{
   358  			"1", "2",
   359  		},
   360  		Data: Data{
   361  			String: "string",
   362  		},
   363  		DataSlice: []string{
   364  			"1", "2",
   365  		},
   366  		FormData:   "1",
   367  		CreateTime: timelib.MySQLTimestamp(timelib.Now()),
   368  	}
   369  
   370  	tpe := reflectx.IndirectType(reflect.TypeOf(req))
   371  	rv := reflect.Indirect(reflect.ValueOf(req))
   372  
   373  	for i := 0; i < tpe.NumField(); i++ {
   374  		field := tpe.Field(i)
   375  		fieldValue := rv.Field(i)
   376  
   377  		tagIn, _, tagInFlags := GetTagIn(&field)
   378  
   379  		p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags)
   380  
   381  		switch p.Name {
   382  		case "json":
   383  			dataList, _ := p.Marshal()
   384  			tt.Equal(BytesList("json"), dataList)
   385  		case "query":
   386  			dataList, _ := p.Marshal()
   387  			tt.Equal(BytesList("query"), dataList)
   388  		case "p":
   389  			dataList, _ := p.Marshal()
   390  			tt.Equal(BytesList("parameter"), dataList)
   391  		case "pointer":
   392  			dataList, _ := p.Marshal()
   393  			tt.Equal(BytesList("pointer"), dataList)
   394  		case "createTime":
   395  			dataList, _ := p.Marshal()
   396  			tt.Equal(BytesList(req.CreateTime.String()), dataList)
   397  		case "slice":
   398  			dataList, _ := p.Marshal()
   399  			tt.Equal(BytesList("1", "2"), dataList)
   400  		case "pointerIgnore":
   401  			dataList, _ := p.Marshal()
   402  			tt.Nil(dataList)
   403  		case "bytes":
   404  			dataList, _ := p.Marshal()
   405  			tt.Equal(BytesList("bytes"), dataList)
   406  		case "Data":
   407  			dataList, _ := p.Marshal()
   408  			tt.Equal(BytesList(`{"string":"string"}`), dataList)
   409  		case "DataSlice":
   410  			dataList, _ := p.Marshal()
   411  			tt.Equal(BytesList(`["1","2"]`), dataList)
   412  		case "FormData":
   413  			dataList, _ := p.Marshal()
   414  			tt.Equal(BytesList("1"), dataList)
   415  		}
   416  	}
   417  }