github.com/lmorg/murex@v0.0.0-20240217211045-e081c89cd4ef/builtins/optional/select/tables.go (about)

     1  package sqlselect
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"os"
     7  
     8  	"github.com/lmorg/murex/builtins/core/open"
     9  	"github.com/lmorg/murex/builtins/pipes/streams"
    10  	"github.com/lmorg/murex/debug"
    11  	"github.com/lmorg/murex/lang"
    12  	"github.com/lmorg/murex/lang/types"
    13  	"github.com/lmorg/murex/utils/humannumbers"
    14  )
    15  
    16  func loadTables(p *lang.Process, fromFile string, pipes, vars []string, parameters string, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings, confPrintHeadings bool, confDataType string) error {
    17  	var (
    18  		v      interface{}
    19  		dt     string
    20  		err    error
    21  		tables []string
    22  	)
    23  
    24  	db, err := createDb()
    25  	if err != nil {
    26  		return err
    27  	}
    28  
    29  	switch {
    30  	case len(pipes) > 0:
    31  		dt = confDataType
    32  		debug.Json("select pipes", pipes)
    33  		debug.Log(fromFile, parameters)
    34  		tables = pipes
    35  		for i := range pipes {
    36  			v, err = readPipe(p, pipes[i])
    37  			if err != nil {
    38  				return err
    39  			}
    40  
    41  			err = createTable(p, db, pipes[i], v, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings)
    42  			if err != nil {
    43  				return err
    44  			}
    45  		}
    46  
    47  	case len(vars) > 0:
    48  		dt = confDataType
    49  		debug.Json("select vars", vars)
    50  		debug.Log(fromFile, parameters)
    51  		tables = vars
    52  		for i := range vars {
    53  			v, err = readVariable(p, vars[i])
    54  			if err != nil {
    55  				return err
    56  			}
    57  
    58  			err = createTable(p, db, vars[i], v, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings)
    59  			if err != nil {
    60  				return err
    61  			}
    62  		}
    63  
    64  	default:
    65  		v, dt, err = readFile(p, fromFile)
    66  		if err != nil {
    67  			return err
    68  		}
    69  
    70  		err = createTable(p, db, "main", v, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings)
    71  		if err != nil {
    72  			return err
    73  		}
    74  	}
    75  
    76  	p.Stdout.SetDataType(dt)
    77  	return runQuery(p, db, dt, tables, parameters, confPrintHeadings)
    78  }
    79  
    80  func readPipe(p *lang.Process, name string) (interface{}, error) {
    81  	pipe, err := lang.GlobalPipes.Get(name)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	fork := p.Fork(0)
    87  	fork.Process.Stdin = pipe
    88  
    89  	dt := pipe.GetDataType()
    90  	v, err := lang.UnmarshalData(fork.Process, dt)
    91  	if err != nil {
    92  		return nil, fmt.Errorf("unable to unmarshal named pipe '%s': %s", name, err.Error())
    93  	}
    94  
    95  	return v, nil
    96  }
    97  
    98  func readVariable(p *lang.Process, name string) (interface{}, error) {
    99  	s, err := p.Variables.GetString(name)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	dt := p.Variables.GetDataType(name)
   104  
   105  	fork := p.Fork(lang.F_CREATE_STDIN)
   106  	fork.Process.Stdin.SetDataType(dt)
   107  	fork.Process.Stdin.Write([]byte(s))
   108  
   109  	v, err := lang.UnmarshalData(fork.Process, dt)
   110  	if err != nil {
   111  		return nil, fmt.Errorf("unable to unmarshal variable '%s': %s", name, err.Error())
   112  	}
   113  
   114  	return v, nil
   115  }
   116  
   117  func readFile(p *lang.Process, fromFile string) (interface{}, string, error) {
   118  	var (
   119  		v   interface{}
   120  		dt  string
   121  		err error
   122  	)
   123  
   124  	if p.IsMethod {
   125  		dt = p.Stdin.GetDataType()
   126  
   127  		v, err = lang.UnmarshalData(p, dt)
   128  		if err != nil {
   129  			return nil, "", fmt.Errorf("unable to unmarshal STDIN: %s", err.Error())
   130  		}
   131  
   132  	} else {
   133  		closers, err := open.OpenFile(p, &fromFile, &dt)
   134  		if err != nil {
   135  			return nil, "", err
   136  		}
   137  
   138  		f, err := os.Open(fromFile)
   139  		if err != nil {
   140  			return nil, "", err
   141  		}
   142  
   143  		fork := p.Fork(0)
   144  		fork.Process.Stdin = streams.NewReadCloser(f)
   145  
   146  		v, err = lang.UnmarshalData(fork.Process, dt)
   147  		if err != nil {
   148  			return nil, "", fmt.Errorf("unable to unmarshal %s: %s", fromFile, err.Error())
   149  		}
   150  
   151  		err = open.CloseFiles(closers)
   152  		if err != nil {
   153  			return nil, "", err
   154  		}
   155  	}
   156  
   157  	return v, dt, nil
   158  }
   159  
   160  func createTable(p *lang.Process, db *sql.DB, name string, v interface{}, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings bool) error {
   161  	debug.Log("Creating table:", name)
   162  	switch v := v.(type) {
   163  	case [][]string:
   164  		return createTable_SliceSliceString(p, db, name, v, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings)
   165  
   166  	case []interface{}:
   167  		table := make([][]string, len(v)+1)
   168  		i := 1
   169  		err := types.MapToTable(v, func(s []string) error {
   170  			table[i] = s
   171  			i++
   172  			return nil
   173  		})
   174  		if err != nil {
   175  			return err
   176  		}
   177  		return createTable_SliceSliceString(p, db, name, table, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings)
   178  
   179  	default:
   180  		return fmt.Errorf("unable to convert the following data structure into a table '%s': %T", name, v)
   181  	}
   182  }
   183  
   184  func createTable_SliceSliceString(p *lang.Process, db *sql.DB, name string, v [][]string, confFailColMismatch, confMergeTrailingColumns, confTableIncHeadings bool) error {
   185  	if len(v) == 0 {
   186  		return fmt.Errorf("no table found")
   187  	}
   188  
   189  	var (
   190  		tx       *sql.Tx
   191  		err      error
   192  		headings []string
   193  		nRow     int
   194  	)
   195  
   196  	if confTableIncHeadings {
   197  		headings = make([]string, len(v[0]))
   198  		for i := range headings {
   199  			headings[i] = fmt.Sprint(v[0][i])
   200  		}
   201  		tx, err = openTable(db, name, headings)
   202  		if err != nil {
   203  			return err
   204  		}
   205  		nRow = 1
   206  
   207  	} else {
   208  		headings = make([]string, len(v[0]))
   209  		for i := range headings {
   210  			headings[i] = humannumbers.ColumnLetter(i)
   211  		}
   212  		tx, err = openTable(db, name, headings)
   213  		if err != nil {
   214  			return err
   215  		}
   216  
   217  		slice := stringToInterfaceTrim(v[0], len(v))
   218  		err = insertRecords(tx, name, slice)
   219  		if err != nil {
   220  			return fmt.Errorf("unable to insert headings into sqlite3: %s", err.Error())
   221  		}
   222  		nRow = 1
   223  	}
   224  
   225  	for ; nRow < len(v); nRow++ {
   226  		if p.HasCancelled() {
   227  			return fmt.Errorf("cancelled")
   228  		}
   229  
   230  		if len(v[nRow]) != len(headings) && confFailColMismatch {
   231  			return fmt.Errorf("table rows contain a different number of columns to table headings\n%d: %s", nRow, v[nRow])
   232  		}
   233  
   234  		if confMergeTrailingColumns {
   235  			slice := stringToInterfaceMerge(v[nRow], len(headings))
   236  			err = insertRecords(tx, name, slice)
   237  			if err != nil {
   238  				return fmt.Errorf("%s\n%d: %s", err.Error(), nRow, v[nRow])
   239  			}
   240  		} else {
   241  			slice := stringToInterfaceTrim(v[nRow], len(headings))
   242  			err = insertRecords(tx, name, slice)
   243  			if err != nil {
   244  				return fmt.Errorf("%s\n%d: %s", err.Error(), nRow, v[nRow][:len(headings)-1])
   245  			}
   246  		}
   247  	}
   248  
   249  	err = tx.Commit()
   250  	if err != nil {
   251  		return fmt.Errorf("unable to commit sqlite3 transaction: %s", err.Error())
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  func runQuery(p *lang.Process, db *sql.DB, dt string, tables []string, parameters string, confPrintHeadings bool) error {
   258  	query := createQueryString(tables, parameters)
   259  	debug.Log(query)
   260  
   261  	rows, err := db.QueryContext(p.Context, query)
   262  	if err != nil {
   263  		return fmt.Errorf("cannot query table: %s\nSQL: %s", err.Error(), query)
   264  	}
   265  
   266  	r, err := rows.Columns()
   267  	if err != nil {
   268  		return fmt.Errorf("cannot query rows: %s", err.Error())
   269  	}
   270  
   271  	var (
   272  		table [][]string
   273  		nRow  int
   274  	)
   275  
   276  	if confPrintHeadings {
   277  		table = [][]string{r}
   278  		nRow++
   279  	}
   280  
   281  	for rows.Next() {
   282  		table = append(table, make([]string, len(r)))
   283  		slice := stringToInterfacePtr(&table[nRow], len(r))
   284  
   285  		err = rows.Scan(slice...)
   286  		if err != nil {
   287  			return err
   288  		}
   289  
   290  		nRow++
   291  	}
   292  	if err := rows.Err(); err != nil {
   293  		return fmt.Errorf("cannot retrieve rows: %s", err.Error())
   294  	}
   295  
   296  	b, err := lang.MarshalData(p, dt, table)
   297  	if err != nil {
   298  		return fmt.Errorf("unable to marshal STDOUT: %s", err.Error())
   299  	}
   300  
   301  	_, err = p.Stdout.Write(b)
   302  	return err
   303  }