github.com/cloudwego/hertz@v0.9.3/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package decoder 18 19 import ( 20 "fmt" 21 "reflect" 22 23 "github.com/cloudwego/hertz/pkg/common/hlog" 24 "github.com/cloudwego/hertz/pkg/protocol" 25 "github.com/cloudwego/hertz/pkg/route/param" 26 ) 27 28 type fileTypeDecoder struct { 29 fieldInfo 30 isRepeated bool 31 } 32 33 func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { 34 fieldValue := GetFieldValue(reqValue, d.parentIndex) 35 field := fieldValue.Field(d.index) 36 37 if d.isRepeated { 38 return d.fileSliceDecode(req, params, reqValue) 39 } 40 var fileName string 41 // file_name > form > fieldName 42 for _, tagInfo := range d.tagInfos { 43 if tagInfo.Key == fileNameTag { 44 fileName = tagInfo.Value 45 break 46 } 47 if tagInfo.Key == formTag { 48 fileName = tagInfo.Value 49 } 50 } 51 if len(fileName) == 0 { 52 fileName = d.fieldName 53 } 54 file, err := req.FormFile(fileName) 55 if err != nil { 56 hlog.SystemLogger().Warnf("can not get file '%s' form request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) 57 return nil 58 } 59 if field.Kind() == reflect.Ptr { 60 t := field.Type() 61 var ptrDepth int 62 for t.Kind() == reflect.Ptr { 63 t = t.Elem() 64 ptrDepth++ 65 } 66 v := reflect.New(t).Elem() 67 v.Set(reflect.ValueOf(*file)) 68 field.Set(ReferenceValue(v, ptrDepth)) 69 return nil 70 } 71 72 // Non-pointer elems 73 field.Set(reflect.ValueOf(*file)) 74 75 return nil 76 } 77 78 func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { 79 fieldValue := GetFieldValue(reqValue, d.parentIndex) 80 field := fieldValue.Field(d.index) 81 // 如果没值,需要为其建一个值 82 if field.Kind() == reflect.Ptr { 83 if field.IsNil() { 84 nonNilVal, ptrDepth := GetNonNilReferenceValue(field) 85 field.Set(ReferenceValue(nonNilVal, ptrDepth)) 86 } 87 } 88 var parentPtrDepth int 89 for field.Kind() == reflect.Ptr { 90 field = field.Elem() 91 parentPtrDepth++ 92 } 93 94 var fileName string 95 // file_name > form > fieldName 96 for _, tagInfo := range d.tagInfos { 97 if tagInfo.Key == fileNameTag { 98 fileName = tagInfo.Value 99 break 100 } 101 if tagInfo.Key == formTag { 102 fileName = tagInfo.Value 103 } 104 } 105 if len(fileName) == 0 { 106 fileName = d.fieldName 107 } 108 multipartForm, err := req.MultipartForm() 109 if err != nil { 110 hlog.SystemLogger().Warnf("can not get MultipartForm from request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) 111 return nil 112 } 113 files, exist := multipartForm.File[fileName] 114 if !exist { 115 hlog.SystemLogger().Warnf("the file '%s' is not existed in request, so skip '%s' field binding", fileName, d.fieldName) 116 return nil 117 } 118 119 if field.Kind() == reflect.Array { 120 if len(files) != field.Len() { 121 return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String()) 122 } 123 } else { 124 // slice need creating enough capacity 125 field = reflect.MakeSlice(field.Type(), len(files), len(files)) 126 } 127 128 // handle multiple pointer 129 var ptrDepth int 130 t := d.fieldType.Elem() 131 elemKind := t.Kind() 132 for elemKind == reflect.Ptr { 133 t = t.Elem() 134 elemKind = t.Kind() 135 ptrDepth++ 136 } 137 138 for idx, file := range files { 139 v := reflect.New(t).Elem() 140 v.Set(reflect.ValueOf(*file)) 141 field.Index(idx).Set(ReferenceValue(v, ptrDepth)) 142 } 143 fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) 144 145 return nil 146 } 147 148 func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { 149 fieldType := field.Type 150 for field.Type.Kind() == reflect.Ptr { 151 fieldType = field.Type.Elem() 152 } 153 isRepeated := false 154 if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice { 155 isRepeated = true 156 } 157 158 return []fieldDecoder{&fileTypeDecoder{ 159 fieldInfo: fieldInfo{ 160 index: index, 161 parentIndex: parentIdx, 162 fieldName: field.Name, 163 tagInfos: tagInfos, 164 fieldType: fieldType, 165 config: config, 166 }, 167 isRepeated: isRepeated, 168 }}, nil 169 }