github.com/cloudwego/hertz@v0.9.3/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package decoder
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  
    23  	"github.com/cloudwego/hertz/pkg/common/hlog"
    24  	"github.com/cloudwego/hertz/pkg/protocol"
    25  	"github.com/cloudwego/hertz/pkg/route/param"
    26  )
    27  
    28  type fileTypeDecoder struct {
    29  	fieldInfo
    30  	isRepeated bool
    31  }
    32  
    33  func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error {
    34  	fieldValue := GetFieldValue(reqValue, d.parentIndex)
    35  	field := fieldValue.Field(d.index)
    36  
    37  	if d.isRepeated {
    38  		return d.fileSliceDecode(req, params, reqValue)
    39  	}
    40  	var fileName string
    41  	// file_name > form > fieldName
    42  	for _, tagInfo := range d.tagInfos {
    43  		if tagInfo.Key == fileNameTag {
    44  			fileName = tagInfo.Value
    45  			break
    46  		}
    47  		if tagInfo.Key == formTag {
    48  			fileName = tagInfo.Value
    49  		}
    50  	}
    51  	if len(fileName) == 0 {
    52  		fileName = d.fieldName
    53  	}
    54  	file, err := req.FormFile(fileName)
    55  	if err != nil {
    56  		hlog.SystemLogger().Warnf("can not get file '%s' form request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName)
    57  		return nil
    58  	}
    59  	if field.Kind() == reflect.Ptr {
    60  		t := field.Type()
    61  		var ptrDepth int
    62  		for t.Kind() == reflect.Ptr {
    63  			t = t.Elem()
    64  			ptrDepth++
    65  		}
    66  		v := reflect.New(t).Elem()
    67  		v.Set(reflect.ValueOf(*file))
    68  		field.Set(ReferenceValue(v, ptrDepth))
    69  		return nil
    70  	}
    71  
    72  	// Non-pointer elems
    73  	field.Set(reflect.ValueOf(*file))
    74  
    75  	return nil
    76  }
    77  
    78  func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Params, reqValue reflect.Value) error {
    79  	fieldValue := GetFieldValue(reqValue, d.parentIndex)
    80  	field := fieldValue.Field(d.index)
    81  	// 如果没值,需要为其建一个值
    82  	if field.Kind() == reflect.Ptr {
    83  		if field.IsNil() {
    84  			nonNilVal, ptrDepth := GetNonNilReferenceValue(field)
    85  			field.Set(ReferenceValue(nonNilVal, ptrDepth))
    86  		}
    87  	}
    88  	var parentPtrDepth int
    89  	for field.Kind() == reflect.Ptr {
    90  		field = field.Elem()
    91  		parentPtrDepth++
    92  	}
    93  
    94  	var fileName string
    95  	// file_name > form > fieldName
    96  	for _, tagInfo := range d.tagInfos {
    97  		if tagInfo.Key == fileNameTag {
    98  			fileName = tagInfo.Value
    99  			break
   100  		}
   101  		if tagInfo.Key == formTag {
   102  			fileName = tagInfo.Value
   103  		}
   104  	}
   105  	if len(fileName) == 0 {
   106  		fileName = d.fieldName
   107  	}
   108  	multipartForm, err := req.MultipartForm()
   109  	if err != nil {
   110  		hlog.SystemLogger().Warnf("can not get MultipartForm from request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName)
   111  		return nil
   112  	}
   113  	files, exist := multipartForm.File[fileName]
   114  	if !exist {
   115  		hlog.SystemLogger().Warnf("the file '%s' is not existed in request, so skip '%s' field binding", fileName, d.fieldName)
   116  		return nil
   117  	}
   118  
   119  	if field.Kind() == reflect.Array {
   120  		if len(files) != field.Len() {
   121  			return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String())
   122  		}
   123  	} else {
   124  		// slice need creating enough capacity
   125  		field = reflect.MakeSlice(field.Type(), len(files), len(files))
   126  	}
   127  
   128  	// handle multiple pointer
   129  	var ptrDepth int
   130  	t := d.fieldType.Elem()
   131  	elemKind := t.Kind()
   132  	for elemKind == reflect.Ptr {
   133  		t = t.Elem()
   134  		elemKind = t.Kind()
   135  		ptrDepth++
   136  	}
   137  
   138  	for idx, file := range files {
   139  		v := reflect.New(t).Elem()
   140  		v.Set(reflect.ValueOf(*file))
   141  		field.Index(idx).Set(ReferenceValue(v, ptrDepth))
   142  	}
   143  	fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth))
   144  
   145  	return nil
   146  }
   147  
   148  func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) {
   149  	fieldType := field.Type
   150  	for field.Type.Kind() == reflect.Ptr {
   151  		fieldType = field.Type.Elem()
   152  	}
   153  	isRepeated := false
   154  	if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice {
   155  		isRepeated = true
   156  	}
   157  
   158  	return []fieldDecoder{&fileTypeDecoder{
   159  		fieldInfo: fieldInfo{
   160  			index:       index,
   161  			parentIndex: parentIdx,
   162  			fieldName:   field.Name,
   163  			tagInfos:    tagInfos,
   164  			fieldType:   fieldType,
   165  			config:      config,
   166  		},
   167  		isRepeated: isRepeated,
   168  	}}, nil
   169  }