github.com/bytedance/go-tagexpr/v2@v2.9.8/binding/param_info.go (about) 1 package binding 2 3 import ( 4 jsonpkg "encoding/json" 5 "errors" 6 "fmt" 7 "mime/multipart" 8 "net/http" 9 "net/url" 10 "reflect" 11 "strconv" 12 "strings" 13 14 "github.com/andeya/ameda" 15 "github.com/bytedance/go-tagexpr/v2" 16 gjson "github.com/bytedance/go-tagexpr/v2/binding/tidwall_gjson" 17 ) 18 19 const ( 20 specialChar = "\x07" 21 ) 22 23 type paramInfo struct { 24 fieldSelector string 25 structField reflect.StructField 26 tagInfos []*tagInfo 27 omitIns map[in]bool 28 bindErrFactory func(failField, msg string) error 29 looseZeroMode bool 30 defaultVal []byte 31 } 32 33 func (p *paramInfo) name(_ in) string { 34 var name string 35 for _, info := range p.tagInfos { 36 if info.paramIn == json { 37 name = info.paramName 38 break 39 } 40 } 41 if name == "" { 42 return p.structField.Name 43 } 44 return name 45 } 46 47 func (p *paramInfo) getField(expr *tagexpr.TagExpr, initZero bool) (reflect.Value, error) { 48 fh, found := expr.Field(p.fieldSelector) 49 if found { 50 v := fh.Value(initZero) 51 if v.IsValid() { 52 return v, nil 53 } 54 } 55 return reflect.Value{}, nil 56 } 57 58 func (p *paramInfo) bindRawBody(info *tagInfo, expr *tagexpr.TagExpr, bodyBytes []byte) error { 59 if len(bodyBytes) == 0 { 60 if info.required { 61 return info.requiredError 62 } 63 return nil 64 } 65 v, err := p.getField(expr, true) 66 if err != nil || !v.IsValid() { 67 return err 68 } 69 v = ameda.DereferenceValue(v) 70 switch v.Kind() { 71 case reflect.Slice: 72 if v.Type().Elem().Kind() != reflect.Uint8 { 73 return info.typeError 74 } 75 v.Set(reflect.ValueOf(bodyBytes)) 76 return nil 77 case reflect.String: 78 v.Set(reflect.ValueOf(ameda.UnsafeBytesToString(bodyBytes))) 79 return nil 80 default: 81 return info.typeError 82 } 83 } 84 85 func (p *paramInfo) bindPath(info *tagInfo, expr *tagexpr.TagExpr, pathParams PathParams) (bool, error) { 86 if pathParams == nil { 87 return false, nil 88 } 89 r, found := pathParams.Get(info.paramName) 90 if !found { 91 if info.required { 92 return false, info.requiredError 93 } 94 return false, nil 95 } 96 return true, p.bindStringSlice(info, expr, []string{r}) 97 } 98 99 func (p *paramInfo) bindQuery(info *tagInfo, expr *tagexpr.TagExpr, queryValues url.Values) (bool, error) { 100 return p.bindMapStrings(info, expr, queryValues) 101 } 102 103 func (p *paramInfo) bindHeader(info *tagInfo, expr *tagexpr.TagExpr, header http.Header) (bool, error) { 104 return p.bindMapStrings(info, expr, header) 105 } 106 107 func (p *paramInfo) bindCookie(info *tagInfo, expr *tagexpr.TagExpr, cookies []*http.Cookie) (bool, error) { 108 var r []string 109 for _, c := range cookies { 110 if c.Name == info.paramName { 111 r = append(r, c.Value) 112 } 113 } 114 if len(r) == 0 { 115 if info.required { 116 return false, info.requiredError 117 } 118 return false, nil 119 } 120 return true, p.bindStringSlice(info, expr, r) 121 } 122 123 func (p *paramInfo) bindOrRequireBody( 124 info *tagInfo, expr *tagexpr.TagExpr, bodyCodec codec, bodyString string, 125 postForm map[string][]string, fileHeaders map[string][]*multipart.FileHeader, hasDefaultVal bool) (bool, error) { 126 switch bodyCodec { 127 case bodyForm: 128 found, err := p.bindMapStrings(info, expr, postForm) 129 if !found { 130 return p.bindFileHeaders(info, expr, fileHeaders) 131 } 132 return found, err 133 case bodyJSON: 134 return p.checkRequireJSON(info, expr, bodyString, hasDefaultVal) 135 case bodyProtobuf: 136 // It has been checked when binding, no need to check now 137 return true, nil 138 // err := p.checkRequireProtobuf(info, expr, false) 139 // return err == nil, err 140 default: 141 return false, info.contentTypeError 142 } 143 } 144 145 func (p *paramInfo) checkRequireProtobuf(info *tagInfo, expr *tagexpr.TagExpr, checkOpt bool) error { 146 if checkOpt && !info.required { 147 v, err := p.getField(expr, false) 148 if err != nil || !v.IsValid() { 149 return info.requiredError 150 } 151 } 152 return nil 153 } 154 155 func (p *paramInfo) checkRequireJSON(info *tagInfo, expr *tagexpr.TagExpr, bodyString string, hasDefaultVal bool) (bool, error) { 156 var requiredError error 157 if info.required { // only return error if it's a required field 158 requiredError = info.requiredError 159 } else if !hasDefaultVal { 160 return true, nil 161 } 162 if !gjson.Get(bodyString, info.namePath).Exists() { 163 idx := strings.LastIndex(info.namePath, ".") 164 // There should be a superior but it is empty, no error is reported 165 if idx > 0 && !gjson.Get(bodyString, info.namePath[:idx]).Exists() { 166 return true, nil 167 } 168 return false, requiredError 169 } 170 v, err := p.getField(expr, false) 171 if err != nil || !v.IsValid() { 172 return false, requiredError 173 } 174 return true, nil 175 } 176 177 var fileHeaderType = reflect.TypeOf(multipart.FileHeader{}) 178 179 func (p *paramInfo) bindFileHeaders(info *tagInfo, expr *tagexpr.TagExpr, fileHeaders map[string][]*multipart.FileHeader) (bool, error) { 180 r, ok := fileHeaders[info.paramName] 181 if !ok || len(r) == 0 { 182 if info.required { 183 return false, info.requiredError 184 } 185 return false, nil 186 } 187 v, err := p.getField(expr, true) 188 if err != nil || !v.IsValid() { 189 return true, err 190 } 191 v = ameda.DereferenceValue(v) 192 var elemType reflect.Type 193 isSlice := v.Kind() == reflect.Slice 194 if isSlice { 195 elemType = v.Type().Elem() 196 } else { 197 elemType = v.Type() 198 } 199 var ptrDepth int 200 for elemType.Kind() == reflect.Ptr { 201 elemType = elemType.Elem() 202 ptrDepth++ 203 } 204 if elemType != fileHeaderType { 205 return true, errors.New("parameter type is not (*)multipart.FileHeader struct or slice") 206 } 207 if len(r) == 0 || r[0] == nil { 208 return true, nil 209 } 210 if !isSlice { 211 v.Set(reflect.ValueOf(*r[0])) 212 return true, nil 213 } 214 for _, fileHeader := range r { 215 v.Set(reflect.Append(v, ameda.ReferenceValue(reflect.ValueOf(fileHeader), ptrDepth-1))) 216 } 217 return true, nil 218 } 219 220 func (p *paramInfo) bindMapStrings(info *tagInfo, expr *tagexpr.TagExpr, values map[string][]string) (bool, error) { 221 r, ok := values[info.paramName] 222 if !ok || len(r) == 0 { 223 if info.required { 224 return false, info.requiredError 225 } 226 return false, nil 227 } 228 return true, p.bindStringSlice(info, expr, r) 229 } 230 231 // NOTE: len(a)>0 232 func (p *paramInfo) bindStringSlice(info *tagInfo, expr *tagexpr.TagExpr, a []string) error { 233 v, err := p.getField(expr, true) 234 if err != nil || !v.IsValid() { 235 return err 236 } 237 238 v = ameda.DereferenceValue(v) 239 240 // we have customized unmarshal defined, we should use it firstly 241 if fn, exist := typeUnmarshalFuncs[v.Type()]; exist { 242 vv, err := fn(a[0], p.looseZeroMode) 243 if err == nil { 244 v.Set(vv) 245 return nil 246 } 247 return info.typeError 248 } 249 250 switch v.Kind() { 251 case reflect.String: 252 v.SetString(a[0]) 253 return nil 254 255 case reflect.Bool: 256 var bol bool 257 bol, err = strconv.ParseBool(a[0]) 258 if err == nil || (a[0] == "" && p.looseZeroMode) { 259 v.SetBool(bol) 260 return nil 261 } 262 case reflect.Float32: 263 var f float64 264 f, err = strconv.ParseFloat(a[0], 32) 265 if err == nil || (a[0] == "" && p.looseZeroMode) { 266 v.SetFloat(f) 267 return nil 268 } 269 case reflect.Float64: 270 var f float64 271 f, err = strconv.ParseFloat(a[0], 64) 272 if err == nil || (a[0] == "" && p.looseZeroMode) { 273 v.SetFloat(f) 274 return nil 275 } 276 case reflect.Int64, reflect.Int: 277 var i int64 278 i, err = strconv.ParseInt(a[0], 10, 64) 279 if err == nil || (a[0] == "" && p.looseZeroMode) { 280 v.SetInt(i) 281 return nil 282 } 283 case reflect.Int32: 284 var i int64 285 i, err = strconv.ParseInt(a[0], 10, 32) 286 if err == nil || (a[0] == "" && p.looseZeroMode) { 287 v.SetInt(i) 288 return nil 289 } 290 case reflect.Int16: 291 var i int64 292 i, err = strconv.ParseInt(a[0], 10, 16) 293 if err == nil || (a[0] == "" && p.looseZeroMode) { 294 v.SetInt(i) 295 return nil 296 } 297 case reflect.Int8: 298 var i int64 299 i, err = strconv.ParseInt(a[0], 10, 8) 300 if err == nil || (a[0] == "" && p.looseZeroMode) { 301 v.SetInt(i) 302 return nil 303 } 304 case reflect.Uint64, reflect.Uint: 305 var u uint64 306 u, err = strconv.ParseUint(a[0], 10, 64) 307 if err == nil || (a[0] == "" && p.looseZeroMode) { 308 v.SetUint(u) 309 return nil 310 } 311 case reflect.Uint32: 312 var u uint64 313 u, err = strconv.ParseUint(a[0], 10, 32) 314 if err == nil || (a[0] == "" && p.looseZeroMode) { 315 v.SetUint(u) 316 return nil 317 } 318 case reflect.Uint16: 319 var u uint64 320 u, err = strconv.ParseUint(a[0], 10, 16) 321 if err == nil || (a[0] == "" && p.looseZeroMode) { 322 v.SetUint(u) 323 return nil 324 } 325 case reflect.Uint8: 326 var u uint64 327 u, err = strconv.ParseUint(a[0], 10, 8) 328 if err == nil || (a[0] == "" && p.looseZeroMode) { 329 v.SetUint(u) 330 return nil 331 } 332 case reflect.Slice: 333 var ptrDepth int 334 t := v.Type().Elem() 335 elemKind := t.Kind() 336 for elemKind == reflect.Ptr { 337 t = t.Elem() 338 elemKind = t.Kind() 339 ptrDepth++ 340 } 341 val := reflect.New(v.Type()).Elem() 342 for _, s := range a { 343 var vv reflect.Value 344 vv, err = stringToValue(t, s, p.looseZeroMode) 345 if err != nil { 346 break 347 } 348 val = reflect.Append(val, ameda.ReferenceValue(vv, ptrDepth)) 349 } 350 if err == nil { 351 v.Set(val) 352 return nil 353 } 354 fallthrough 355 default: 356 // no customized unmarshal defined 357 err = unmarshal(ameda.UnsafeStringToBytes(a[0]), v.Addr().Interface()) 358 if err == nil { 359 return nil 360 } 361 } 362 return info.typeError 363 } 364 365 func (p *paramInfo) bindDefaultVal(expr *tagexpr.TagExpr, defaultValue []byte) (bool, error) { 366 if defaultValue == nil { 367 return false, nil 368 } 369 v, err := p.getField(expr, true) 370 if err != nil || !v.IsValid() { 371 return false, err 372 } 373 return true, jsonpkg.Unmarshal(defaultValue, v.Addr().Interface()) 374 } 375 376 // setDefaultVal preprocess the default tags and store the parsed value 377 func (p *paramInfo) setDefaultVal() error { 378 for _, info := range p.tagInfos { 379 if info.paramIn != default_val { 380 continue 381 } 382 383 defaultVal := info.paramName 384 st := ameda.DereferenceType(p.structField.Type) 385 switch st.Kind() { 386 case reflect.String: 387 p.defaultVal, _ = jsonpkg.Marshal(defaultVal) 388 continue 389 case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: 390 // escape single quote and double quote, replace single quote with double quote 391 defaultVal = strings.Replace(defaultVal, `"`, `\"`, -1) 392 defaultVal = strings.Replace(defaultVal, `\'`, specialChar, -1) 393 defaultVal = strings.Replace(defaultVal, `'`, `"`, -1) 394 defaultVal = strings.Replace(defaultVal, specialChar, `'`, -1) 395 } 396 p.defaultVal = ameda.UnsafeStringToBytes(defaultVal) 397 } 398 return nil 399 } 400 401 func stringToValue(elemType reflect.Type, s string, emptyAsZero bool) (v reflect.Value, err error) { 402 v = reflect.New(elemType).Elem() 403 404 // we have customized unmarshal defined, we should use it firstly 405 if fn, exist := typeUnmarshalFuncs[elemType]; exist { 406 vv, err := fn(s, emptyAsZero) 407 if err == nil { 408 v.Set(vv) 409 } 410 return v, err 411 } 412 413 switch elemType.Kind() { 414 case reflect.String: 415 v.SetString(s) 416 case reflect.Bool: 417 var i bool 418 i, err = ameda.StringToBool(s, emptyAsZero) 419 if err == nil { 420 v.SetBool(i) 421 } 422 case reflect.Float32, reflect.Float64: 423 var i float64 424 i, err = ameda.StringToFloat64(s, emptyAsZero) 425 if err == nil { 426 v.SetFloat(i) 427 } 428 case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: 429 var i int64 430 i, err = ameda.StringToInt64(s, emptyAsZero) 431 if err == nil { 432 v.SetInt(i) 433 } 434 case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: 435 var i uint64 436 i, err = ameda.StringToUint64(s, emptyAsZero) 437 if err == nil { 438 v.SetUint(i) 439 } 440 default: 441 // no customized unmarshal defined 442 err = unmarshal(ameda.UnsafeStringToBytes(s), v.Addr().Interface()) 443 } 444 if err != nil { 445 return reflect.Value{}, fmt.Errorf("type mismatch, error=%v", err) 446 } 447 return v, nil 448 }