github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/tsv/reader.go (about)

     1  package tsv
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"sort"
     9  	"strconv"
    10  	"strings"
    11  	"unsafe"
    12  
    13  	"github.com/Schaudge/grailbase/errors"
    14  )
    15  
    16  type columnFormat struct {
    17  	fieldName  string       // Go struct field name.
    18  	columnName string       // expected column name in TSV. Defaults to fieldName unless `tsv:"colname"` tag is set.
    19  	typ        reflect.Type // Go type information of the column.
    20  	kind       reflect.Kind // type of the column.
    21  	fmt        string       // Optional format directive for writing this value.
    22  	index      int          // index of this column in a row, 0-based.
    23  	offset     uintptr      // byte offset of this field within the Go struct.
    24  }
    25  
    26  type rowFormat []columnFormat
    27  
    28  // Reader reads a TSV file. It wraps around the standard csv.Reader and allows
    29  // parsing row contents into a Go struct directly. Thread compatible.
    30  //
    31  // TODO(saito) Support passing a custom bool parser.
    32  //
    33  // TODO(saito) Support a custom "NA" detector.
    34  type Reader struct {
    35  	*csv.Reader
    36  
    37  	// HasHeaderRow should be set to true to indicate that the input contains a
    38  	// single header row that lists column names of the rows that follow.  It must
    39  	// be set before reading any data.
    40  	HasHeaderRow bool
    41  
    42  	// UseHeaderNames causes the reader to set struct fields by matching column
    43  	// names to struct field names (or `tsv` tag). It must be set before reading
    44  	// any data.
    45  	//
    46  	// If not set, struct fields are filled in order, EVEN IF HasHeaderRow=true.
    47  	// If set, all struct fields must have a corresponding column in the file or
    48  	// IgnoreMissingColumns must also be set. An error will be reported through
    49  	// Read().
    50  	//
    51  	// REQUIRES: HasHeaderRow=true
    52  	UseHeaderNames bool
    53  
    54  	// RequireParseAllColumns causes Read() report an error if there are columns
    55  	// not listed in the passed-in struct. It must be set before reading any data.
    56  	//
    57  	// REQUIRES: HasHeaderRow=true
    58  	RequireParseAllColumns bool
    59  
    60  	// IgnoreMissingColumns causes the reader to ignore any struct fields that are
    61  	// not present as columns in the file. It must be set before reading any
    62  	// data.
    63  	//
    64  	// REQUIRES: HasHeaderRow=true AND UseHeaderNames=true
    65  	IgnoreMissingColumns bool
    66  
    67  	nRow int // # of rows read so far, excluding the header.
    68  
    69  	// columnIndex x maps colname -> colindex (0-based). Filled from the header
    70  	// line.
    71  	columnIndex map[string]int
    72  
    73  	cachedRowType   reflect.Type
    74  	cachedRowFormat rowFormat
    75  }
    76  
    77  // NewReader creates a new TSV reader that reads from the given input.
    78  func NewReader(in io.Reader) *Reader {
    79  	r := &Reader{
    80  		Reader: csv.NewReader(in),
    81  	}
    82  	r.Reader.Comma = '\t'
    83  	r.ReuseRecord = true
    84  	return r
    85  }
    86  
    87  // Filter columns from the row format that are not present in the file being read.
    88  func (r *Reader) filterRowFormat(format rowFormat) rowFormat {
    89  	var filtered rowFormat
    90  	for _, f := range format {
    91  		if _, ok := r.columnIndex[f.columnName]; ok {
    92  			filtered = append(filtered, f)
    93  		}
    94  	}
    95  	return filtered
    96  }
    97  
    98  // Validates and canonicalizes the given row format object when column names
    99  // are being used from the header row. This method may modify the input.
   100  func (r *Reader) validateRowFormat(format rowFormat) (rowFormat, error) {
   101  	if r.IgnoreMissingColumns {
   102  		format = r.filterRowFormat(format)
   103  	}
   104  	if r.RequireParseAllColumns && len(format) != len(r.columnIndex) {
   105  		return format, fmt.Errorf("number of columns found in %+v does not match format %v", r.columnIndex, format)
   106  	}
   107  	for i := range format {
   108  		col := &format[i]
   109  		var ok bool
   110  		if col.index, ok = r.columnIndex[col.columnName]; !ok {
   111  			return format, fmt.Errorf("column %s does not appear in the header: %+v", col.columnName, r.columnIndex)
   112  		}
   113  	}
   114  	sort.Slice(format, func(i, j int) bool {
   115  		return format[i].index < format[j].index
   116  	})
   117  	return format, nil
   118  }
   119  
   120  func parseRowFormat(typ reflect.Type) (rowFormat, error) {
   121  	var format rowFormat
   122  	if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
   123  		return nil, fmt.Errorf("destination must be a pointer to struct, but found %v", typ)
   124  	}
   125  	typ = typ.Elem()
   126  	nField := typ.NumField()
   127  	for i := 0; i < nField; i++ {
   128  		f := typ.Field(i)
   129  		if f.PkgPath != "" { // Unexported field.
   130  			if tag := f.Tag.Get("tsv"); tag != "" {
   131  				return nil, fmt.Errorf("unexported field '%s' should not have a tsv tag '%s'", f.Name, tag)
   132  			}
   133  			// Unexported embedded (anonymous) struct is OK, but skip other fields.
   134  			if !f.Anonymous {
   135  				continue
   136  			}
   137  		}
   138  		// Fields from embedded structs are parsed recursively.
   139  		if f.Anonymous && f.Type.Kind() == reflect.Struct {
   140  			embeddedFormat, err := parseRowFormat(reflect.PtrTo(f.Type))
   141  			if err != nil {
   142  				return nil, err
   143  			}
   144  			for _, col := range embeddedFormat {
   145  				col.offset += f.Offset  // Shift offsets to be relative to the outer struct.
   146  				col.index = len(format) // Reset column index.
   147  				format = append(format, col)
   148  			}
   149  			continue
   150  		}
   151  		columnName := f.Name
   152  		var fmt string
   153  		if tag := f.Tag.Get("tsv"); tag != "" {
   154  			if tag == "-" {
   155  				continue
   156  			}
   157  			tagArray := strings.Split(tag, ",")
   158  			if tagArray[0] != "" {
   159  				columnName = tagArray[0]
   160  			}
   161  			for _, tag := range tagArray[1:] {
   162  				if strings.HasPrefix(tag, "fmt=") {
   163  					fmt = tag[4:]
   164  				}
   165  			}
   166  		}
   167  		format = append(format, columnFormat{
   168  			fieldName:  f.Name,
   169  			columnName: columnName,
   170  			typ:        f.Type,
   171  			kind:       f.Type.Kind(),
   172  			fmt:        fmt,
   173  			index:      len(format),
   174  			offset:     f.Offset,
   175  		})
   176  	}
   177  	return format, nil
   178  }
   179  
   180  func (r *Reader) wrapError(err error, col columnFormat) error {
   181  	var name string
   182  	if col.columnName != col.fieldName {
   183  		name = fmt.Sprintf("'%s' (Go field '%s')", col.columnName, col.fieldName)
   184  	} else {
   185  		name = fmt.Sprintf("'%s'", col.columnName)
   186  	}
   187  	return errors.E(err, fmt.Sprintf("line %d, column %d, %s", r.nRow, col.index, name))
   188  }
   189  
   190  // fillRow fills Go struct fields from the TSV row.  dest is the pointer to the
   191  // struct, and format defines the struct format.
   192  func (r *Reader) fillRow(val interface{}, row []string) error {
   193  	p := unsafe.Pointer(reflect.ValueOf(val).Pointer())
   194  	if r.RequireParseAllColumns && len(r.cachedRowFormat) != len(row) { // check this for headerless TSVs
   195  		return fmt.Errorf("extra columns found in %+v", r.cachedRowFormat)
   196  	}
   197  
   198  	for _, col := range r.cachedRowFormat {
   199  		if len(row) < col.index {
   200  			return r.wrapError(fmt.Errorf("row has only %d columns", len(row)), col)
   201  		}
   202  		colVal := row[col.index]
   203  		if col.fmt != "" {
   204  			// Not all format directives are recognized while scanning. Try to
   205  			// standardize some of the common options.
   206  			colfmt := col.fmt
   207  			if strings.ContainsAny(colfmt, "efg") {
   208  				// Standardize all base 10 floating point number formats to 'g', and
   209  				// drop precision and width which are not supported while scanning.
   210  				colfmt = "g"
   211  			}
   212  			if len(strings.Fields(colVal)) != 1 {
   213  				// Scanf functions tokenize by space.
   214  				return r.wrapError(fmt.Errorf("value with fmt option can not have whitespace"), col)
   215  			}
   216  			var (
   217  				typ1   = col.typ
   218  				p1     = unsafe.Pointer(uintptr(p) + col.offset)
   219  				v      = reflect.NewAt(typ1, p1).Interface()
   220  				n, err = fmt.Sscanf(colVal, "%"+colfmt, v)
   221  			)
   222  			if err != nil {
   223  				return r.wrapError(err, col)
   224  			}
   225  			if n != 1 {
   226  				return r.wrapError(fmt.Errorf("%d objects scanned for %s; expected 1", n, colVal), col)
   227  			}
   228  			continue
   229  		}
   230  		switch col.kind {
   231  		case reflect.Bool:
   232  			var v bool
   233  			switch colVal {
   234  			case "Y", "yes":
   235  				v = true
   236  			case "N", "no":
   237  				v = false
   238  			default:
   239  				var err error
   240  				if v, err = strconv.ParseBool(colVal); err != nil {
   241  					return r.wrapError(err, col)
   242  				}
   243  			}
   244  			*(*bool)(unsafe.Pointer(uintptr(p) + col.offset)) = v
   245  		case reflect.String:
   246  			*(*string)(unsafe.Pointer(uintptr(p) + col.offset)) = colVal
   247  		case reflect.Int8:
   248  			v, err := strconv.ParseInt(colVal, 0, 8)
   249  			if err != nil {
   250  				return r.wrapError(err, col)
   251  			}
   252  			*(*int8)(unsafe.Pointer(uintptr(p) + col.offset)) = int8(v)
   253  		case reflect.Int16:
   254  			v, err := strconv.ParseInt(colVal, 0, 16)
   255  			if err != nil {
   256  				return r.wrapError(err, col)
   257  			}
   258  			*(*int16)(unsafe.Pointer(uintptr(p) + col.offset)) = int16(v)
   259  		case reflect.Int32:
   260  			v, err := strconv.ParseInt(colVal, 0, 32)
   261  			if err != nil {
   262  				return r.wrapError(err, col)
   263  			}
   264  			*(*int32)(unsafe.Pointer(uintptr(p) + col.offset)) = int32(v)
   265  		case reflect.Int64:
   266  			v, err := strconv.ParseInt(colVal, 0, 64)
   267  			if err != nil {
   268  				return r.wrapError(err, col)
   269  			}
   270  			*(*int64)(unsafe.Pointer(uintptr(p) + col.offset)) = v
   271  		case reflect.Int:
   272  			v, err := strconv.ParseInt(colVal, 0, 64)
   273  			if err != nil {
   274  				return r.wrapError(err, col)
   275  			}
   276  			*(*int)(unsafe.Pointer(uintptr(p) + col.offset)) = int(v)
   277  		case reflect.Uint8:
   278  			v, err := strconv.ParseUint(colVal, 0, 8)
   279  			if err != nil {
   280  				return r.wrapError(err, col)
   281  			}
   282  			*(*uint8)(unsafe.Pointer(uintptr(p) + col.offset)) = uint8(v)
   283  		case reflect.Uint16:
   284  			v, err := strconv.ParseUint(colVal, 0, 16)
   285  			if err != nil {
   286  				return r.wrapError(err, col)
   287  			}
   288  			*(*uint16)(unsafe.Pointer(uintptr(p) + col.offset)) = uint16(v)
   289  		case reflect.Uint32:
   290  			v, err := strconv.ParseUint(colVal, 0, 32)
   291  			if err != nil {
   292  				return r.wrapError(err, col)
   293  
   294  			}
   295  			*(*uint32)(unsafe.Pointer(uintptr(p) + col.offset)) = uint32(v)
   296  		case reflect.Uint64:
   297  			v, err := strconv.ParseUint(colVal, 0, 64)
   298  			if err != nil {
   299  				return r.wrapError(err, col)
   300  
   301  			}
   302  			*(*uint64)(unsafe.Pointer(uintptr(p) + col.offset)) = v
   303  		case reflect.Uint:
   304  			v, err := strconv.ParseUint(colVal, 0, 64)
   305  			if err != nil {
   306  				return r.wrapError(err, col)
   307  			}
   308  			*(*uint)(unsafe.Pointer(uintptr(p) + col.offset)) = uint(v)
   309  
   310  		case reflect.Float32:
   311  			v, err := strconv.ParseFloat(colVal, 32)
   312  			if err != nil {
   313  				return r.wrapError(err, col)
   314  
   315  			}
   316  			*(*float32)(unsafe.Pointer(uintptr(p) + col.offset)) = float32(v)
   317  		case reflect.Float64:
   318  			v, err := strconv.ParseFloat(colVal, 64)
   319  			if err != nil {
   320  				return r.wrapError(err, col)
   321  
   322  			}
   323  			*(*float64)(unsafe.Pointer(uintptr(p) + col.offset)) = v
   324  		default:
   325  			return r.wrapError(fmt.Errorf("unsupported type %v", col.kind), col)
   326  		}
   327  	}
   328  	return nil
   329  }
   330  
   331  // EmptyReadErrStr is the error-string returned by Read() when the file is
   332  // empty, and at least a header line was expected.
   333  const EmptyReadErrStr = "empty file: could not read the header row"
   334  
   335  // Read reads the next TSV row into a go struct.  The argument must be a pointer
   336  // to a struct. It parses each column in the row into the matching struct
   337  // fields.
   338  //
   339  // Example:
   340  //   r := tsv.NewReader(...)
   341  //   ...
   342  //   type row struct {
   343  //     Col0 string
   344  //     Col1 int
   345  //     Float int
   346  //  }
   347  //  var v row
   348  //  err := r.Read(&v)
   349  //
   350  //
   351  // If !Reader.HasHeaderRow or !Reader.UseHeaderNames, the N-th column (base
   352  // zero) will be parsed into the N-th field in the struct.
   353  //
   354  // If Reader.HasHeaderRow and Reader.UseHeaderNames, then the struct's field
   355  // name must match one of the column names listed in the first row in the TSV
   356  // input. The contents of the column with the matching name will be parsed
   357  // into the struct field.
   358  //
   359  // By default, the column name is the struct's field name, but you can override
   360  // it by setting `tsv:"columnname"` tag in the field. The struct tag may also
   361  // take an fmt option to specify how to parse the value using the fmt package.
   362  // This is useful for parsing numbers written in a different base. Note that
   363  // not all verbs are supported with the scanning functions in the fmt package.
   364  // Using the fmt option may lead to slower performance.
   365  // Imagine the following row type:
   366  //
   367  //   type row struct {
   368  //      Chr    string `tsv:"chromo"`
   369  //      Start  int    `tsv:"pos"`
   370  //      Length int
   371  //      Score  int    `tsv:"score,fmt=x"`
   372  //   }
   373  //
   374  // and the following TSV file:
   375  //
   376  //   | chromo | Length | pos | score
   377  //   | chr1   | 1000   | 10  | 0a
   378  //   | chr2   | 950    | 20  | ff
   379  //
   380  // The first Read() will return row{"chr1", 10, 1000, 10}.
   381  //
   382  // The second Read() will return row{"chr2", 20, 950, 15}.
   383  //
   384  // Embedded structs are supported, and the default column name for nested
   385  // fields will be the unqualified name of the field.
   386  func (r *Reader) Read(v interface{}) error {
   387  	if r.nRow == 0 && r.HasHeaderRow {
   388  		headerRow, err := r.Reader.Read()
   389  		if err != nil {
   390  			if err == io.EOF {
   391  				err = errors.E(EmptyReadErrStr)
   392  			}
   393  			return err
   394  		}
   395  		r.nRow++
   396  		r.columnIndex = map[string]int{}
   397  		for i, colName := range headerRow {
   398  			r.columnIndex[colName] = i
   399  		}
   400  	}
   401  	row, err := r.Reader.Read()
   402  	if err != nil {
   403  		return err
   404  	}
   405  	r.nRow++
   406  	typ := reflect.TypeOf(v)
   407  	if typ != r.cachedRowType {
   408  		format, err := parseRowFormat(typ)
   409  		if err != nil {
   410  			return err
   411  		}
   412  		if r.UseHeaderNames {
   413  			format, err = r.validateRowFormat(format)
   414  			if err != nil {
   415  				return err
   416  			}
   417  		}
   418  		r.cachedRowType = typ
   419  		r.cachedRowFormat = format
   420  	}
   421  	return r.fillRow(v, row)
   422  }