github.com/cilium/statedb@v0.3.2/script.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package statedb
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"iter"
    14  	"maps"
    15  	"os"
    16  	"regexp"
    17  	"slices"
    18  	"strings"
    19  	"time"
    20  
    21  	"github.com/cilium/hive"
    22  	"github.com/cilium/hive/script"
    23  	"github.com/liggitt/tabwriter"
    24  	"golang.org/x/time/rate"
    25  	"gopkg.in/yaml.v3"
    26  )
    27  
    28  func ScriptCommands(db *DB) hive.ScriptCmdOut {
    29  	subCmds := map[string]script.Cmd{
    30  		"tables":      TablesCmd(db),
    31  		"show":        ShowCmd(db),
    32  		"cmp":         CompareCmd(db),
    33  		"insert":      InsertCmd(db),
    34  		"delete":      DeleteCmd(db),
    35  		"get":         GetCmd(db),
    36  		"prefix":      PrefixCmd(db),
    37  		"list":        ListCmd(db),
    38  		"lowerbound":  LowerBoundCmd(db),
    39  		"watch":       WatchCmd(db),
    40  		"initialized": InitializedCmd(db),
    41  	}
    42  	subCmdsList := strings.Join(slices.Collect(maps.Keys(subCmds)), ", ")
    43  	return hive.NewScriptCmd(
    44  		"db",
    45  		script.Command(
    46  			script.CmdUsage{
    47  				Summary: "Inspect and manipulate StateDB",
    48  				Args:    "cmd args...",
    49  				Detail: []string{
    50  					"Supported commands: " + subCmdsList,
    51  				},
    52  			},
    53  			func(s *script.State, args ...string) (script.WaitFunc, error) {
    54  				if len(args) < 1 {
    55  					return nil, fmt.Errorf("expected command (%s)", subCmdsList)
    56  				}
    57  				cmd, ok := subCmds[args[0]]
    58  				if !ok {
    59  					return nil, fmt.Errorf("command not found, expected one of %s", subCmdsList)
    60  				}
    61  				wf, err := cmd.Run(s, args[1:]...)
    62  				if errors.Is(err, errUsage) {
    63  					s.Logf("usage: db %s %s\n", args[0], cmd.Usage().Args)
    64  				}
    65  				return wf, err
    66  			},
    67  		),
    68  	)
    69  }
    70  
    71  var errUsage = errors.New("bad arguments")
    72  
    73  func TablesCmd(db *DB) script.Cmd {
    74  	return script.Command(
    75  		script.CmdUsage{
    76  			Summary: "Show StateDB tables",
    77  			Args:    "table",
    78  		},
    79  		func(s *script.State, args ...string) (script.WaitFunc, error) {
    80  			txn := db.ReadTxn()
    81  			tbls := db.GetTables(txn)
    82  			w := newTabWriter(s.LogWriter())
    83  			fmt.Fprintf(w, "Name\tObject count\tDeleted objects\tIndexes\tInitializers\tGo type\tLast WriteTxn\n")
    84  			for _, tbl := range tbls {
    85  				idxs := strings.Join(tbl.Indexes(), ", ")
    86  				fmt.Fprintf(w, "%s\t%d\t%d\t%s\t%v\t%T\t%s\n",
    87  					tbl.Name(), tbl.NumObjects(txn), tbl.numDeletedObjects(txn), idxs, tbl.PendingInitializers(txn), tbl.proto(), tbl.getAcquiredInfo())
    88  			}
    89  			w.Flush()
    90  			return nil, nil
    91  		},
    92  	)
    93  }
    94  
    95  func newCmdFlagSet() *flag.FlagSet {
    96  	return &flag.FlagSet{
    97  		// Disable showing the normal usage.
    98  		Usage: func() {},
    99  	}
   100  }
   101  
   102  func InitializedCmd(db *DB) script.Cmd {
   103  	return script.Command(
   104  		script.CmdUsage{
   105  			Summary: "Wait until all or specific tables have been initialized",
   106  			Args:    "(-timeout=<duration>) table...",
   107  		},
   108  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   109  			txn := db.ReadTxn()
   110  			allTbls := db.GetTables(txn)
   111  			tbls := allTbls
   112  
   113  			flags := newCmdFlagSet()
   114  			timeout := flags.Duration("timeout", 5*time.Second, "Maximum amount of time to wait for the table contents to match")
   115  			if err := flags.Parse(args); err != nil {
   116  				return nil, fmt.Errorf("%w: %s", errUsage, err)
   117  			}
   118  			timeoutChan := time.After(*timeout)
   119  			args = flags.Args()
   120  
   121  			if len(args) > 0 {
   122  				// Specific tables requested, look them up.
   123  				tbls = make([]TableMeta, 0, len(args))
   124  				for _, tableName := range args {
   125  					found := false
   126  					for _, tbl := range allTbls {
   127  						if tableName == tbl.Name() {
   128  							tbls = append(tbls, tbl)
   129  							found = true
   130  							break
   131  						}
   132  					}
   133  					if !found {
   134  						return nil, fmt.Errorf("table %q not found", tableName)
   135  					}
   136  				}
   137  			}
   138  
   139  			for _, tbl := range tbls {
   140  				init, watch := tbl.Initialized(txn)
   141  				if init {
   142  					s.Logf("%s initialized\n", tbl.Name())
   143  					continue
   144  				}
   145  				s.Logf("Waiting for %s to initialize (%v)...\n", tbl.Name(), tbl.PendingInitializers(txn))
   146  				select {
   147  				case <-s.Context().Done():
   148  					return nil, s.Context().Err()
   149  				case <-timeoutChan:
   150  					return nil, fmt.Errorf("timed out")
   151  				case <-watch:
   152  					s.Logf("%s initialized\n", tbl.Name())
   153  				}
   154  			}
   155  			return nil, nil
   156  		},
   157  	)
   158  }
   159  
   160  func ShowCmd(db *DB) script.Cmd {
   161  	return script.Command(
   162  		script.CmdUsage{
   163  			Summary: "Show table",
   164  			Args:    "(-o=<file>) (-columns=col1,...) (-format={table,yaml,json}) table",
   165  		},
   166  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   167  			flags := newCmdFlagSet()
   168  			file := flags.String("o", "", "File to write to instead of stdout")
   169  			columns := flags.String("columns", "", "Comma-separated list of columns to write")
   170  			format := flags.String("format", "table", "Format to write in (table, yaml, json)")
   171  			if err := flags.Parse(args); err != nil {
   172  				return nil, fmt.Errorf("%w: %s", errUsage, err)
   173  			}
   174  
   175  			var cols []string
   176  			if len(*columns) > 0 {
   177  				cols = strings.Split(*columns, ",")
   178  			}
   179  
   180  			args = flags.Args()
   181  			if len(args) < 1 {
   182  				return nil, fmt.Errorf("%w: missing table name", errUsage)
   183  			}
   184  			tableName := args[0]
   185  			return func(*script.State) (stdout, stderr string, err error) {
   186  				var buf strings.Builder
   187  				var w io.Writer
   188  				if *file == "" {
   189  					w = &buf
   190  				} else {
   191  					f, err := os.OpenFile(s.Path(*file), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   192  					if err != nil {
   193  						return "", "", fmt.Errorf("OpenFile(%s): %w", *file, err)
   194  					}
   195  					defer f.Close()
   196  					w = f
   197  				}
   198  				tbl, txn, err := getTable(db, tableName)
   199  				if err != nil {
   200  					return "", "", err
   201  				}
   202  				err = writeObjects(tbl, tbl.All(txn), w, cols, *format)
   203  				return buf.String(), "", err
   204  			}, nil
   205  		})
   206  }
   207  
   208  func CompareCmd(db *DB) script.Cmd {
   209  	return script.Command(
   210  		script.CmdUsage{
   211  			Summary: "Compare table",
   212  			Args:    "(-timeout=<dur>) (-grep=<pattern>) table file",
   213  		},
   214  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   215  			flags := newCmdFlagSet()
   216  			timeout := flags.Duration("timeout", time.Second, "Maximum amount of time to wait for the table contents to match")
   217  			grep := flags.String("grep", "", "Grep the result rows and only compare matching ones")
   218  			err := flags.Parse(args)
   219  			args = flags.Args()
   220  			if err != nil || len(args) != 2 {
   221  				return nil, fmt.Errorf("%w: %s", errUsage, err)
   222  			}
   223  
   224  			var grepRe *regexp.Regexp
   225  			if *grep != "" {
   226  				grepRe, err = regexp.Compile(*grep)
   227  				if err != nil {
   228  					return nil, fmt.Errorf("bad grep: %w", err)
   229  				}
   230  			}
   231  
   232  			tableName := args[0]
   233  
   234  			txn := db.ReadTxn()
   235  			meta := db.GetTable(txn, tableName)
   236  			if meta == nil {
   237  				return nil, fmt.Errorf("table %q not found", tableName)
   238  			}
   239  			tbl := AnyTable{Meta: meta}
   240  			header := tbl.TableHeader()
   241  
   242  			data, err := os.ReadFile(s.Path(args[1]))
   243  			if err != nil {
   244  				return nil, fmt.Errorf("ReadFile(%s): %w", args[1], err)
   245  			}
   246  			lines := strings.Split(string(data), "\n")
   247  			lines = slices.DeleteFunc(lines, func(line string) bool {
   248  				return strings.TrimSpace(line) == ""
   249  			})
   250  			if len(lines) < 1 {
   251  				return nil, fmt.Errorf("%q missing header line, e.g. %q", args[1], strings.Join(header, " "))
   252  			}
   253  
   254  			columnNames, columnPositions := splitHeaderLine(lines[0])
   255  			columnIndexes, err := getColumnIndexes(columnNames, header)
   256  			if err != nil {
   257  				return nil, err
   258  			}
   259  			lines = lines[1:]
   260  			origLines := lines
   261  			timeoutChan := time.After(*timeout)
   262  
   263  			for {
   264  				lines = origLines
   265  
   266  				// Create the diff between 'lines' and the rows in the table.
   267  				equal := true
   268  				var diff bytes.Buffer
   269  				w := newTabWriter(&diff)
   270  				fmt.Fprintf(w, "  %s\n", joinByPositions(columnNames, columnPositions))
   271  
   272  				objs, watch := tbl.AllWatch(db.ReadTxn())
   273  				for obj := range objs {
   274  					rowRaw := takeColumns(obj.(TableWritable).TableRow(), columnIndexes)
   275  					row := joinByPositions(rowRaw, columnPositions)
   276  					if grepRe != nil && !grepRe.Match([]byte(row)) {
   277  						continue
   278  					}
   279  
   280  					if len(lines) == 0 {
   281  						equal = false
   282  						fmt.Fprintf(w, "- %s\n", row)
   283  						continue
   284  					}
   285  					line := lines[0]
   286  					splitLine := splitByPositions(line, columnPositions)
   287  
   288  					if slices.Equal(rowRaw, splitLine) {
   289  						fmt.Fprintf(w, "  %s\n", row)
   290  					} else {
   291  						fmt.Fprintf(w, "- %s\n", row)
   292  						fmt.Fprintf(w, "+ %s\n", line)
   293  						equal = false
   294  					}
   295  					lines = lines[1:]
   296  				}
   297  				for _, line := range lines {
   298  					fmt.Fprintf(w, "+ %s\n", line)
   299  					equal = false
   300  				}
   301  				if equal {
   302  					return nil, nil
   303  				}
   304  				w.Flush()
   305  
   306  				select {
   307  				case <-s.Context().Done():
   308  					return nil, s.Context().Err()
   309  
   310  				case <-timeoutChan:
   311  					return nil, fmt.Errorf("table mismatch:\n%s", diff.String())
   312  
   313  				case <-watch:
   314  				}
   315  			}
   316  		})
   317  }
   318  
   319  func InsertCmd(db *DB) script.Cmd {
   320  	return script.Command(
   321  		script.CmdUsage{
   322  			Summary: "Insert object into a table",
   323  			Args:    "table path...",
   324  		},
   325  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   326  			return insertOrDelete(true, db, s, args...)
   327  		},
   328  	)
   329  }
   330  
   331  func DeleteCmd(db *DB) script.Cmd {
   332  	return script.Command(
   333  		script.CmdUsage{
   334  			Summary: "Delete an object from the table",
   335  			Args:    "table path...",
   336  		},
   337  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   338  			return insertOrDelete(false, db, s, args...)
   339  		},
   340  	)
   341  }
   342  
   343  func getTable(db *DB, tableName string) (*AnyTable, ReadTxn, error) {
   344  	txn := db.ReadTxn()
   345  	meta := db.GetTable(txn, tableName)
   346  	if meta == nil {
   347  		return nil, nil, fmt.Errorf("table %q not found", tableName)
   348  	}
   349  	return &AnyTable{Meta: meta}, txn, nil
   350  }
   351  
   352  func insertOrDelete(insert bool, db *DB, s *script.State, args ...string) (script.WaitFunc, error) {
   353  	if len(args) < 2 {
   354  		return nil, fmt.Errorf("%w: expected table and path(s)", errUsage)
   355  	}
   356  
   357  	tbl, _, err := getTable(db, args[0])
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  
   362  	wtxn := db.WriteTxn(tbl.Meta)
   363  	defer wtxn.Commit()
   364  
   365  	for _, arg := range args[1:] {
   366  		data, err := os.ReadFile(s.Path(arg))
   367  		if err != nil {
   368  			return nil, fmt.Errorf("ReadFile(%s): %w", arg, err)
   369  		}
   370  		parts := strings.Split(string(data), "---")
   371  		for _, part := range parts {
   372  			obj, err := tbl.UnmarshalYAML([]byte(part))
   373  			if err != nil {
   374  				return nil, fmt.Errorf("Unmarshal(%s): %w", arg, err)
   375  			}
   376  			if insert {
   377  				_, _, err = tbl.Insert(wtxn, obj)
   378  				if err != nil {
   379  					return nil, fmt.Errorf("Insert(%s): %w", arg, err)
   380  				}
   381  			} else {
   382  				_, _, err = tbl.Delete(wtxn, obj)
   383  				if err != nil {
   384  					return nil, fmt.Errorf("Delete(%s): %w", arg, err)
   385  				}
   386  
   387  			}
   388  		}
   389  	}
   390  	return nil, nil
   391  }
   392  
   393  func PrefixCmd(db *DB) script.Cmd {
   394  	return queryCmd(db, queryCmdPrefix, "Query table by prefix")
   395  }
   396  
   397  func LowerBoundCmd(db *DB) script.Cmd {
   398  	return queryCmd(db, queryCmdLowerBound, "Query table by lower bound search")
   399  }
   400  
   401  func ListCmd(db *DB) script.Cmd {
   402  	return queryCmd(db, queryCmdList, "List objects in the table")
   403  }
   404  
   405  func GetCmd(db *DB) script.Cmd {
   406  	return queryCmd(db, queryCmdGet, "Get the first matching object")
   407  }
   408  
   409  const (
   410  	queryCmdList = iota
   411  	queryCmdPrefix
   412  	queryCmdLowerBound
   413  	queryCmdGet
   414  )
   415  
   416  func queryCmd(db *DB, query int, summary string) script.Cmd {
   417  	return script.Command(
   418  		script.CmdUsage{
   419  			Summary: summary,
   420  			Args:    "(-o=<file>) (-columns=col1,...) (-format={table*,yaml,json}) (-index=<index>) table key",
   421  		},
   422  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   423  			return runQueryCmd(query, db, s, args)
   424  		},
   425  	)
   426  }
   427  
   428  func runQueryCmd(query int, db *DB, s *script.State, args []string) (script.WaitFunc, error) {
   429  	flags := newCmdFlagSet()
   430  	file := flags.String("o", "", "File to write results to instead of stdout")
   431  	index := flags.String("index", "", "Index to query")
   432  	format := flags.String("format", "table", "Format to write in (table, yaml, json)")
   433  	columns := flags.String("columns", "", "Comma-separated list of columns to write")
   434  	delete := flags.Bool("delete", false, "Delete all matching objects")
   435  	if err := flags.Parse(args); err != nil {
   436  		return nil, fmt.Errorf("%w: %s", errUsage, err)
   437  	}
   438  
   439  	var cols []string
   440  	if len(*columns) > 0 {
   441  		cols = strings.Split(*columns, ",")
   442  	}
   443  
   444  	args = flags.Args()
   445  	if len(args) < 2 {
   446  		return nil, fmt.Errorf("%w: expected table and key", errUsage)
   447  	}
   448  
   449  	return func(*script.State) (stdout, stderr string, err error) {
   450  		tbl, txn, err := getTable(db, args[0])
   451  		if err != nil {
   452  			return "", "", err
   453  		}
   454  
   455  		var buf strings.Builder
   456  		var w io.Writer
   457  		if *file == "" {
   458  			w = &buf
   459  		} else {
   460  			f, err := os.OpenFile(s.Path(*file), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   461  			if err != nil {
   462  				return "", "", fmt.Errorf("OpenFile(%s): %s", *file, err)
   463  			}
   464  			defer f.Close()
   465  			w = f
   466  		}
   467  
   468  		var it iter.Seq2[any, uint64]
   469  		switch query {
   470  		case queryCmdList:
   471  			it, err = tbl.List(txn, *index, args[1])
   472  		case queryCmdLowerBound:
   473  			it, err = tbl.LowerBound(txn, *index, args[1])
   474  		case queryCmdPrefix:
   475  			it, err = tbl.Prefix(txn, *index, args[1])
   476  		case queryCmdGet:
   477  			it, err = tbl.List(txn, *index, args[1])
   478  			if err == nil {
   479  				it = firstOfSeq2(it)
   480  			}
   481  		default:
   482  			panic("unknown query enum")
   483  		}
   484  		if err != nil {
   485  			return "", "", fmt.Errorf("query: %w", err)
   486  		}
   487  
   488  		err = writeObjects(tbl, it, w, cols, *format)
   489  		if err != nil {
   490  			return "", "", err
   491  		}
   492  
   493  		if *delete {
   494  			wtxn := db.WriteTxn(tbl.Meta)
   495  			count := 0
   496  			for obj := range it {
   497  				_, hadOld, err := tbl.Delete(wtxn, obj)
   498  				if err != nil {
   499  					wtxn.Abort()
   500  					return "", "", err
   501  				}
   502  				if hadOld {
   503  					count++
   504  				}
   505  			}
   506  			s.Logf("Deleted %d objects\n", count)
   507  			wtxn.Commit()
   508  		}
   509  
   510  		return buf.String(), "", err
   511  	}, nil
   512  }
   513  
   514  func WatchCmd(db *DB) script.Cmd {
   515  	return script.Command(
   516  		script.CmdUsage{
   517  			Summary: "Watch a table for changes",
   518  			Args:    "table",
   519  		},
   520  		func(s *script.State, args ...string) (script.WaitFunc, error) {
   521  			if len(args) < 1 {
   522  				return nil, fmt.Errorf("expected table name")
   523  			}
   524  
   525  			tbl, _, err := getTable(db, args[0])
   526  			if err != nil {
   527  				return nil, err
   528  			}
   529  			wtxn := db.WriteTxn(tbl.Meta)
   530  			iter, err := tbl.Changes(wtxn)
   531  			wtxn.Commit()
   532  			if err != nil {
   533  				return nil, err
   534  			}
   535  
   536  			header := tbl.TableHeader()
   537  			if header == nil {
   538  				return nil, fmt.Errorf("objects in table %q not TableWritable", tbl.Meta.Name())
   539  			}
   540  			tw := newTabWriter(&strikethroughWriter{w: s.LogWriter()})
   541  			fmt.Fprintf(tw, "%s\n", strings.Join(header, "\t"))
   542  
   543  			limiter := rate.NewLimiter(10.0, 1)
   544  			for {
   545  				if err := limiter.Wait(s.Context()); err != nil {
   546  					break
   547  				}
   548  				changes, watch := iter.nextAny(db.ReadTxn())
   549  				for change := range changes {
   550  					row := change.Object.(TableWritable).TableRow()
   551  					if change.Deleted {
   552  						fmt.Fprintf(tw, "%s (deleted)%s", strings.Join(row, "\t"), magicStrikethroughNewline)
   553  					} else {
   554  						fmt.Fprintf(tw, "%s\n", strings.Join(row, "\t"))
   555  					}
   556  				}
   557  				tw.Flush()
   558  				if err := s.FlushLog(); err != nil {
   559  					return nil, err
   560  				}
   561  				select {
   562  				case <-watch:
   563  				case <-s.Context().Done():
   564  					return nil, nil
   565  				}
   566  			}
   567  			return nil, nil
   568  
   569  		},
   570  	)
   571  }
   572  
   573  func firstOfSeq2[A, B any](it iter.Seq2[A, B]) iter.Seq2[A, B] {
   574  	return func(yield func(a A, b B) bool) {
   575  		for a, b := range it {
   576  			yield(a, b)
   577  			break
   578  		}
   579  	}
   580  }
   581  
   582  func writeObjects(tbl *AnyTable, it iter.Seq2[any, Revision], w io.Writer, columns []string, format string) error {
   583  	if len(columns) > 0 && format != "table" {
   584  		return fmt.Errorf("-columns not supported with non-table formats")
   585  	}
   586  	switch format {
   587  	case "yaml":
   588  		sep := []byte("---\n")
   589  		first := true
   590  		for obj := range it {
   591  			if !first {
   592  				w.Write(sep)
   593  			}
   594  			first = false
   595  
   596  			out, err := yaml.Marshal(obj)
   597  			if err != nil {
   598  				return fmt.Errorf("yaml.Marshal: %w", err)
   599  			}
   600  			if _, err := w.Write(out); err != nil {
   601  				return err
   602  			}
   603  		}
   604  		return nil
   605  	case "json":
   606  		sep := []byte("\n")
   607  		first := true
   608  		for obj := range it {
   609  			if !first {
   610  				w.Write(sep)
   611  			}
   612  			first = false
   613  
   614  			out, err := json.Marshal(obj)
   615  			if err != nil {
   616  				return fmt.Errorf("json.Marshal: %w", err)
   617  			}
   618  			if _, err := w.Write(out); err != nil {
   619  				return err
   620  			}
   621  		}
   622  		return nil
   623  	case "table":
   624  		header := tbl.TableHeader()
   625  		if header == nil {
   626  			return fmt.Errorf("objects in table %q not TableWritable", tbl.Meta.Name())
   627  		}
   628  
   629  		var idxs []int
   630  		var err error
   631  		if len(columns) > 0 {
   632  			idxs, err = getColumnIndexes(columns, header)
   633  			header = columns
   634  		} else {
   635  			idxs, err = getColumnIndexes(header, header)
   636  		}
   637  		if err != nil {
   638  			return err
   639  		}
   640  		tw := newTabWriter(w)
   641  		fmt.Fprintf(tw, "%s\n", strings.Join(header, "\t"))
   642  
   643  		for obj := range it {
   644  			row := takeColumns(obj.(TableWritable).TableRow(), idxs)
   645  			fmt.Fprintf(tw, "%s\n", strings.Join(row, "\t"))
   646  		}
   647  		return tw.Flush()
   648  	}
   649  	return fmt.Errorf("unknown format %q, expected table, yaml or json", format)
   650  }
   651  
   652  func takeColumns[T any](xs []T, idxs []int) (out []T) {
   653  	for _, idx := range idxs {
   654  		out = append(out, xs[idx])
   655  	}
   656  	return
   657  }
   658  
   659  func getColumnIndexes(names []string, header []string) ([]int, error) {
   660  	columnIndexes := make([]int, 0, len(header))
   661  loop:
   662  	for _, name := range names {
   663  		for i, name2 := range header {
   664  			if strings.EqualFold(name, name2) {
   665  				columnIndexes = append(columnIndexes, i)
   666  				continue loop
   667  			}
   668  		}
   669  		return nil, fmt.Errorf("column %q not part of %v", name, header)
   670  	}
   671  	return columnIndexes, nil
   672  }
   673  
   674  // splitHeaderLine takes a header of column names separated by any
   675  // number of whitespaces and returns the names and their starting positions.
   676  // e.g. "Foo  Bar Baz" would result in ([Foo,Bar,Baz],[0,5,9]).
   677  // With this information we can take a row in the database and format it
   678  // the same way as our test data.
   679  func splitHeaderLine(line string) (names []string, pos []int) {
   680  	start := 0
   681  	skip := true
   682  	for i, r := range line {
   683  		switch r {
   684  		case ' ', '\t':
   685  			if !skip {
   686  				names = append(names, line[start:i])
   687  				pos = append(pos, start)
   688  				start = -1
   689  			}
   690  			skip = true
   691  		default:
   692  			skip = false
   693  			if start == -1 {
   694  				start = i
   695  			}
   696  		}
   697  	}
   698  	if start >= 0 && start < len(line) {
   699  		names = append(names, line[start:])
   700  		pos = append(pos, start)
   701  	}
   702  	return
   703  }
   704  
   705  // splitByPositions takes a "row" line and the positions of the header columns
   706  // and extracts the values.
   707  // e.g. if we have the positions [0,5,9] (from header "Foo  Bar Baz") and
   708  // line is "1    a   b", then we'd extract [1,a,b].
   709  // The whitespace on the right of the start position (e.g. "1  \t") is trimmed.
   710  // This of course requires that the table is properly formatted in a way that the
   711  // header columns are indented to fit the data exactly.
   712  func splitByPositions(line string, positions []int) []string {
   713  	out := make([]string, 0, len(positions))
   714  	start := 0
   715  	for _, pos := range positions[1:] {
   716  		if start >= len(line) {
   717  			out = append(out, "")
   718  			start = len(line)
   719  			continue
   720  		}
   721  		out = append(out, strings.TrimRight(line[start:min(pos, len(line))], " \t"))
   722  		start = pos
   723  	}
   724  	out = append(out, strings.TrimRight(line[min(start, len(line)):], " \t"))
   725  	return out
   726  }
   727  
   728  // joinByPositions is the reverse of splitByPositions, it takes the columns of a
   729  // row and the starting positions of each and joins into a single line.
   730  // e.g. [1,a,b] and positions [0,5,9] expands to "1    a   b".
   731  // NOTE: This does not deal well with mixing tabs and spaces. The test input
   732  // data should preferably just use spaces.
   733  func joinByPositions(row []string, positions []int) string {
   734  	var w strings.Builder
   735  	prev := 0
   736  	for i, pos := range positions {
   737  		for pad := pos - prev; pad > 0; pad-- {
   738  			w.WriteByte(' ')
   739  		}
   740  		w.WriteString(row[i])
   741  		prev = pos + len(row[i])
   742  	}
   743  	return w.String()
   744  }
   745  
   746  // strikethroughWriter writes a line of text that is striken through
   747  // if the line contains the magic character at the end before \n.
   748  // This is used to strike through a tab-formatted line without messing
   749  // up with the widths of the cells.
   750  type strikethroughWriter struct {
   751  	buf           []byte
   752  	strikethrough bool
   753  	w             io.Writer
   754  }
   755  
   756  var (
   757  	// Magic character to use at the end of the line to denote that this should be
   758  	// striken through.
   759  	// This is to avoid messing up the width calculations in the tab writer, which
   760  	// would happen if ANSI codes were used directly.
   761  	magicStrikethrough        = byte('\xfe')
   762  	magicStrikethroughNewline = "\xfe\n"
   763  )
   764  
   765  func stripTrailingWhitespace(buf []byte) []byte {
   766  	idx := bytes.LastIndexFunc(
   767  		buf,
   768  		func(r rune) bool {
   769  			return r != ' ' && r != '\t'
   770  		},
   771  	)
   772  	if idx > 0 {
   773  		return buf[:idx+1]
   774  	}
   775  	return buf
   776  }
   777  
   778  func (s *strikethroughWriter) Write(p []byte) (n int, err error) {
   779  	write := func(bs []byte) {
   780  		if err == nil {
   781  			_, e := s.w.Write(bs)
   782  			if e != nil {
   783  				err = e
   784  			}
   785  		}
   786  	}
   787  	for _, c := range p {
   788  		switch c {
   789  		case '\n':
   790  			s.buf = stripTrailingWhitespace(s.buf)
   791  
   792  			if s.strikethrough {
   793  				write(beginStrikethrough)
   794  				write(s.buf)
   795  				write(endStrikethrough)
   796  			} else {
   797  				write(s.buf)
   798  			}
   799  			write(newline)
   800  
   801  			s.buf = s.buf[:0] // reset len for reuse.
   802  			s.strikethrough = false
   803  
   804  			if err != nil {
   805  				return 0, err
   806  			}
   807  
   808  		case magicStrikethrough:
   809  			s.strikethrough = true
   810  
   811  		default:
   812  			s.buf = append(s.buf, c)
   813  		}
   814  	}
   815  	return len(p), nil
   816  }
   817  
   818  var (
   819  	// Use color red and the strikethrough escape
   820  	beginStrikethrough = []byte("\033[9m\033[31m")
   821  	endStrikethrough   = []byte("\033[0m")
   822  	newline            = []byte("\n")
   823  )
   824  
   825  var _ io.Writer = &strikethroughWriter{}
   826  
   827  func newTabWriter(out io.Writer) *tabwriter.Writer {
   828  	const (
   829  		minWidth = 5
   830  		width    = 4
   831  		padding  = 3
   832  		padChar  = ' '
   833  		flags    = tabwriter.RememberWidths
   834  	)
   835  	return tabwriter.NewWriter(out, minWidth, width, padding, padChar, flags)
   836  }