github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/transport_http/transform/parameter_meta_test.go (about) 1 package transform 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "reflect" 9 "strconv" 10 "strings" 11 "testing" 12 13 "github.com/go-courier/ptr" 14 "github.com/stretchr/testify/assert" 15 16 "github.com/artisanhe/tools/courier/status_error" 17 "github.com/artisanhe/tools/reflectx" 18 "github.com/artisanhe/tools/timelib" 19 ) 20 21 type JSON string 22 23 func (p JSON) MarshalJSON() ([]byte, error) { 24 return []byte("json"), nil 25 } 26 27 func (p *JSON) UnmarshalJSON(data []byte) error { 28 *p = JSON("json") 29 return nil 30 } 31 32 type P string 33 34 func (p P) MarshalText() ([]byte, error) { 35 return []byte("parameter"), nil 36 } 37 38 func (p *P) UnmarshalText(data []byte) error { 39 *p = P(string(data)) 40 if string(*p) == "error" { 41 return fmt.Errorf("error") 42 } 43 *p = P("parameter") 44 return nil 45 } 46 47 type Uint64List []uint64 48 49 func (list Uint64List) MarshalJSON() ([]byte, error) { 50 if len(list) == 0 { 51 return []byte(`[]`), nil 52 } 53 strValues := make([]string, 0) 54 for _, v := range list { 55 strValues = append(strValues, fmt.Sprintf(`"%d"`, v)) 56 } 57 return []byte(`[` + strings.Join(strValues, ",") + `]`), nil 58 } 59 60 func (list *Uint64List) UnmarshalJSON(data []byte) (err error) { 61 strValues := make([]string, 0) 62 err = json.Unmarshal(data, &strValues) 63 if err != nil { 64 return err 65 } 66 finalList := Uint64List{} 67 for i, strValue := range strValues { 68 v, parseErr := strconv.ParseUint(strValue, 10, 64) 69 if parseErr != nil { 70 err = fmt.Errorf(`[%d] cannot unmarshal string into value of type uint64`, i) 71 return 72 } 73 finalList = append(finalList, v) 74 } 75 *list = finalList 76 return 77 } 78 79 func TestParameterMeta(t *testing.T) { 80 tt := assert.New(t) 81 82 type Data struct { 83 String string `json:"string" validate:"@string[3,]"` 84 Pointer *string `json:"pointer" validate:"@string[3,]"` 85 } 86 87 type SomeReq struct { 88 P P `name:"p" in:"query"` 89 Slice []string `name:"slice" in:"query"` 90 Array [5]string `name:"array" in:"query"` 91 Query string `name:"query" in:"query" validate:"@string[3,]"` 92 Uint64List Uint64List `name:"Uint64List" in:"query" default:"" validate:"@array[1,]:@uint64[3,]"` 93 QueryWithDefaults string `name:"queryWithDefaults" in:"query" default:"123123" validate:"@string[3,]"` 94 Pointer *string `name:"pointerWithDefaults" in:"path" default:"123123" validate:"@string[3,]"` 95 PointerWithDefaults *string `name:"pointer" in:"path" validate:"@string[3,]"` 96 Bytes []byte `name:"bytes" in:"query" style:"form,explode" ` 97 MustCreateTime timelib.MySQLTimestamp `in:"query"` 98 CreateTime timelib.MySQLTimestamp `name:"createTime" in:"query" default:""` 99 UpdateTime timelib.MySQLTimestamp `name:"updateTime,omitempty" in:"query"` 100 Data Data `in:"body"` 101 PtrData *Data `in:"body"` 102 DataSlice []string `in:"body"` 103 DataArray [5]string `in:"body"` 104 } 105 106 req := &SomeReq{} 107 tpe := reflectx.IndirectType(reflect.TypeOf(req)) 108 rv := reflect.Indirect(reflect.ValueOf(req)) 109 110 for i := 0; i < tpe.NumField(); i++ { 111 field := tpe.Field(i) 112 fieldValue := rv.Field(i) 113 114 tagIn, _, tagInFlags := GetTagIn(&field) 115 116 p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags) 117 118 switch p.Name { 119 case "Uint64List": 120 { 121 err := p.UnmarshalStringAndValidate() 122 tt.Equal(status_error.ErrorFields{ 123 { 124 In: "query", 125 Field: "Uint64List", 126 Msg: "切片元素个数不在[1, 0]范围内,当前个数:0", 127 }, 128 }, status_error.FromError(err).ErrorFields) 129 } 130 { 131 err := p.UnmarshalStringAndValidate("123", "123") 132 tt.NoError(err) 133 tt.Equal(Uint64List{123, 123}, p.Value.Interface()) 134 } 135 case "queryWithDefaults": 136 { 137 err := p.UnmarshalStringAndValidate("") 138 tt.NoError(err) 139 tt.Equal("123123", p.Value.Interface()) 140 } 141 case "pointerWithDefaults": 142 { 143 err := p.UnmarshalStringAndValidate("") 144 tt.NoError(err) 145 tt.Equal(ptr.String("123123"), p.Value.Interface()) 146 } 147 case "MustCreateTime": 148 { 149 err := p.UnmarshalStringAndValidate("") 150 tt.NotNil(err) 151 tt.Equal(status_error.ErrorFields{ 152 { 153 In: "query", 154 Field: "MustCreateTime", 155 Msg: ErrMsgForRequired, 156 }, 157 }, status_error.FromError(err).ErrorFields) 158 } 159 case "query": 160 { 161 err := p.UnmarshalStringAndValidate("") 162 tt.NotNil(err) 163 tt.Equal(status_error.ErrorFields{ 164 { 165 In: "query", 166 Field: "query", 167 Msg: ErrMsgForRequired, 168 }, 169 }, status_error.FromError(err).ErrorFields) 170 } 171 172 { 173 err := p.UnmarshalStringAndValidate("100") 174 tt.Nil(err) 175 tt.Equal("100", p.Value.Interface()) 176 } 177 case "p": 178 { 179 err := p.UnmarshalStringAndValidate("error") 180 tt.Error(err) 181 tt.Equal(status_error.ErrorFields{ 182 { 183 In: "query", 184 Field: "p", 185 Msg: "error", 186 }, 187 }, status_error.FromError(err).ErrorFields) 188 } 189 { 190 err := p.UnmarshalStringAndValidate("100") 191 tt.Nil(err) 192 tt.Equal(P("parameter"), p.Value.Interface()) 193 } 194 case "pointer": 195 { 196 err := p.UnmarshalStringAndValidate("") 197 tt.Nil(p.Value.Interface()) 198 tt.NotNil(err) 199 tt.Equal(status_error.ErrorFields{ 200 status_error.NewErrorField("path", "pointer", ErrMsgForRequired), 201 }, status_error.FromError(err).ErrorFields.Sort()) 202 } 203 { 204 err := p.UnmarshalStringAndValidate("10") 205 tt.Equal(ptr.String("10"), p.Value.Interface()) 206 tt.Error(err) 207 tt.Equal(status_error.ErrorFields{ 208 status_error.NewErrorField("path", "pointer", "字符串长度不在[3, 1024]范围内,当前长度:2"), 209 }, status_error.FromError(err).ErrorFields.Sort()) 210 } 211 { 212 err := p.UnmarshalStringAndValidate("100") 213 tt.Equal(ptr.String("100"), p.Value.Interface()) 214 tt.NoError(err) 215 } 216 case "bytes": 217 { 218 err := p.UnmarshalStringAndValidate("111") 219 tt.NoError(err) 220 tt.Equal([]byte("111"), p.Value.Interface()) 221 } 222 case "slice": 223 { 224 err := p.UnmarshalStringAndValidate("111", "222") 225 tt.NoError(err) 226 tt.Equal([]string{"111", "222"}, p.Value.Interface()) 227 } 228 case "array": 229 { 230 err := p.UnmarshalStringAndValidate("111", "222") 231 tt.NoError(err) 232 tt.Equal([5]string{"111", "222", "", "", ""}, p.Value.Interface()) 233 } 234 case "updateTime": 235 { 236 err := p.UnmarshalStringAndValidate() 237 tt.NoError(err) 238 } 239 case "createTime": 240 { 241 err := p.UnmarshalStringAndValidate("2017-10-10T00:00:00Z") 242 tt.Nil(err) 243 d, _ := timelib.ParseMySQLTimestampFromString("2017-10-10T00:00:00Z") 244 tt.Equal(d, p.Value.Interface()) 245 } 246 case "PtrData": 247 { 248 buf := bytes.NewBufferString(`{"string":"1"}`) 249 err := p.UnmarshalFromReader(buf) 250 tt.NotNil(err) 251 tt.Equal(status_error.ErrorFields{ 252 status_error.NewErrorField("body", "pointer", ErrMsgForRequired), 253 status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"), 254 }, status_error.FromError(err).ErrorFields.Sort()) 255 } 256 case "Data": 257 { 258 err := p.UnmarshalFromReader(nil) 259 tt.NotNil(err) 260 tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code) 261 } 262 263 { 264 file, _ := ioutil.TempFile("", "") 265 file.Close() 266 err := p.UnmarshalFromReader(file) 267 tt.NotNil(err) 268 tt.Equal(int64(status_error.ReadFailed), status_error.FromError(err).Code) 269 } 270 271 { 272 file, _ := ioutil.TempFile("", "") 273 err := p.UnmarshalFromReader(file) 274 tt.NotNil(err) 275 tt.Equal(int64(status_error.InvalidBodyStruct), status_error.FromError(err).Code) 276 file.Close() 277 } 278 { 279 buf := bytes.NewBufferString(`{"string":"1"}`) 280 err := p.UnmarshalFromReader(buf) 281 tt.NotNil(err) 282 tt.Equal(status_error.ErrorFields{ 283 status_error.NewErrorField("body", "pointer", ErrMsgForRequired), 284 status_error.NewErrorField("body", "string", "字符串长度不在[3, 1024]范围内,当前长度:1"), 285 }, status_error.FromError(err).ErrorFields.Sort()) 286 } 287 { 288 buf := bytes.NewBufferString(`{"string":"111", "pointer":1}`) 289 err := p.UnmarshalFromReader(buf) 290 tt.NotNil(err) 291 tt.Equal(status_error.ErrorFields{ 292 status_error.NewErrorField("body", "pointer", "json: cannot unmarshal number into Go struct field Data.pointer of type string"), 293 }, status_error.FromError(err).ErrorFields.Sort()) 294 } 295 { 296 buf := bytes.NewBufferString(`{"string":"111","pointer":"111"}`) 297 err := p.UnmarshalFromReader(buf) 298 tt.Nil(err) 299 tt.Equal(Data{ 300 String: "111", 301 Pointer: ptr.String("111"), 302 }, p.Value.Interface()) 303 } 304 case "DataSlice": 305 { 306 buf := bytes.NewBufferString(`["123","123"]`) 307 err := p.UnmarshalFromReader(buf) 308 tt.Nil(err) 309 tt.Equal([]string{ 310 "123", 311 "123", 312 }, p.Value.Interface()) 313 } 314 case "DataArray": 315 { 316 buf := bytes.NewBufferString(`["123","123"]`) 317 err := p.UnmarshalFromReader(buf) 318 tt.Nil(err) 319 tt.Equal([5]string{ 320 "123", 321 "123", 322 "", 323 "", 324 "", 325 }, p.Value.Interface()) 326 } 327 } 328 } 329 } 330 331 func TestParameterMeta_Marshal(t *testing.T) { 332 tt := assert.New(t) 333 334 type Data struct { 335 String string `json:"string" validate:"@string[3,]"` 336 } 337 338 type SomeReq struct { 339 JSON JSON `name:"json" in:"query"` 340 P P `name:"p" in:"query"` 341 Bytes []byte `name:"bytes" in:"query" style:"form,explode" ` 342 Slice []string `name:"slice" in:"query"` 343 Query string `name:"query" in:"query" validate:"@string[3,]"` 344 Pointer *string `name:"pointer" in:"path" validate:"@string[3,]"` 345 Pointer2 *string `name:"pointerIgnore" in:"path" validate:"@string[3,]"` 346 CreateTime timelib.MySQLTimestamp `name:"createTime" in:"query" default:""` 347 Data Data `in:"body"` 348 DataSlice []string `in:"body"` 349 FormData string `in:"formData"` 350 } 351 352 req := &SomeReq{ 353 P: "!", 354 Query: "query", 355 Pointer: ptr.String("pointer"), 356 Bytes: []byte("bytes"), 357 Slice: []string{ 358 "1", "2", 359 }, 360 Data: Data{ 361 String: "string", 362 }, 363 DataSlice: []string{ 364 "1", "2", 365 }, 366 FormData: "1", 367 CreateTime: timelib.MySQLTimestamp(timelib.Now()), 368 } 369 370 tpe := reflectx.IndirectType(reflect.TypeOf(req)) 371 rv := reflect.Indirect(reflect.ValueOf(req)) 372 373 for i := 0; i < tpe.NumField(); i++ { 374 field := tpe.Field(i) 375 fieldValue := rv.Field(i) 376 377 tagIn, _, tagInFlags := GetTagIn(&field) 378 379 p := NewParameterMeta(&field, fieldValue, tagIn, tagInFlags) 380 381 switch p.Name { 382 case "json": 383 dataList, _ := p.Marshal() 384 tt.Equal(BytesList("json"), dataList) 385 case "query": 386 dataList, _ := p.Marshal() 387 tt.Equal(BytesList("query"), dataList) 388 case "p": 389 dataList, _ := p.Marshal() 390 tt.Equal(BytesList("parameter"), dataList) 391 case "pointer": 392 dataList, _ := p.Marshal() 393 tt.Equal(BytesList("pointer"), dataList) 394 case "createTime": 395 dataList, _ := p.Marshal() 396 tt.Equal(BytesList(req.CreateTime.String()), dataList) 397 case "slice": 398 dataList, _ := p.Marshal() 399 tt.Equal(BytesList("1", "2"), dataList) 400 case "pointerIgnore": 401 dataList, _ := p.Marshal() 402 tt.Nil(dataList) 403 case "bytes": 404 dataList, _ := p.Marshal() 405 tt.Equal(BytesList("bytes"), dataList) 406 case "Data": 407 dataList, _ := p.Marshal() 408 tt.Equal(BytesList(`{"string":"string"}`), dataList) 409 case "DataSlice": 410 dataList, _ := p.Marshal() 411 tt.Equal(BytesList(`["1","2"]`), dataList) 412 case "FormData": 413 dataList, _ := p.Marshal() 414 tt.Equal(BytesList("1"), dataList) 415 } 416 } 417 }