go-hep.org/x/hep@v0.38.1/csvutil/csv.go (about)

     1  // Copyright ©2016 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package csvutil exposes functions and types to easily handle CSV files.
     6  // csvutil builds upon the encoding/csv package and provides the Table type.
     7  // A Table can read data from a CSV file into a struct value whose fields are
     8  // the various columns of the CSV file.
     9  // Conversely, a Table can write data into a CSV file from a struct value.
    10  package csvutil // import "go-hep.org/x/hep/csvutil"
    11  
    12  import (
    13  	"bufio"
    14  	"encoding/csv"
    15  	"fmt"
    16  	"io"
    17  	"math"
    18  	"os"
    19  	"reflect"
    20  	"strconv"
    21  	"strings"
    22  )
    23  
    24  func min(a, b int) int {
    25  	if a < b {
    26  		return a
    27  	}
    28  	return b
    29  }
    30  
    31  func formatValue(val any, quotes bool, recBuilder *strings.Builder) error {
    32  	rv := reflect.Indirect(reflect.ValueOf(val))
    33  	rt := rv.Type()
    34  	switch rt.Kind() {
    35  	case reflect.Bool:
    36  		recBuilder.WriteString(strconv.FormatBool(rv.Bool()))
    37  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    38  		recBuilder.WriteString(strconv.FormatInt(rv.Int(), 10))
    39  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    40  		recBuilder.WriteString(strconv.FormatUint(rv.Uint(), 10))
    41  	case reflect.Float32, reflect.Float64:
    42  		recBuilder.WriteString(strconv.FormatFloat(rv.Float(), 'g', -1, rt.Bits()))
    43  	case reflect.String:
    44  		if quotes {
    45  			recBuilder.WriteString("'" + rv.String() + "'")
    46  		} else {
    47  			recBuilder.WriteString(rv.String())
    48  		}
    49  	case reflect.Slice:
    50  		recBuilder.WriteString("[")
    51  		for i := range rv.Len() {
    52  			if i > 0 {
    53  				recBuilder.WriteString(", ")
    54  			}
    55  			err := formatValue(rv.Index(i).Interface(), true, recBuilder)
    56  			if err != nil {
    57  				return err
    58  			}
    59  		}
    60  		recBuilder.WriteString("]")
    61  	default:
    62  		return fmt.Errorf("csvutil: invalid type (%[1]T) %[1]v (kind=%[2]v)", val, rt.Kind())
    63  	}
    64  	return nil
    65  }
    66  
    67  // Open opens a Table in read mode connected to a CSV file.
    68  func Open(fname string) (*Table, error) {
    69  	r, err := os.Open(fname)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	table := &Table{
    74  		Reader: csv.NewReader(bufio.NewReader(r)),
    75  		f:      r,
    76  	}
    77  	return table, err
    78  }
    79  
    80  // Create creates a new CSV file and returns a Table in write mode.
    81  func Create(fname string) (*Table, error) {
    82  	w, err := os.Create(fname)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	table := &Table{
    87  		Writer: csv.NewWriter(bufio.NewWriter(w)),
    88  		f:      w,
    89  	}
    90  	return table, err
    91  }
    92  
    93  // Append opens an already existing CSV file and returns a Table in write mode.
    94  // The file cursor is positioned at the end of the file so new data can be
    95  // appended via the returned Table.
    96  func Append(fname string) (*Table, error) {
    97  	f, err := os.OpenFile(fname, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0666)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	_, err = f.Seek(0, io.SeekEnd)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	table := &Table{
   108  		Writer: csv.NewWriter(bufio.NewWriter(f)),
   109  		f:      f,
   110  	}
   111  	return table, err
   112  }
   113  
   114  // Table provides read- or write-access to a CSV file.
   115  // Table supports reading and writing data to/from a struct value.
   116  type Table struct {
   117  	Reader *csv.Reader
   118  	Writer *csv.Writer
   119  
   120  	f      *os.File
   121  	closed bool
   122  	err    error
   123  }
   124  
   125  // Close closes the table and the underlying CSV file.
   126  func (tbl *Table) Close() error {
   127  	if tbl.closed {
   128  		return tbl.err
   129  	}
   130  
   131  	if tbl.Writer != nil {
   132  		tbl.Writer.Flush()
   133  		tbl.err = tbl.Writer.Error()
   134  	}
   135  
   136  	if tbl.f != nil {
   137  		err := tbl.f.Close()
   138  		if err != nil && tbl.err == nil {
   139  			tbl.err = err
   140  		}
   141  		tbl.f = nil
   142  		tbl.closed = true
   143  	}
   144  	return tbl.err
   145  }
   146  
   147  // ReadRows returns a row iterator semantically equivalent to [beg,end).
   148  // If end==-1, the iterator will be configured to read rows until EOF.
   149  func (tbl *Table) ReadRows(beg, end int64) (*Rows, error) {
   150  	inc := int64(1)
   151  	rows := &Rows{
   152  		tbl: tbl,
   153  		i:   0,
   154  		n:   end - beg,
   155  		inc: inc,
   156  		cur: beg - inc,
   157  	}
   158  	if end == -1 {
   159  		rows.n = math.MaxInt64
   160  	}
   161  	if beg > 0 {
   162  		err := rows.skip(beg)
   163  		if err != nil {
   164  			return nil, err
   165  		}
   166  	}
   167  	return rows, nil
   168  }
   169  
   170  // WriteHeader writes a header to the underlying CSV file
   171  func (tbl *Table) WriteHeader(hdr string) error {
   172  	if !strings.HasSuffix(hdr, "\n") {
   173  		hdr += "\n"
   174  	}
   175  	_, err := tbl.f.WriteString(hdr)
   176  	return err
   177  }
   178  
   179  // WriteRow writes the data into the columns at the current row.
   180  func (tbl *Table) WriteRow(args ...any) error {
   181  	var err error
   182  	if tbl.Writer == nil {
   183  		return fmt.Errorf("csvutil: Table is not in write mode")
   184  	}
   185  
   186  	switch len(args) {
   187  	case 0:
   188  		return fmt.Errorf("csvutil: Table.WriteRow needs at least one argument")
   189  
   190  	case 1:
   191  		// maybe special case: struct?
   192  		rv := reflect.Indirect(reflect.ValueOf(args[0]))
   193  		rt := rv.Type()
   194  		switch rt.Kind() {
   195  		case reflect.Struct:
   196  			err = tbl.writeStruct(rv)
   197  			return err
   198  		}
   199  	}
   200  
   201  	err = tbl.write(args...)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	return err
   207  }
   208  
   209  func (tbl *Table) write(args ...any) error {
   210  	rec := make([]string, len(args))
   211  	var recBuilder strings.Builder
   212  	for i, arg := range args {
   213  		recBuilder.Reset()
   214  		err := formatValue(arg, false, &recBuilder)
   215  		if err != nil {
   216  			return err
   217  		}
   218  		rec[i] = recBuilder.String()
   219  	}
   220  	return tbl.Writer.Write(rec)
   221  }
   222  
   223  func (tbl *Table) writeStruct(rv reflect.Value) error {
   224  	rt := rv.Type()
   225  	args := make([]any, rt.NumField())
   226  	for i := range args {
   227  		args[i] = rv.Field(i).Interface()
   228  	}
   229  
   230  	return tbl.write(args...)
   231  }
   232  
   233  // Rows is an iterator over an interval of rows inside a CSV file.
   234  type Rows struct {
   235  	tbl    *Table
   236  	i      int64    // number of rows iterated over
   237  	n      int64    // number of rows this iterator iters over
   238  	inc    int64    // number of rows to increment by at each iteration
   239  	cur    int64    // current row index
   240  	record []string // last read record
   241  	closed bool
   242  	err    error // last error
   243  }
   244  
   245  // Err returns the error, if any, that was encountered during iteration.
   246  // Err may be called after an explicit or implicit Close.
   247  func (rows *Rows) Err() error {
   248  	return rows.err
   249  }
   250  
   251  // Close closes the Rows, preventing further enumeration.
   252  // Close is idempotent and does not affect the result of Err.
   253  func (rows *Rows) Close() error {
   254  	if rows.closed {
   255  		return nil
   256  	}
   257  	rows.closed = true
   258  	rows.tbl = nil
   259  	return nil
   260  }
   261  
   262  // NumFields returns the number of fields in the current CSV-record.
   263  // NumFields assumes Rows.Next() has been called at least once.
   264  func (rows *Rows) NumFields() int {
   265  	return len(rows.record)
   266  }
   267  
   268  // Fields returns the raw string values of the fields of the current CSV-record.
   269  // Fields assumes Rows.Next() has been called at least once.
   270  func (rows *Rows) Fields() []string {
   271  	fields := make([]string, len(rows.record))
   272  	copy(fields, rows.record)
   273  	return fields
   274  }
   275  
   276  // Scan copies the columns in the current row into the values pointed at by
   277  // dest.
   278  // dest can be either:
   279  // - a pointer to a struct value (whose fields will be filled with column values)
   280  // - a slice of values
   281  func (rows *Rows) Scan(dest ...any) error {
   282  	var err error
   283  	defer func() {
   284  		rows.err = err
   285  	}()
   286  
   287  	switch len(dest) {
   288  	case 0:
   289  		err = fmt.Errorf("csvutil: Rows.Scan needs at least one argument")
   290  		return err
   291  
   292  	case 1:
   293  		// maybe special case: struct?
   294  		rv := reflect.ValueOf(dest[0]).Elem()
   295  		rt := rv.Type()
   296  		switch rt.Kind() {
   297  		case reflect.Struct:
   298  			err = rows.scanStruct(rv)
   299  			return err
   300  		}
   301  	}
   302  
   303  	err = rows.scan(dest...)
   304  	return err
   305  }
   306  
   307  func (rows *Rows) scan(args ...any) error {
   308  	var err error
   309  	n := min(len(rows.record), len(args))
   310  	for i := range n {
   311  		rec := rows.record[i]
   312  		rv := reflect.ValueOf(args[i]).Elem()
   313  		rt := reflect.TypeOf(args[i]).Elem()
   314  		switch rt.Kind() {
   315  		case reflect.Bool:
   316  			v, err := strconv.ParseBool(rec)
   317  			if err != nil {
   318  				return err
   319  			}
   320  			rv.SetBool(v)
   321  
   322  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   323  			v, err := strconv.ParseInt(rec, 10, rt.Bits())
   324  			if err != nil {
   325  				return err
   326  			}
   327  			rv.SetInt(v)
   328  
   329  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   330  			v, err := strconv.ParseUint(rec, 10, rt.Bits())
   331  			if err != nil {
   332  				return err
   333  			}
   334  			rv.SetUint(v)
   335  
   336  		case reflect.Float32, reflect.Float64:
   337  			v, err := strconv.ParseFloat(rec, rt.Bits())
   338  			if err != nil {
   339  				return err
   340  			}
   341  			rv.SetFloat(v)
   342  
   343  		case reflect.String:
   344  			rv.SetString(rec)
   345  
   346  		default:
   347  			return fmt.Errorf("csvutil: invalid type (%T) %q (kind=%v)", rv.Interface(), rec, rt.Kind())
   348  		}
   349  	}
   350  
   351  	return err
   352  }
   353  
   354  func (rows *Rows) scanStruct(rv reflect.Value) error {
   355  	rt := rv.Type()
   356  	args := make([]any, rt.NumField())
   357  	for i := range rt.NumField() {
   358  		args[i] = rv.Field(i).Addr().Interface()
   359  	}
   360  	return rows.scan(args...)
   361  }
   362  
   363  func (rows *Rows) skip(n int64) error {
   364  	var err error
   365  	for i := int64(0); i < n; i++ {
   366  		_, err = rows.tbl.Reader.Read()
   367  		if err != nil {
   368  			return err
   369  		}
   370  		rows.cur++
   371  	}
   372  	return err
   373  }
   374  
   375  // Next prepares the next result row for reading with the Scan method.
   376  // It returns true on success, false if there is no next result row.
   377  // Every call to Scan, even the first one, must be preceded by a call to Next.
   378  func (rows *Rows) Next() bool {
   379  	if rows.closed {
   380  		return false
   381  	}
   382  	if rows.err != nil {
   383  		return false
   384  	}
   385  	next := rows.i < rows.n
   386  	rows.cur += rows.inc
   387  	rows.i += rows.inc
   388  	if !next {
   389  		rows.err = rows.Close()
   390  		return next
   391  	}
   392  
   393  	var err error
   394  	rows.record, err = rows.tbl.Reader.Read()
   395  	if err != nil {
   396  		rows.err = err
   397  		return false
   398  	}
   399  
   400  	return next
   401  }