github.com/bytedance/go-tagexpr@v2.7.5-0.20210114074101-de5b8743ad85+incompatible/binding/param_info.go (about) 1 package binding 2 3 import ( 4 jsonpkg "encoding/json" 5 "errors" 6 "net/http" 7 "net/url" 8 "reflect" 9 "strconv" 10 "strings" 11 12 "github.com/henrylee2cn/ameda" 13 "github.com/henrylee2cn/goutil" 14 "github.com/tidwall/gjson" 15 16 "github.com/bytedance/go-tagexpr" 17 ) 18 19 const ( 20 specialChar = "\x07" 21 ) 22 23 type paramInfo struct { 24 fieldSelector string 25 structField reflect.StructField 26 tagInfos []*tagInfo 27 omitIns map[in]bool 28 bindErrFactory func(failField, msg string) error 29 looseZeroMode bool 30 defaultVal []byte 31 } 32 33 func (p *paramInfo) name(_ in) string { 34 var name string 35 for _, info := range p.tagInfos { 36 if info.paramIn == json { 37 name = info.paramName 38 break 39 } 40 } 41 if name == "" { 42 return p.structField.Name 43 } 44 return name 45 } 46 47 func (p *paramInfo) getField(expr *tagexpr.TagExpr, initZero bool) (reflect.Value, error) { 48 fh, found := expr.Field(p.fieldSelector) 49 if found { 50 v := fh.Value(initZero) 51 if v.IsValid() { 52 return v, nil 53 } 54 } 55 return reflect.Value{}, nil 56 } 57 58 func (p *paramInfo) bindRawBody(info *tagInfo, expr *tagexpr.TagExpr, bodyBytes []byte) error { 59 if len(bodyBytes) == 0 { 60 if info.required { 61 return info.requiredError 62 } 63 return nil 64 } 65 v, err := p.getField(expr, true) 66 if err != nil || !v.IsValid() { 67 return err 68 } 69 v = goutil.DereferenceValue(v) 70 switch v.Kind() { 71 case reflect.Slice: 72 if v.Type().Elem().Kind() != reflect.Uint8 { 73 return info.typeError 74 } 75 v.Set(reflect.ValueOf(bodyBytes)) 76 return nil 77 case reflect.String: 78 v.Set(reflect.ValueOf(goutil.BytesToString(bodyBytes))) 79 return nil 80 default: 81 return info.typeError 82 } 83 } 84 85 func (p *paramInfo) bindPath(info *tagInfo, expr *tagexpr.TagExpr, pathParams PathParams) (bool, error) { 86 if pathParams == nil { 87 return false, nil 88 } 89 r, found := pathParams.Get(info.paramName) 90 if !found { 91 if info.required { 92 return false, info.requiredError 93 } 94 return false, nil 95 } 96 return true, p.bindStringSlice(info, expr, []string{r}) 97 } 98 99 func (p *paramInfo) bindQuery(info *tagInfo, expr *tagexpr.TagExpr, queryValues url.Values) (bool, error) { 100 return p.bindMapStrings(info, expr, queryValues) 101 } 102 103 func (p *paramInfo) bindHeader(info *tagInfo, expr *tagexpr.TagExpr, header http.Header) (bool, error) { 104 return p.bindMapStrings(info, expr, header) 105 } 106 107 func (p *paramInfo) bindCookie(info *tagInfo, expr *tagexpr.TagExpr, cookies []*http.Cookie) error { 108 var r []string 109 for _, c := range cookies { 110 if c.Name == info.paramName { 111 r = append(r, c.Value) 112 } 113 } 114 if len(r) == 0 { 115 if info.required { 116 return info.requiredError 117 } 118 return nil 119 } 120 return p.bindStringSlice(info, expr, r) 121 } 122 123 func (p *paramInfo) bindOrRequireBody(info *tagInfo, expr *tagexpr.TagExpr, bodyCodec codec, bodyString string, postForm map[string][]string) (bool, error) { 124 switch bodyCodec { 125 case bodyForm: 126 return p.bindMapStrings(info, expr, postForm) 127 case bodyJSON: 128 return p.checkRequireJSON(info, expr, bodyString, false) 129 case bodyProtobuf: 130 err := p.checkRequireProtobuf(info, expr, false) 131 return err == nil, err 132 default: 133 return false, info.contentTypeError 134 } 135 } 136 137 func (p *paramInfo) checkRequireProtobuf(info *tagInfo, expr *tagexpr.TagExpr, checkOpt bool) error { 138 if checkOpt && !info.required { 139 v, err := p.getField(expr, false) 140 if err != nil || !v.IsValid() { 141 return info.requiredError 142 } 143 } 144 return nil 145 } 146 147 func (p *paramInfo) checkRequireJSON(info *tagInfo, expr *tagexpr.TagExpr, bodyString string, checkOpt bool) (bool, error) { 148 var requiredError error 149 if checkOpt || info.required { // only return error if it's a required field 150 requiredError = info.requiredError 151 } 152 153 if !gjson.Get(bodyString, info.namePath).Exists() { 154 idx := strings.LastIndex(info.namePath, ".") 155 // There should be a superior but it is empty, no error is reported 156 if idx > 0 && !gjson.Get(bodyString, info.namePath[:idx]).Exists() { 157 return true, nil 158 } 159 return false, requiredError 160 } 161 v, err := p.getField(expr, false) 162 if err != nil || !v.IsValid() { 163 return false, requiredError 164 } 165 return true, nil 166 } 167 168 func (p *paramInfo) bindMapStrings(info *tagInfo, expr *tagexpr.TagExpr, values map[string][]string) (bool, error) { 169 r, ok := values[info.paramName] 170 if !ok || len(r) == 0 { 171 if info.required { 172 return false, info.requiredError 173 } 174 return false, nil 175 } 176 return true, p.bindStringSlice(info, expr, r) 177 } 178 179 // NOTE: len(a)>0 180 func (p *paramInfo) bindStringSlice(info *tagInfo, expr *tagexpr.TagExpr, a []string) error { 181 v, err := p.getField(expr, true) 182 if err != nil || !v.IsValid() { 183 return err 184 } 185 186 v = goutil.DereferenceValue(v) 187 switch v.Kind() { 188 case reflect.String: 189 v.SetString(a[0]) 190 return nil 191 192 case reflect.Bool: 193 var bol bool 194 bol, err = strconv.ParseBool(a[0]) 195 if err == nil || (a[0] == "" && p.looseZeroMode) { 196 v.SetBool(bol) 197 return nil 198 } 199 case reflect.Float32: 200 var f float64 201 f, err = strconv.ParseFloat(a[0], 32) 202 if err == nil || (a[0] == "" && p.looseZeroMode) { 203 v.SetFloat(f) 204 return nil 205 } 206 case reflect.Float64: 207 var f float64 208 f, err = strconv.ParseFloat(a[0], 64) 209 if err == nil || (a[0] == "" && p.looseZeroMode) { 210 v.SetFloat(f) 211 return nil 212 } 213 case reflect.Int64, reflect.Int: 214 var i int64 215 i, err = strconv.ParseInt(a[0], 10, 64) 216 if err == nil || (a[0] == "" && p.looseZeroMode) { 217 v.SetInt(i) 218 return nil 219 } 220 case reflect.Int32: 221 var i int64 222 i, err = strconv.ParseInt(a[0], 10, 32) 223 if err == nil || (a[0] == "" && p.looseZeroMode) { 224 v.SetInt(i) 225 return nil 226 } 227 case reflect.Int16: 228 var i int64 229 i, err = strconv.ParseInt(a[0], 10, 16) 230 if err == nil || (a[0] == "" && p.looseZeroMode) { 231 v.SetInt(i) 232 return nil 233 } 234 case reflect.Int8: 235 var i int64 236 i, err = strconv.ParseInt(a[0], 10, 8) 237 if err == nil || (a[0] == "" && p.looseZeroMode) { 238 v.SetInt(i) 239 return nil 240 } 241 case reflect.Uint64, reflect.Uint: 242 var u uint64 243 u, err = strconv.ParseUint(a[0], 10, 64) 244 if err == nil || (a[0] == "" && p.looseZeroMode) { 245 v.SetUint(u) 246 return nil 247 } 248 case reflect.Uint32: 249 var u uint64 250 u, err = strconv.ParseUint(a[0], 10, 32) 251 if err == nil || (a[0] == "" && p.looseZeroMode) { 252 v.SetUint(u) 253 return nil 254 } 255 case reflect.Uint16: 256 var u uint64 257 u, err = strconv.ParseUint(a[0], 10, 16) 258 if err == nil || (a[0] == "" && p.looseZeroMode) { 259 v.SetUint(u) 260 return nil 261 } 262 case reflect.Uint8: 263 var u uint64 264 u, err = strconv.ParseUint(a[0], 10, 8) 265 if err == nil || (a[0] == "" && p.looseZeroMode) { 266 v.SetUint(u) 267 return nil 268 } 269 case reflect.Slice: 270 vv, err := stringsToValue(v.Type().Elem(), a, p.looseZeroMode) 271 if err == nil { 272 v.Set(vv) 273 return nil 274 } 275 fallthrough 276 default: 277 fn := typeUnmarshalFuncs[v.Type()] 278 if fn != nil { 279 vv, err := fn(a[0], p.looseZeroMode) 280 if err == nil { 281 v.Set(vv) 282 return nil 283 } 284 } 285 } 286 return info.typeError 287 } 288 289 func (p *paramInfo) bindDefaultVal(expr *tagexpr.TagExpr, defaultValue []byte) (bool, error) { 290 if defaultValue == nil { 291 return false, nil 292 } 293 v, err := p.getField(expr, true) 294 if err != nil || !v.IsValid() { 295 return false, err 296 } 297 return true, jsonpkg.Unmarshal(defaultValue, v.Addr().Interface()) 298 } 299 300 // setDefaultVal preprocess the default tags and store the parsed value 301 func (p *paramInfo) setDefaultVal() error { 302 for _, info := range p.tagInfos { 303 if info.paramIn != default_val { 304 continue 305 } 306 307 defaultVal := info.paramName 308 st := ameda.DereferenceType(p.structField.Type) 309 switch st.Kind() { 310 case reflect.String: 311 p.defaultVal, _ = jsonpkg.Marshal(defaultVal) 312 continue 313 case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: 314 // escape single quote and double quote, replace single quote with double quote 315 defaultVal = strings.Replace(defaultVal, `"`, `\"`, -1) 316 defaultVal = strings.Replace(defaultVal, `\'`, specialChar, -1) 317 defaultVal = strings.Replace(defaultVal, `'`, `"`, -1) 318 defaultVal = strings.Replace(defaultVal, specialChar, `'`, -1) 319 } 320 p.defaultVal = ameda.UnsafeStringToBytes(defaultVal) 321 } 322 return nil 323 } 324 325 var errMismatch = errors.New("type mismatch") 326 327 func stringsToValue(t reflect.Type, a []string, emptyAsZero bool) (reflect.Value, error) { 328 var i interface{} 329 var err error 330 var ptrDepth int 331 elemKind := t.Kind() 332 for elemKind == reflect.Ptr { 333 t = t.Elem() 334 elemKind = t.Kind() 335 ptrDepth++ 336 } 337 switch elemKind { 338 case reflect.String: 339 i = a 340 case reflect.Bool: 341 i, err = goutil.StringsToBools(a, emptyAsZero) 342 case reflect.Float32: 343 i, err = goutil.StringsToFloat32s(a, emptyAsZero) 344 case reflect.Float64: 345 i, err = goutil.StringsToFloat64s(a, emptyAsZero) 346 case reflect.Int: 347 i, err = goutil.StringsToInts(a, emptyAsZero) 348 case reflect.Int64: 349 i, err = goutil.StringsToInt64s(a, emptyAsZero) 350 case reflect.Int32: 351 i, err = goutil.StringsToInt32s(a, emptyAsZero) 352 case reflect.Int16: 353 i, err = goutil.StringsToInt16s(a, emptyAsZero) 354 case reflect.Int8: 355 i, err = goutil.StringsToInt8s(a, emptyAsZero) 356 case reflect.Uint: 357 i, err = goutil.StringsToUints(a, emptyAsZero) 358 case reflect.Uint64: 359 i, err = goutil.StringsToUint64s(a, emptyAsZero) 360 case reflect.Uint32: 361 i, err = goutil.StringsToUint32s(a, emptyAsZero) 362 case reflect.Uint16: 363 i, err = goutil.StringsToUint16s(a, emptyAsZero) 364 case reflect.Uint8: 365 i, err = goutil.StringsToUint8s(a, emptyAsZero) 366 default: 367 fn := typeUnmarshalFuncs[t] 368 if fn == nil { 369 return reflect.Value{}, errMismatch 370 } 371 v := reflect.New(reflect.SliceOf(t)).Elem() 372 for _, s := range a { 373 vv, err := fn(s, emptyAsZero) 374 if err != nil { 375 return reflect.Value{}, errMismatch 376 } 377 v = reflect.Append(v, vv) 378 } 379 return goutil.ReferenceSlice(v, ptrDepth), nil 380 } 381 if err != nil { 382 return reflect.Value{}, errMismatch 383 } 384 return goutil.ReferenceSlice(reflect.ValueOf(i), ptrDepth), nil 385 }