gitee.com/quant1x/pkg@v0.2.8/gocsv/unmarshaller.go (about)

     1  package gocsv
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"reflect"
     7  )
     8  
     9  // Unmarshaller is a CSV to struct unmarshaller.
    10  type Unmarshaller struct {
    11  	reader                 *csv.Reader
    12  	Headers                []string
    13  	fieldInfoMap           []*fieldInfo
    14  	MismatchedHeaders      []string
    15  	MismatchedStructFields []string
    16  	outType                reflect.Type
    17  	out                    interface{}
    18  }
    19  
    20  // NewUnmarshaller creates an unmarshaller from a csv.Reader and a struct.
    21  func NewUnmarshaller(reader *csv.Reader, out interface{}) (*Unmarshaller, error) {
    22  	headers, err := reader.Read()
    23  	if err != nil {
    24  		return nil, err
    25  	}
    26  	headers = normalizeHeaders(headers)
    27  
    28  	um := &Unmarshaller{reader: reader, outType: reflect.TypeOf(out)}
    29  	err = validate(um, out, headers)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	return um, nil
    34  }
    35  
    36  // Read returns an interface{} whose runtime type is the same as the struct that
    37  // was used to create the Unmarshaller.
    38  func (um *Unmarshaller) Read() (interface{}, error) {
    39  	row, err := um.reader.Read()
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	return um.unmarshalRow(row, nil)
    44  }
    45  
    46  // ReadUnmatched is same as Read(), but returns a map of the columns that didn't match a field in the struct
    47  func (um *Unmarshaller) ReadUnmatched() (interface{}, map[string]string, error) {
    48  	row, err := um.reader.Read()
    49  	if err != nil {
    50  		return nil, nil, err
    51  	}
    52  	unmatched := make(map[string]string)
    53  	value, err := um.unmarshalRow(row, unmatched)
    54  	return value, unmatched, err
    55  }
    56  
    57  // validate ensures that a struct was used to create the Unmarshaller, and validates
    58  // CSV headers against the CSV tags in the struct.
    59  func validate(um *Unmarshaller, s interface{}, headers []string) error {
    60  	concreteType := reflect.TypeOf(s)
    61  	if concreteType.Kind() == reflect.Ptr {
    62  		concreteType = concreteType.Elem()
    63  	}
    64  	if err := ensureOutInnerType(concreteType); err != nil {
    65  		return err
    66  	}
    67  	structInfo := getStructInfo(concreteType) // Get struct info to get CSV annotations.
    68  	if len(structInfo.Fields) == 0 {
    69  		return ErrNoStructTags
    70  	}
    71  	csvHeadersLabels := make([]*fieldInfo, len(headers)) // Used to store the corresponding header <-> position in CSV
    72  	headerCount := map[string]int{}
    73  	for i, csvColumnHeader := range headers {
    74  		curHeaderCount := headerCount[csvColumnHeader]
    75  		if fieldInfo := getCSVFieldPosition(csvColumnHeader, structInfo, curHeaderCount); fieldInfo != nil {
    76  			csvHeadersLabels[i] = fieldInfo
    77  			if ShouldAlignDuplicateHeadersWithStructFieldOrder {
    78  				curHeaderCount++
    79  				headerCount[csvColumnHeader] = curHeaderCount
    80  			}
    81  		}
    82  	}
    83  
    84  	if FailIfDoubleHeaderNames {
    85  		if err := maybeDoubleHeaderNames(headers); err != nil {
    86  			return err
    87  		}
    88  	}
    89  
    90  	um.Headers = headers
    91  	um.fieldInfoMap = csvHeadersLabels
    92  	um.MismatchedHeaders = mismatchHeaderFields(structInfo.Fields, headers)
    93  	um.MismatchedStructFields = mismatchStructFields(structInfo.Fields, headers)
    94  	um.out = s
    95  	return nil
    96  }
    97  
    98  // unmarshalRow converts a CSV row to a struct, based on CSV struct tags.
    99  // If unmatched is non nil, it is populated with any columns that don't map to a struct field
   100  func (um *Unmarshaller) unmarshalRow(row []string, unmatched map[string]string) (interface{}, error) {
   101  	isPointer := false
   102  	concreteOutType := um.outType
   103  	if um.outType.Kind() == reflect.Ptr {
   104  		isPointer = true
   105  		concreteOutType = concreteOutType.Elem()
   106  	}
   107  	outValue := createNewOutInner(isPointer, concreteOutType)
   108  	for j, csvColumnContent := range row {
   109  		if j < len(um.fieldInfoMap) && um.fieldInfoMap[j] != nil {
   110  			fieldInfo := um.fieldInfoMap[j]
   111  			if err := setInnerField(&outValue, isPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct
   112  				return nil, fmt.Errorf("cannot assign field at %v to %s through index chain %v: %v", j, outValue.Type(), fieldInfo.IndexChain, err)
   113  			}
   114  		} else if unmatched != nil {
   115  			unmatched[um.Headers[j]] = csvColumnContent
   116  		}
   117  	}
   118  	return outValue.Interface(), nil
   119  }
   120  
   121  // RenormalizeHeaders will remap the header names based on the headerNormalizer.
   122  // This can be used to map a CSV to a struct where the CSV header names do not match in the file but a mapping is known
   123  func (um *Unmarshaller) RenormalizeHeaders(headerNormalizer func([]string) []string) error {
   124  	headers := um.Headers
   125  	if headerNormalizer != nil {
   126  		headers = headerNormalizer(headers)
   127  	}
   128  	err := validate(um, um.out, headers)
   129  	if err != nil {
   130  		return err
   131  	}
   132  
   133  	return nil
   134  }