github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/transport_http/transform/parameter_meta_test.go (about)

     1  package transform
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  
    15  	"github.com/johnnyeven/libtools/courier/status_error"
    16  	"github.com/johnnyeven/libtools/ptr"
    17  	"github.com/johnnyeven/libtools/reflectx"
    18  	"github.com/johnnyeven/libtools/timelib"
    19  )
    20  
    21  type JSON string
    22  
    23  func (p JSON) MarshalJSON() ([]byte, error) {
    24  	return []byte("json"), nil
    25  }
    26  
    27  func (p *JSON) UnmarshalJSON(data []byte) error {
    28  	*p = JSON("json")
    29  	return nil
    30  }
    31  
    32  type P string
    33  
    34  func (p P) MarshalText() ([]byte, error) {
    35  	return []byte("parameter"), nil
    36  }
    37  
    38  func (p *P) UnmarshalText(data []byte) error {
    39  	*p = P(string(data))
    40  	if string(*p) == "error" {
    41  		return fmt.Errorf("error")
    42  	}
    43  	*p = P("parameter")
    44  	return nil
    45  }
    46  
    47  type Uint64List []uint64
    48  
    49  func (list Uint64List) MarshalJSON() ([]byte, error) {
    50  	if len(list) == 0 {
    51  		return []byte(`[]`), nil
    52  	}
    53  	strValues := make([]string, 0)
    54  	for _, v := range list {
    55  		strValues = append(strValues, fmt.Sprintf(`"%d"`, v))
    56  	}
    57  	return []byte(`[` + strings.Join(strValues, ",") + `]`), nil
    58  }
    59  
    60  func (list *Uint64List) UnmarshalJSON(data []byte) (err error) {
    61  	strValues := make([]string, 0)
    62  	err = json.Unmarshal(data, &strValues)
    63  	if err != nil {
    64  		return err
    65  	}
    66  	finalList := Uint64List{}
    67  	for i, strValue := range strValues {
    68  		v, parseErr := strconv.ParseUint(strValue, 10, 64)
    69  		if parseErr != nil {
    70  			err = fmt.Errorf(`[%d] cannot unmarshal string into value of type uint64`, i)
    71  			return
    72  		}
    73  		finalList = append(finalList, v)
    74  	}
    75  	*list = finalList
    76  	return
    77  }
    78  
    79  func TestParameterMeta(t *testing.T) {
    80  	tt := assert.New(t)
    81  
    82  	type Data struct {
    83  		String  string  `json:"string" validate:"@string[3,]"`
    84  		Pointer *string `json:"pointer" validate:"@string[3,]"`
    85  	}
    86  
    87  	type SomeReq struct {
    88  		P                   P                      `name:"p" in:"query"`
    89  		Slice               []string               `name:"slice" in:"query"`
    90  		Array               [5]string              `name:"array" in:"query"`
    91  		Query               string                 `name:"query" in:"query" validate:"@string[3,]"`
    92  		Uint64List          Uint64List             `name:"Uint64List" in:"query" default:"" validate:"@array[1,]:@uint64[3,]"`
    93  		QueryWithDefaults   string                 `name:"queryWithDefaults" in:"query" default:"123123" validate:"@string[3,]"`
    94  		Pointer             *string                `name:"pointerWithDefaults" in:"path" default:"123123" validate:"@string[3,]"`
    95  		PointerWithDefaults *string                `name:"pointer" in:"path" validate:"@string[3,]"`
    96  		Bytes               []byte                 `name:"bytes" in:"query" style:"form,explode" `
    97  		CreateTime          timelib.MySQLTimestamp `name:"createTime" in:"query" default:""`
    98  		UpdateTime          timelib.MySQLTimestamp `name:"updateTime,omitempty" in:"query"`
    99  		Data                Data                   `in:"body"`
   100  		PtrData             *Data                  `in:"body"`
   101  		DataSlice           []string               `in:"body"`
   102  		DataArray           [5]string              `in:"body"`
   103  	}
   104  
   105  	req := &SomeReq{}
   106  	tpe := reflectx.IndirectType(reflect.TypeOf(req))
   107  	rv := reflect.Indirect(reflect.ValueOf(req))
   108  
   109  	for i := 0; i < tpe.NumField(); i++ {
   110  		field := tpe.Field(i)
   111  		fieldValue := rv.Field(i)
   112  
   113  		tagIn, _, tagInFlags := GetTagIn(&field)
   114  
   115  		p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags)
   116  
   117  		switch p.Name {
   118  		case "Uint64List":
   119  			{
   120  				err := p.UnmarshalStringAndValidate()
   121  				tt.Equal(status_error.ErrorFields{
   122  					{
   123  						In:    "query",
   124  						Field: "Uint64List",
   125  						Msg:   "切片元素个数不在[1, 0]范围内,当前个数:0",
   126  					},
   127  				}, status_error.FromError(err).ErrorFields)
   128  			}
   129  			{
   130  				err := p.UnmarshalStringAndValidate("123", "123")
   131  				tt.NoError(err)
   132  				tt.Equal(Uint64List{123, 123}, p.Value.Interface())
   133  			}
   134  		case "queryWithDefaults":
   135  			{
   136  				err := p.UnmarshalStringAndValidate("")
   137  				tt.NoError(err)
   138  				tt.Equal("123123", p.Value.Interface())
   139  			}
   140  		case "pointerWithDefaults":
   141  			{
   142  				err := p.UnmarshalStringAndValidate("")
   143  				tt.NoError(err)
   144  				tt.Equal(ptr.String("123123"), p.Value.Interface())
   145  			}
   146  		case "query":
   147  			{
   148  				err := p.UnmarshalStringAndValidate("")
   149  				tt.NotNil(err)
   150  				tt.Equal(status_error.ErrorFields{
   151  					{
   152  						In:    "query",
   153  						Field: "query",
   154  						Msg:   ErrMsgForRequired,
   155  					},
   156  				}, status_error.FromError(err).ErrorFields)
   157  			}
   158  
   159  			{
   160  				err := p.UnmarshalStringAndValidate("100")
   161  				tt.Nil(err)
   162  				tt.Equal("100", p.Value.Interface())
   163  			}
   164  		case "p":
   165  			{
   166  				err := p.UnmarshalStringAndValidate("error")
   167  				tt.Error(err)
   168  				tt.Equal(status_error.ErrorFields{
   169  					{
   170  						In:    "query",
   171  						Field: "p",
   172  						Msg:   "error",
   173  					},
   174  				}, status_error.FromError(err).ErrorFields)
   175  			}
   176  			{
   177  				err := p.UnmarshalStringAndValidate("100")
   178  				tt.Nil(err)
   179  				tt.Equal(P("parameter"), p.Value.Interface())
   180  			}
   181  		case "pointer":
   182  			{
   183  				err := p.UnmarshalStringAndValidate("")
   184  				tt.Nil(p.Value.Interface())
   185  				tt.NotNil(err)
   186  				tt.Equal(status_error.ErrorFields{
   187  					status_error.NewErrorField("path", "pointer", ErrMsgForRequired),
   188  				}, status_error.FromError(err).ErrorFields.Sort())
   189  			}
   190  			{
   191  				err := p.UnmarshalStringAndValidate("10")
   192  				tt.Equal(ptr.String("10"), p.Value.Interface())
   193  				tt.Error(err)
   194  				tt.Equal(status_error.ErrorFields{
   195  					status_error.NewErrorField("path", "pointer", "字符串长度不在[3, 1024]范围内,当前长度:2"),
   196  				}, status_error.FromError(err).ErrorFields.Sort())
   197  			}
   198  			{
   199  				err := p.UnmarshalStringAndValidate("100")
   200  				tt.Equal(ptr.String("100"), p.Value.Interface())
   201  				tt.NoError(err)
   202  			}
   203  		case "bytes":
   204  			{
   205  				err := p.UnmarshalStringAndValidate("111")
   206  				tt.NoError(err)
   207  				tt.Equal([]byte("111"), p.Value.Interface())
   208  			}
   209  		case "slice":
   210  			{
   211  				err := p.UnmarshalStringAndValidate("111", "222")
   212  				tt.NoError(err)
   213  				tt.Equal([]string{"111", "222"}, p.Value.Interface())
   214  			}
   215  		case "array":
   216  			{
   217  				err := p.UnmarshalStringAndValidate("111", "222")
   218  				tt.NoError(err)
   219  				tt.Equal([5]string{"111", "222", "", "", ""}, p.Value.Interface())
   220  			}
   221  		case "updateTime":
   222  			{
   223  				err := p.UnmarshalStringAndValidate()
   224  				tt.NoError(err)
   225  			}
   226  		case "createTime":
   227  			{
   228  				err := p.UnmarshalStringAndValidate("2017-10-10T00:00:00Z")
   229  				tt.Nil(err)
   230  				d, _ := timelib.ParseMySQLTimestampFromString("2017-10-10T00:00:00Z")
   231  				tt.Equal(d, p.Value.Interface())
   232  			}
   233  		case "PtrData":
   234  			{
   235  				buf := bytes.NewBufferString(`{"string":"1"}`)
   236  				err := p.UnmarshalFromReader(buf)
   237  				tt.NotNil(err)
   238  				tt.Equal(status_error.ErrorFields{
   239  					status_error.NewErrorField("body", "pointer", ErrMsgForRequired),
   240  					status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"),
   241  				}, status_error.FromError(err).ErrorFields.Sort())
   242  			}
   243  		case "Data":
   244  			{
   245  				err := p.UnmarshalFromReader(nil)
   246  				tt.NotNil(err)
   247  				tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code)
   248  			}
   249  
   250  			{
   251  				file, _ := ioutil.TempFile("", "")
   252  				file.Close()
   253  				err := p.UnmarshalFromReader(file)
   254  				tt.NotNil(err)
   255  				tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code)
   256  			}
   257  
   258  			{
   259  				file, _ := ioutil.TempFile("", "")
   260  				err := p.UnmarshalFromReader(file)
   261  				tt.NotNil(err)
   262  				tt.Equal(int64(status_error.InvalidBodyStruct), status_error.FromError(err).Code)
   263  				file.Close()
   264  			}
   265  			{
   266  				buf := bytes.NewBufferString(`{"string":"1"}`)
   267  				err := p.UnmarshalFromReader(buf)
   268  				tt.NotNil(err)
   269  				tt.Equal(status_error.ErrorFields{
   270  					status_error.NewErrorField("body", "pointer", ErrMsgForRequired),
   271  					status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"),
   272  				}, status_error.FromError(err).ErrorFields.Sort())
   273  			}
   274  			{
   275  				buf := bytes.NewBufferString(`{"string":"111", "pointer":1}`)
   276  				err := p.UnmarshalFromReader(buf)
   277  				tt.NotNil(err)
   278  				tt.Equal(status_error.ErrorFields{
   279  					status_error.NewErrorField("body", "pointer", "json: cannot unmarshal number into Go struct field Data.pointer of type string"),
   280  				}, status_error.FromError(err).ErrorFields.Sort())
   281  			}
   282  			{
   283  				buf := bytes.NewBufferString(`{"string":"111","pointer":"111"}`)
   284  				err := p.UnmarshalFromReader(buf)
   285  				tt.Nil(err)
   286  				tt.Equal(Data{
   287  					String:  "111",
   288  					Pointer: ptr.String("111"),
   289  				}, p.Value.Interface())
   290  			}
   291  		case "DataSlice":
   292  			{
   293  				buf := bytes.NewBufferString(`["123","123"]`)
   294  				err := p.UnmarshalFromReader(buf)
   295  				tt.Nil(err)
   296  				tt.Equal([]string{
   297  					"123",
   298  					"123",
   299  				}, p.Value.Interface())
   300  			}
   301  		case "DataArray":
   302  			{
   303  				buf := bytes.NewBufferString(`["123","123"]`)
   304  				err := p.UnmarshalFromReader(buf)
   305  				tt.Nil(err)
   306  				tt.Equal([5]string{
   307  					"123",
   308  					"123",
   309  					"",
   310  					"",
   311  					"",
   312  				}, p.Value.Interface())
   313  			}
   314  		}
   315  	}
   316  }
   317  
   318  func TestParameterMeta_Marshal(t *testing.T) {
   319  	tt := assert.New(t)
   320  
   321  	type Data struct {
   322  		String string `json:"string" validate:"@string[3,]"`
   323  	}
   324  
   325  	type SomeReq struct {
   326  		JSON       JSON                   `name:"json" in:"query"`
   327  		P          P                      `name:"p" in:"query"`
   328  		Bytes      []byte                 `name:"bytes" in:"query" style:"form,explode" `
   329  		Slice      []string               `name:"slice" in:"query"`
   330  		Query      string                 `name:"query" in:"query" validate:"@string[3,]"`
   331  		Pointer    *string                `name:"pointer" in:"path" validate:"@string[3,]"`
   332  		Pointer2   *string                `name:"pointerIgnore" in:"path" validate:"@string[3,]"`
   333  		CreateTime timelib.MySQLTimestamp `name:"createTime" in:"query" default:""`
   334  		Data       Data                   `in:"body"`
   335  		DataSlice  []string               `in:"body"`
   336  		FormData   string                 `in:"formData"`
   337  	}
   338  
   339  	req := &SomeReq{
   340  		P:       "!",
   341  		Query:   "query",
   342  		Pointer: ptr.String("pointer"),
   343  		Bytes:   []byte("bytes"),
   344  		Slice: []string{
   345  			"1", "2",
   346  		},
   347  		Data: Data{
   348  			String: "string",
   349  		},
   350  		DataSlice: []string{
   351  			"1", "2",
   352  		},
   353  		FormData:   "1",
   354  		CreateTime: timelib.MySQLTimestamp(timelib.Now()),
   355  	}
   356  
   357  	tpe := reflectx.IndirectType(reflect.TypeOf(req))
   358  	rv := reflect.Indirect(reflect.ValueOf(req))
   359  
   360  	for i := 0; i < tpe.NumField(); i++ {
   361  		field := tpe.Field(i)
   362  		fieldValue := rv.Field(i)
   363  
   364  		tagIn, _, tagInFlags := GetTagIn(&field)
   365  
   366  		p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags)
   367  
   368  		switch p.Name {
   369  		case "json":
   370  			dataList, _ := p.Marshal()
   371  			tt.Equal(BytesList("json"), dataList)
   372  		case "query":
   373  			dataList, _ := p.Marshal()
   374  			tt.Equal(BytesList("query"), dataList)
   375  		case "p":
   376  			dataList, _ := p.Marshal()
   377  			tt.Equal(BytesList("parameter"), dataList)
   378  		case "pointer":
   379  			dataList, _ := p.Marshal()
   380  			tt.Equal(BytesList("pointer"), dataList)
   381  		case "createTime":
   382  			dataList, _ := p.Marshal()
   383  			tt.Equal(BytesList(req.CreateTime.String()), dataList)
   384  		case "slice":
   385  			dataList, _ := p.Marshal()
   386  			tt.Equal(BytesList("1", "2"), dataList)
   387  		case "pointerIgnore":
   388  			dataList, _ := p.Marshal()
   389  			tt.Nil(dataList)
   390  		case "bytes":
   391  			dataList, _ := p.Marshal()
   392  			tt.Equal(BytesList("bytes"), dataList)
   393  		case "Data":
   394  			dataList, _ := p.Marshal()
   395  			tt.Equal(BytesList(`{"string":"string"}`), dataList)
   396  		case "DataSlice":
   397  			dataList, _ := p.Marshal()
   398  			tt.Equal(BytesList(`["1","2"]`), dataList)
   399  		case "FormData":
   400  			dataList, _ := p.Marshal()
   401  			tt.Equal(BytesList("1"), dataList)
   402  		}
   403  	}
   404  }