go-hep.org/x/hep@v0.38.1/csvutil/csvdriver/driver.go (about)

     1  // Copyright ©2016 The go-hep Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // package csvdriver registers a database/sql/driver.Driver implementation for CSV files.
     6  package csvdriver // import "go-hep.org/x/hep/csvutil/csvdriver"
     7  
     8  import (
     9  	"context"
    10  	"database/sql"
    11  	"database/sql/driver"
    12  	"encoding/json"
    13  	"fmt"
    14  	"io"
    15  	"net/http"
    16  	"os"
    17  	"strings"
    18  	"sync"
    19  
    20  	_ "modernc.org/ql/driver"
    21  )
    22  
    23  var (
    24  	_ driver.Driver             = (*csvDriver)(nil)
    25  	_ drvConn                   = (*csvConn)(nil)
    26  	_ driver.ExecerContext      = (*csvConn)(nil)
    27  	_ driver.QueryerContext     = (*csvConn)(nil)
    28  	_ driver.ConnBeginTx        = (*csvConn)(nil)
    29  	_ driver.ConnPrepareContext = (*csvConn)(nil)
    30  	_ driver.Tx                 = (*csvConn)(nil)
    31  )
    32  
    33  type drvConn interface {
    34  	driver.Conn
    35  	driver.ConnBeginTx
    36  	driver.ConnPrepareContext
    37  }
    38  
    39  // Conn describes how a connection to the CSV-driver should be established.
    40  type Conn struct {
    41  	File    string      `json:"file"`    // name of the file to be open
    42  	Mode    int         `json:"mode"`    // r/w mode (default: read-only)
    43  	Perm    os.FileMode `json:"perm"`    // file permissions
    44  	Comma   rune        `json:"comma"`   // field delimiter (default: ',')
    45  	Comment rune        `json:"comment"` // comment character for start of line (default: '#')
    46  	Header  bool        `json:"header"`  // whether the CSV-file has a column header
    47  	Names   []string    `json:"names"`   // column names
    48  }
    49  
    50  func (c *Conn) setDefaults() {
    51  	if c.Mode == 0 {
    52  		c.Mode = os.O_RDONLY
    53  		c.Perm = 0
    54  	}
    55  	if c.Comma == 0 {
    56  		c.Comma = ','
    57  	}
    58  	if c.Comment == 0 {
    59  		c.Comment = '#'
    60  	}
    61  }
    62  
    63  func (c Conn) toJSON() (string, error) {
    64  	c.setDefaults()
    65  	buf, err := json.Marshal(c)
    66  	if err != nil {
    67  		return "", err
    68  	}
    69  	return string(buf), err
    70  }
    71  
    72  // Open opens a database connection with the CSV driver.
    73  func (c Conn) Open() (*sql.DB, error) {
    74  	c.setDefaults()
    75  	str, err := c.toJSON()
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return sql.Open("csv", str)
    80  }
    81  
    82  // Open is a CSV-driver helper function for sql.Open.
    83  //
    84  // It opens a database connection to csvdriver.
    85  func Open(name string) (*sql.DB, error) {
    86  	c := Conn{File: name, Mode: os.O_RDONLY, Perm: 0}
    87  	return c.Open()
    88  }
    89  
    90  // Create is a CSV-driver helper function for sql.Open.
    91  //
    92  // It creates a new CSV file, connected via the csvdriver.
    93  func Create(name string) (*sql.DB, error) {
    94  	c := Conn{
    95  		File: name,
    96  		Mode: os.O_RDWR | os.O_CREATE | os.O_TRUNC,
    97  		Perm: 0666,
    98  	}
    99  	return c.Open()
   100  }
   101  
   102  // csvDriver implements the interface required by database/sql/driver.
   103  type csvDriver struct {
   104  	dbs map[string]*csvConn
   105  	mu  sync.Mutex
   106  }
   107  
   108  // Open returns a new connection to the database.
   109  // The name is a string in a driver-specific format.
   110  //
   111  // Open may return a cached connection (one previously
   112  // closed), but doing so is unnecessary; the sql package
   113  // maintains a pool of idle connections for efficient re-use.
   114  //
   115  // The returned connection is only used by one goroutine at a
   116  // time.
   117  func (drv *csvDriver) Open(cfg string) (driver.Conn, error) {
   118  	c := Conn{}
   119  	if strings.HasPrefix(cfg, "{") {
   120  		err := json.Unmarshal([]byte(cfg), &c)
   121  		if err != nil {
   122  			return nil, err
   123  		}
   124  	} else {
   125  		c.File = cfg
   126  		c.setDefaults()
   127  	}
   128  
   129  	doImport := false
   130  	_, err := os.Lstat(c.File)
   131  	if err == nil {
   132  		doImport = true
   133  	}
   134  
   135  	drv.mu.Lock()
   136  	defer drv.mu.Unlock()
   137  	if drv.dbs == nil {
   138  		drv.dbs = make(map[string]*csvConn)
   139  	}
   140  	conn := drv.dbs[c.File]
   141  	if conn == nil {
   142  		var f *os.File
   143  		switch {
   144  		case strings.HasPrefix(c.File, "http://"), strings.HasPrefix(c.File, "https://"):
   145  			// FIXME(sbinet: check that c.Mode makes sense (ie: only reading)
   146  			resp, err := http.Get(c.File)
   147  			if err != nil {
   148  				return nil, err
   149  			}
   150  			defer resp.Body.Close()
   151  			// FIXME(sbinet): devise a mechanism to remove that temporary file
   152  			// when we close the connection.
   153  			f, err = os.CreateTemp("", "csvdriver-")
   154  			if err != nil {
   155  				return nil, err
   156  			}
   157  			_, err = io.CopyBuffer(f, resp.Body, make([]byte, 16*1024*1024))
   158  			if err != nil {
   159  				return nil, err
   160  			}
   161  			_, err = f.Seek(0, io.SeekStart)
   162  			if err != nil {
   163  				return nil, err
   164  			}
   165  			doImport = true
   166  		default:
   167  			// local file
   168  			f, err = os.OpenFile(c.File, c.Mode, c.Perm)
   169  			if err != nil {
   170  				return nil, err
   171  			}
   172  		}
   173  		conn = &csvConn{
   174  			f:    f,
   175  			cfg:  c,
   176  			drv:  drv,
   177  			refs: 0,
   178  		}
   179  
   180  		err = conn.initDB()
   181  		if err != nil {
   182  			return nil, err
   183  		}
   184  
   185  		if doImport {
   186  			err = conn.importCSV()
   187  			if err != nil {
   188  				return nil, err
   189  			}
   190  		}
   191  		drv.dbs[c.File] = conn
   192  	}
   193  	conn.refs++
   194  
   195  	return conn, err
   196  }
   197  
   198  type csvConn struct {
   199  	f    *os.File
   200  	cfg  Conn
   201  	drv  *csvDriver
   202  	refs int
   203  
   204  	conn  drvConn
   205  	exec  driver.ExecerContext
   206  	query driver.QueryerContext
   207  	tx    driver.Tx
   208  }
   209  
   210  func (conn *csvConn) initDB() error {
   211  	c, err := qlopen(conn.cfg.File)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	conn.conn = c.(drvConn)
   217  	conn.exec = c.(driver.ExecerContext)
   218  	conn.query = c.(driver.QueryerContext)
   219  	return nil
   220  }
   221  
   222  // PrepareContext returns a prepared statement, bound to this connection.
   223  func (conn *csvConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   224  	return conn.conn.PrepareContext(ctx, query)
   225  }
   226  
   227  // Prepare returns a prepared statement, bound to this connection.
   228  func (conn *csvConn) Prepare(query string) (driver.Stmt, error) {
   229  	return conn.conn.PrepareContext(context.Background(), query)
   230  }
   231  
   232  // Close invalidates and potentially stops any current
   233  // prepared statements and transactions, marking this
   234  // connection as no longer in use.
   235  //
   236  // Because the sql package maintains a free pool of
   237  // connections and only calls Close when there's a surplus of
   238  // idle connections, it shouldn't be necessary for drivers to
   239  // do their own connection caching.
   240  func (conn *csvConn) Close() error {
   241  	if conn.refs > 1 {
   242  		conn.refs--
   243  		return nil
   244  	}
   245  	var err error
   246  	defer conn.f.Close()
   247  
   248  	// FIXME(sbinet) write-back to file if needed.
   249  	// err = conn.exportCSV()
   250  
   251  	err = conn.conn.Close()
   252  	if err != nil {
   253  		return err
   254  	}
   255  
   256  	err = conn.f.Close()
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	conn.drv.mu.Lock()
   262  	if conn.refs == 1 {
   263  		delete(conn.drv.dbs, conn.f.Name())
   264  	}
   265  	conn.refs = 0
   266  	conn.drv.mu.Unlock()
   267  
   268  	return err
   269  }
   270  
   271  // BeginTx starts and returns a new transaction.
   272  func (conn *csvConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
   273  	tx, err := conn.conn.BeginTx(ctx, opts)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  	conn.tx = tx
   278  	return tx, err
   279  }
   280  
   281  // Begin starts and returns a new transaction.
   282  func (conn *csvConn) Begin() (driver.Tx, error) {
   283  	return conn.BeginTx(context.Background(), driver.TxOptions{})
   284  }
   285  
   286  func (conn *csvConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   287  	return conn.exec.ExecContext(ctx, query, args)
   288  }
   289  
   290  func (conn *csvConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   291  	rows, err := conn.query.QueryContext(ctx, query, args)
   292  	if err != nil {
   293  		return nil, err
   294  	}
   295  	return rows, err
   296  }
   297  
   298  func (conn *csvConn) Commit() error {
   299  	if conn.tx == nil {
   300  		return fmt.Errorf("csvdriver: commit while not in transaction")
   301  	}
   302  	err := conn.tx.Commit()
   303  	conn.tx = nil
   304  	return err
   305  }
   306  
   307  func (conn *csvConn) Rollback() error {
   308  	if conn.tx == nil {
   309  		return fmt.Errorf("csvdriver: rollback while not in transaction")
   310  	}
   311  	err := conn.tx.Rollback()
   312  	conn.tx = nil
   313  	return err
   314  }
   315  
   316  func qlopen(name string) (driver.Conn, error) {
   317  	conn, err := qldrv.Open("memory://" + name)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	return conn, nil
   323  }
   324  
   325  var (
   326  	qldrv driver.Driver
   327  )
   328  
   329  func init() {
   330  	sql.Register("csv", &csvDriver{})
   331  
   332  	db, err := sql.Open("ql", "memory:///dev/null")
   333  	if err != nil {
   334  		panic(err)
   335  	}
   336  	defer db.Close()
   337  	qldrv = db.Driver()
   338  }