github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/ext/csv/csv.go (about)

     1  // Package csv provides a CSV virtual table.
     2  //
     3  // The CSV virtual table reads RFC 4180 formatted comma-separated values,
     4  // and returns that content as if it were rows and columns of an SQL table.
     5  //
     6  // https://sqlite.org/csv.html
     7  package csv
     8  
     9  import (
    10  	"bufio"
    11  	"encoding/csv"
    12  	"fmt"
    13  	"io"
    14  	"io/fs"
    15  	"strings"
    16  
    17  	"github.com/ncruces/go-sqlite3"
    18  	"github.com/ncruces/go-sqlite3/util/osutil"
    19  	"github.com/ncruces/go-sqlite3/util/vtabutil"
    20  )
    21  
    22  // Register registers the CSV virtual table.
    23  // If a filename is specified, [os.Open] is used to open the file.
    24  func Register(db *sqlite3.Conn) {
    25  	RegisterFS(db, osutil.FS{})
    26  }
    27  
    28  // RegisterFS registers the CSV virtual table.
    29  // If a filename is specified, fsys is used to open the file.
    30  func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
    31  	declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
    32  		var (
    33  			filename string
    34  			data     string
    35  			schema   string
    36  			header   bool
    37  			columns  int  = -1
    38  			comma    rune = ','
    39  
    40  			done = map[string]struct{}{}
    41  		)
    42  
    43  		for _, arg := range arg {
    44  			key, val := vtabutil.NamedArg(arg)
    45  			if _, ok := done[key]; ok {
    46  				return nil, fmt.Errorf("csv: more than one %q parameter", key)
    47  			}
    48  			switch key {
    49  			case "filename":
    50  				filename = vtabutil.Unquote(val)
    51  			case "data":
    52  				data = vtabutil.Unquote(val)
    53  			case "schema":
    54  				schema = vtabutil.Unquote(val)
    55  			case "header":
    56  				header, err = boolArg(key, val)
    57  			case "columns":
    58  				columns, err = uintArg(key, val)
    59  			case "comma":
    60  				comma, err = runeArg(key, val)
    61  			default:
    62  				return nil, fmt.Errorf("csv: unknown %q parameter", key)
    63  			}
    64  			if err != nil {
    65  				return nil, err
    66  			}
    67  			done[key] = struct{}{}
    68  		}
    69  
    70  		if (filename == "") == (data == "") {
    71  			return nil, fmt.Errorf(`csv: must specify either "filename" or "data" but not both`)
    72  		}
    73  
    74  		table := &table{
    75  			fsys:   fsys,
    76  			name:   filename,
    77  			data:   data,
    78  			comma:  comma,
    79  			header: header,
    80  		}
    81  
    82  		if schema == "" {
    83  			var row []string
    84  			if header || columns < 0 {
    85  				csv, c, err := table.newReader()
    86  				defer c.Close()
    87  				if err != nil {
    88  					return nil, err
    89  				}
    90  				row, err = csv.Read()
    91  				if err != nil {
    92  					return nil, err
    93  				}
    94  			}
    95  			schema = getSchema(header, columns, row)
    96  		}
    97  
    98  		err = db.DeclareVTab(schema)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		err = db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		return table, nil
   107  	}
   108  
   109  	sqlite3.CreateModule(db, "csv", declare, declare)
   110  }
   111  
   112  type table struct {
   113  	fsys   fs.FS
   114  	name   string
   115  	data   string
   116  	comma  rune
   117  	header bool
   118  }
   119  
   120  func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
   121  	idx.EstimatedCost = 1e6
   122  	return nil
   123  }
   124  
   125  func (t *table) Open() (sqlite3.VTabCursor, error) {
   126  	return &cursor{table: t}, nil
   127  }
   128  
   129  func (t *table) Rename(new string) error {
   130  	return nil
   131  }
   132  
   133  func (t *table) Integrity(schema, table string, flags int) error {
   134  	if flags&1 != 0 {
   135  		return nil
   136  	}
   137  	csv, c, err := t.newReader()
   138  	if err != nil {
   139  		return err
   140  	}
   141  	defer c.Close()
   142  	_, err = csv.ReadAll()
   143  	return err
   144  }
   145  
   146  func (t *table) newReader() (*csv.Reader, io.Closer, error) {
   147  	var r io.Reader
   148  	var c io.Closer
   149  	if t.name != "" {
   150  		f, err := t.fsys.Open(t.name)
   151  		if err != nil {
   152  			return nil, f, err
   153  		}
   154  
   155  		buf := bufio.NewReader(f)
   156  		bom, err := buf.Peek(3)
   157  		if err != nil {
   158  			return nil, f, err
   159  		}
   160  		if string(bom) == "\xEF\xBB\xBF" {
   161  			buf.Discard(3)
   162  		}
   163  
   164  		r = buf
   165  		c = f
   166  	} else {
   167  		r = strings.NewReader(t.data)
   168  		c = io.NopCloser(r)
   169  	}
   170  
   171  	csv := csv.NewReader(r)
   172  	csv.ReuseRecord = true
   173  	csv.Comma = t.comma
   174  	return csv, c, nil
   175  }
   176  
   177  type cursor struct {
   178  	table  *table
   179  	closer io.Closer
   180  	csv    *csv.Reader
   181  	row    []string
   182  	rowID  int64
   183  }
   184  
   185  func (c *cursor) Close() (err error) {
   186  	if c.closer != nil {
   187  		err = c.closer.Close()
   188  		c.closer = nil
   189  	}
   190  	return err
   191  }
   192  
   193  func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
   194  	err := c.Close()
   195  	if err != nil {
   196  		return err
   197  	}
   198  
   199  	c.csv, c.closer, err = c.table.newReader()
   200  	if err != nil {
   201  		return err
   202  	}
   203  	if c.table.header {
   204  		c.Next() // skip header
   205  	}
   206  	c.rowID = 0
   207  	return c.Next()
   208  }
   209  
   210  func (c *cursor) Next() (err error) {
   211  	c.rowID++
   212  	c.row, err = c.csv.Read()
   213  	if err != io.EOF {
   214  		return err
   215  	}
   216  	return nil
   217  }
   218  
   219  func (c *cursor) EOF() bool {
   220  	return c.row == nil
   221  }
   222  
   223  func (c *cursor) RowID() (int64, error) {
   224  	return c.rowID, nil
   225  }
   226  
   227  func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
   228  	if col < len(c.row) {
   229  		ctx.ResultText(c.row[col])
   230  	}
   231  	return nil
   232  }