go-hep.org/x/hep@v0.38.1/csvutil/csvdriver/import.go (about)

     1  // Copyright ©2016 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package csvdriver
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"fmt"
    11  	"io"
    12  	"reflect"
    13  	"strconv"
    14  	"strings"
    15  
    16  	"go-hep.org/x/hep/csvutil"
    17  )
    18  
    19  func (conn *csvConn) importCSV() error {
    20  	fname := conn.f.Name()
    21  	tbl, err := csvutil.Open(fname)
    22  	if err != nil {
    23  		return err
    24  	}
    25  	defer tbl.Close()
    26  	tbl.Reader.Comma = conn.cfg.Comma
    27  	tbl.Reader.Comment = conn.cfg.Comment
    28  
    29  	schema, err := inferSchema(conn, conn.cfg.Header, conn.cfg.Names)
    30  	if err != nil {
    31  		return err
    32  	}
    33  
    34  	tx, err := conn.Begin()
    35  	if err != nil {
    36  		return err
    37  	}
    38  	defer func() {
    39  		_ = tx.Commit()
    40  	}()
    41  
    42  	ctx := context.Background()
    43  
    44  	_, err = conn.ExecContext(ctx, "create table csv ("+schema.Decl()+")", nil)
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	_, err = conn.ExecContext(ctx, "create index csv_id on csv (id());", nil)
    50  	if err != nil {
    51  		return err
    52  	}
    53  
    54  	beg := int64(0)
    55  	if conn.cfg.Header {
    56  		beg++
    57  	}
    58  	rows, err := tbl.ReadRows(beg, -1)
    59  	if err != nil {
    60  		return err
    61  	}
    62  	defer rows.Close()
    63  
    64  	vargs, pargs := schema.Args()
    65  	def := schema.Def()
    66  	insert := "insert into csv values(" + def + ");"
    67  	for rows.Next() {
    68  		err = rows.Scan(pargs...)
    69  		if err != nil {
    70  			return err
    71  		}
    72  		for i, arg := range pargs {
    73  			vargs[i].Value = reflect.ValueOf(arg).Elem().Interface()
    74  		}
    75  		_, err = conn.ExecContext(ctx, insert, vargs)
    76  		if err != nil {
    77  			return err
    78  		}
    79  	}
    80  
    81  	err = rows.Err()
    82  	if err == io.EOF {
    83  		err = nil
    84  	}
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	err = tx.Commit()
    90  	if err != nil {
    91  		return err
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func inferSchema(conn *csvConn, header bool, names []string) (schemaType, error) {
    98  	fname := conn.f.Name()
    99  	tbl, err := csvutil.Open(fname)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	defer tbl.Close()
   104  	tbl.Reader.Comma = conn.cfg.Comma
   105  	tbl.Reader.Comment = conn.cfg.Comment
   106  
   107  	return inferSchemaFromTable(tbl, header, names)
   108  }
   109  
   110  func inferSchemaFromTable(tbl *csvutil.Table, header bool, names []string) (schemaType, error) {
   111  	var (
   112  		beg int64 = 0
   113  		end int64 = 1
   114  	)
   115  	if header {
   116  		end++
   117  	}
   118  	rows, err := tbl.ReadRows(beg, end)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  	defer rows.Close()
   123  
   124  	if header {
   125  		if !rows.Next() {
   126  			return nil, rows.Err()
   127  		}
   128  		if len(names) == 0 {
   129  			names = rows.Fields()
   130  		}
   131  	}
   132  
   133  	if !rows.Next() {
   134  		return nil, rows.Err()
   135  	}
   136  
   137  	return inferSchemaFromFields(rows.Fields(), names)
   138  }
   139  
   140  func inferSchemaFromFields(fields []string, names []string) (schemaType, error) {
   141  	if len(names) == 0 {
   142  		names = make([]string, len(fields))
   143  	}
   144  	schema := make(schemaType, len(fields))
   145  	for i, field := range fields {
   146  		var err error
   147  		name := names[i]
   148  		if name == "" {
   149  			name = fmt.Sprintf("var%d", i+1)
   150  		}
   151  
   152  		schema[i].n = name
   153  		_, err = strconv.ParseInt(field, 10, 64)
   154  		if err == nil {
   155  			schema[i].v = reflect.ValueOf(int64(0))
   156  			continue
   157  		}
   158  
   159  		_, err = strconv.ParseFloat(field, 64)
   160  		if err == nil {
   161  			schema[i].v = reflect.ValueOf(float64(0))
   162  			continue
   163  		}
   164  
   165  		schema[i].v = reflect.ValueOf("")
   166  	}
   167  	return schema, nil
   168  }
   169  
   170  type schemaType []struct {
   171  	v reflect.Value
   172  	n string
   173  }
   174  
   175  func (st *schemaType) Decl() string {
   176  	o := make([]string, 0, len(*st))
   177  	for _, v := range *st {
   178  		n := v.n
   179  		t := v.v.Type().Kind().String()
   180  		o = append(o, n+" "+t)
   181  	}
   182  	return strings.Join(o, ", ")
   183  }
   184  
   185  func (st *schemaType) Args() ([]driver.NamedValue, []any) {
   186  	vargs := make([]driver.NamedValue, len(*st))
   187  	pargs := make([]any, len(*st))
   188  	for i, v := range *st {
   189  		ptr := reflect.New(v.v.Type())
   190  		vargs[i] = driver.NamedValue{
   191  			Name:    v.n,
   192  			Ordinal: i + 1,
   193  			Value:   ptr.Elem().Interface(),
   194  		}
   195  		pargs[i] = ptr.Interface()
   196  	}
   197  	return vargs, pargs
   198  }
   199  
   200  func (st *schemaType) Def() string {
   201  	o := make([]string, len(*st))
   202  	for i := range *st {
   203  		o[i] = fmt.Sprintf("$%d", i+1)
   204  	}
   205  	return strings.Join(o, ", ")
   206  }