github.com/GGP1/kure@v0.8.4/commands/util.go (about)

     1  package cmdutil
     2  
     3  import (
     4  	"crypto/rand"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/GGP1/kure/config"
    13  	"github.com/GGP1/kure/db/bucket"
    14  	"github.com/GGP1/kure/db/card"
    15  	"github.com/GGP1/kure/db/entry"
    16  	"github.com/GGP1/kure/db/file"
    17  	"github.com/GGP1/kure/db/totp"
    18  	"github.com/GGP1/kure/orderedmap"
    19  	"github.com/GGP1/kure/sig"
    20  	"github.com/GGP1/kure/terminal"
    21  
    22  	"github.com/atotto/clipboard"
    23  	"github.com/awnumar/memguard"
    24  	"github.com/pkg/errors"
    25  	"github.com/spf13/cobra"
    26  	"github.com/stretchr/testify/assert"
    27  	bolt "go.etcd.io/bbolt"
    28  )
    29  
    30  var (
    31  	// ErrInvalidLength is returned when generating a password/passphrase and the length passed is < 1.
    32  	ErrInvalidLength = errors.New("invalid length")
    33  	// ErrInvalidName is returned when a name is required and received "" or contains "//".
    34  	ErrInvalidName = errors.New("invalid name")
    35  	// ErrInvalidPath is returned when a path is required and received "".
    36  	ErrInvalidPath = errors.New("invalid path")
    37  )
    38  
    39  const (
    40  	// Card object
    41  	Card object = iota
    42  	// Entry object
    43  	Entry
    44  	// File object
    45  	File
    46  	// TOTP object
    47  	TOTP
    48  
    49  	// Box
    50  	hBar       = "─"
    51  	vBar       = "│"
    52  	upperLeft  = "╭"
    53  	lowerLeft  = "╰"
    54  	upperRight = "╮"
    55  	lowerRight = "╯"
    56  )
    57  
    58  // RunEFunc runs a cobra function returning an error.
    59  type RunEFunc func(cmd *cobra.Command, args []string) error
    60  
    61  type object int
    62  
    63  // BuildBox constructs a responsive box used to display records information.
    64  //
    65  //	┌──── Sample ────┐
    66  //	│ Key  │ Value   │
    67  //	└────────────────┘
    68  func BuildBox(name string, mp *orderedmap.Map) string {
    69  	var sb strings.Builder
    70  
    71  	// Do not use folders as part of the name
    72  	name = filepath.Base(name)
    73  	if !strings.Contains(name, ".") {
    74  		name = strings.Title(name)
    75  	}
    76  
    77  	nameLen := len([]rune(name))
    78  	longestKey := 0
    79  	longestValue := nameLen
    80  
    81  	// Range to take the longest key and value
    82  	// Keys will always be 1 byte characters
    83  	// Values may be 1, 2 or 3 bytes, to take the length use len([]rune(v))
    84  	for _, key := range mp.Keys() {
    85  		value := mp.Get(key) // Get key's value
    86  
    87  		// Take map's longest key
    88  		if len(key) > longestKey {
    89  			longestKey = len(key)
    90  		}
    91  
    92  		// Split each value by a new line (fields like Notes contain multiple lines)
    93  		for _, v := range strings.Split(value, "\n") {
    94  			lenV := len([]rune(v))
    95  
    96  			// Take map's longest value
    97  			if lenV > longestValue {
    98  				longestValue = lenV
    99  			}
   100  		}
   101  	}
   102  
   103  	// -4-: 2 spaces that wrap name and 2 corners
   104  	headerLen := longestKey + longestValue - nameLen + 4
   105  	headerHalfLen := headerLen / 2
   106  
   107  	// Left side header
   108  	sb.WriteString(upperLeft)
   109  	sb.WriteString(strings.Repeat(hBar, headerHalfLen))
   110  
   111  	// Header name
   112  	sb.WriteRune(' ')
   113  	sb.WriteString(name)
   114  	sb.WriteRune(' ')
   115  
   116  	// Adjust the right side of the header if its width is even
   117  	if headerLen%2 == 0 {
   118  		headerHalfLen--
   119  	}
   120  
   121  	// Right side header
   122  	sb.WriteString(strings.Repeat(hBar, headerHalfLen))
   123  	sb.WriteString(upperRight)
   124  	sb.WriteString("\n")
   125  
   126  	// Body
   127  	for _, key := range mp.Keys() {
   128  		value := mp.Get(key) // Get key's value
   129  		// Start
   130  		sb.WriteString(vBar)
   131  
   132  		// Key
   133  		sb.WriteRune(' ')
   134  		sb.WriteString(key)
   135  		sb.WriteRune(' ')
   136  		sb.WriteString(strings.Repeat(" ", longestKey-len(key))) // Padding
   137  
   138  		// Middle
   139  		sb.WriteString(vBar)
   140  
   141  		// Value
   142  		for i, v := range strings.Split(value, "\n") {
   143  			// In case the value contains multi-lines,
   144  			// repeat the process above but do not add the key
   145  			if i >= 1 {
   146  				sb.WriteString("\n")
   147  				sb.WriteString(vBar)
   148  				// -2- represents key leading and trailing spaces
   149  				//   1   2
   150  				// (│ key │), here key = ""
   151  				sb.WriteString(strings.Repeat(" ", longestKey+2)) // Padding
   152  				sb.WriteString(vBar)
   153  			}
   154  
   155  			sb.WriteRune(' ')
   156  			sb.WriteString(v)
   157  			sb.WriteString(strings.Repeat(" ", longestValue-len([]rune(v)))) // Padding
   158  
   159  			// End
   160  			sb.WriteString(" ")
   161  			sb.WriteString(vBar)
   162  		}
   163  		sb.WriteString("\n")
   164  	}
   165  
   166  	// Footer
   167  	// -5- represents the characters that wrap both key and value
   168  	//  1   234     5
   169  	// ( key │ value )
   170  	footerLen := longestKey + longestValue + 5
   171  	sb.WriteString(lowerLeft)
   172  	sb.WriteString(strings.Repeat(hBar, footerLen))
   173  	sb.WriteString(lowerRight)
   174  
   175  	return sb.String()
   176  }
   177  
   178  // Erase overwrites the file content with random bytes and then deletes it.
   179  func Erase(filename string) error {
   180  	f, err := os.Stat(filename)
   181  	if err != nil {
   182  		return errors.Wrap(err, "obtaining file information")
   183  	}
   184  
   185  	buf := make([]byte, f.Size())
   186  	if _, err := rand.Read(buf); err != nil {
   187  		return errors.Wrap(err, "generating random bytes")
   188  	}
   189  
   190  	// WriteFile truncates the file and overwrites it
   191  	if err := os.WriteFile(filename, buf, 0o600); err != nil {
   192  		return errors.Wrap(err, "overwriting file")
   193  	}
   194  
   195  	if err := os.Remove(filename); err != nil {
   196  		return errors.Wrap(err, "removing file")
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  // Exists checks if name or one of its folders is already being used.
   203  //
   204  // Returns an error if a match was found.
   205  func Exists(db *bolt.DB, name string, obj object) error {
   206  	records, objType, err := listNames(db, obj)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	return exists(records, name, objType)
   212  }
   213  
   214  // FmtExpires returns expires formatted.
   215  func FmtExpires(expires string) (string, error) {
   216  	switch strings.ToLower(expires) {
   217  	case "never", "", " ", "0", "0s":
   218  		return "Never", nil
   219  
   220  	default:
   221  		expires = strings.ReplaceAll(expires, "-", "/")
   222  
   223  		// If the first format fails, try the second
   224  		exp, err := time.Parse("02/01/2006", expires)
   225  		if err != nil {
   226  			exp, err = time.Parse("2006/01/02", expires)
   227  			if err != nil {
   228  				return "", errors.New("\"expires\" field has an invalid format. Valid formats: d/m/y or y/m/d")
   229  			}
   230  		}
   231  
   232  		return exp.Format(time.RFC1123Z), nil
   233  	}
   234  }
   235  
   236  // MustExist returns an error if a record does not exist or if the name is invalid.
   237  func MustExist(db *bolt.DB, obj object, allowDir ...bool) cobra.PositionalArgs {
   238  	return func(cmd *cobra.Command, args []string) error {
   239  		if len(args) == 0 {
   240  			return ErrInvalidName
   241  		}
   242  
   243  		records, objType, err := listNames(db, obj)
   244  		if err != nil {
   245  			return err
   246  		}
   247  
   248  		for _, name := range args {
   249  			if name == "" || strings.Contains(name, "//") {
   250  				return ErrInvalidName
   251  			}
   252  			name = NormalizeName(name, allowDir...)
   253  
   254  			if strings.HasSuffix(name, "/") {
   255  				// Take directories into consideration only when the user
   256  				// is trying to perform an action with one
   257  				if err := exists(records, name, objType); err == nil {
   258  					return errors.Errorf("directory %q does not exist", strings.TrimSuffix(name, "/"))
   259  				}
   260  				return nil
   261  			}
   262  
   263  			exists := false
   264  			for _, record := range records {
   265  				if name == record {
   266  					exists = true
   267  					break
   268  				}
   269  			}
   270  			if !exists {
   271  				return errors.Errorf("%q does not exist", name)
   272  			}
   273  		}
   274  
   275  		return nil
   276  	}
   277  }
   278  
   279  // MustExistLs is like MustExist but it doesn't fail if
   280  // there are no arguments or if the user is using the filter flag.
   281  func MustExistLs(db *bolt.DB, obj object) cobra.PositionalArgs {
   282  	return func(cmd *cobra.Command, args []string) error {
   283  		if len(args) == 0 || cmd.Flags().Changed("filter") {
   284  			return nil
   285  		}
   286  
   287  		// If an empty string is joined in session/it command
   288  		// it returns a 1 item long slice [""]
   289  		if strings.Join(args, "") == "" {
   290  			return nil
   291  		}
   292  
   293  		// Pass on cmd and args
   294  		return MustExist(db, obj)(cmd, args)
   295  	}
   296  }
   297  
   298  // MustNotExist returns an error if the record exists or if the name is invalid.
   299  func MustNotExist(db *bolt.DB, obj object, allowDir ...bool) cobra.PositionalArgs {
   300  	return func(cmd *cobra.Command, args []string) error {
   301  		if len(args) == 0 {
   302  			return ErrInvalidName
   303  		}
   304  
   305  		for _, name := range args {
   306  			if name == "" || strings.Contains(name, "//") {
   307  				return ErrInvalidName
   308  			}
   309  			name = NormalizeName(name, allowDir...)
   310  
   311  			if err := Exists(db, name, obj); err != nil {
   312  				return err
   313  			}
   314  		}
   315  
   316  		return nil
   317  	}
   318  }
   319  
   320  // NormalizeName sanitizes the user input name.
   321  func NormalizeName(name string, allowDir ...bool) string {
   322  	if name == "" {
   323  		return name // Avoid allocations
   324  	}
   325  	if len(allowDir) == 0 {
   326  		return strings.ToLower(strings.TrimSpace(strings.Trim(strings.TrimSpace(name), "/")))
   327  	}
   328  	return strings.ToLower(strings.TrimSpace(name))
   329  }
   330  
   331  // SelectEditor returns the editor to use, if none is found it returns vim.
   332  func SelectEditor() string {
   333  	if def := config.GetString("editor"); def != "" {
   334  		return def
   335  	} else if e := os.Getenv("EDITOR"); e != "" {
   336  		return e
   337  	} else if v := os.Getenv("VISUAL"); v != "" {
   338  		return v
   339  	}
   340  
   341  	return "vim"
   342  }
   343  
   344  // SetContext sets up the testing environment.
   345  //
   346  // It uses t.Cleanup() to close the database connection after the test and
   347  // all its subtests are completed.
   348  func SetContext(t testing.TB) *bolt.DB {
   349  	t.Helper()
   350  
   351  	dbFile, err := os.CreateTemp("", "*")
   352  	assert.NoError(t, err)
   353  
   354  	db, err := bolt.Open(dbFile.Name(), 0o600, &bolt.Options{Timeout: 1 * time.Second})
   355  	assert.NoError(t, err, "Failed connecting to the database")
   356  
   357  	config.Reset()
   358  	// Reduce argon2 parameters to speed up tests
   359  	auth := map[string]interface{}{
   360  		"password":   memguard.NewEnclave([]byte("1")),
   361  		"iterations": 1,
   362  		"memory":     1,
   363  		"threads":    1,
   364  	}
   365  	config.Set("auth", auth)
   366  
   367  	db.Update(func(tx *bolt.Tx) error {
   368  		buckets := bucket.GetNames()
   369  		for _, bucket := range buckets {
   370  			// Ignore errors on purpose
   371  			tx.DeleteBucket(bucket)
   372  			tx.CreateBucketIfNotExists(bucket)
   373  		}
   374  		return nil
   375  	})
   376  
   377  	os.Stdout = os.NewFile(0, "") // Mute stdout
   378  	os.Stderr = os.NewFile(0, "") // Mute stderr
   379  	t.Cleanup(func() {
   380  		assert.NoError(t, db.Close(), "Failed connecting to the database")
   381  	})
   382  
   383  	return db
   384  }
   385  
   386  // WatchFile looks for the file initial state and loops until the first modification.
   387  //
   388  // Preferred over fsnotify since this last returns false events with recently created files.
   389  func WatchFile(filename string, done chan struct{}, errCh chan error) {
   390  	initStat, err := os.Stat(filename)
   391  	if err != nil {
   392  		errCh <- err
   393  		return
   394  	}
   395  
   396  	for {
   397  		stat, err := os.Stat(filename)
   398  		if err != nil {
   399  			errCh <- err
   400  			return
   401  		}
   402  
   403  		if stat.Size() != initStat.Size() || stat.ModTime() != initStat.ModTime() {
   404  			break
   405  		}
   406  
   407  		time.Sleep(300 * time.Millisecond)
   408  	}
   409  
   410  	done <- struct{}{}
   411  }
   412  
   413  // WriteClipboard writes the content to the clipboard and deletes it after
   414  // "t" if "t" is higher than 0 or if there is a default timeout set in the configuration.
   415  // Otherwise it does nothing.
   416  func WriteClipboard(cmd *cobra.Command, d time.Duration, field, content string) error {
   417  	if err := clipboard.WriteAll(content); err != nil {
   418  		return errors.Wrap(err, "writing to clipboard")
   419  	}
   420  	memguard.WipeBytes([]byte(content))
   421  
   422  	// Use the config value if it's specified and the timeout flag wasn't used
   423  	configKey := "clipboard.timeout"
   424  	if config.IsSet(configKey) && !cmd.Flags().Changed("timeout") {
   425  		d = config.GetDuration(configKey)
   426  	}
   427  
   428  	if d <= 0 {
   429  		fmt.Println(field, "copied to clipboard")
   430  		return nil
   431  	}
   432  
   433  	sig.Signal.AddCleanup(func() error { return clipboard.WriteAll("") })
   434  	done := make(chan struct{})
   435  	start := time.Now()
   436  
   437  	go terminal.Ticker(done, true, func() {
   438  		timeLeft := d - time.Since(start)
   439  		fmt.Printf("(%v) %s copied to clipboard", timeLeft.Round(time.Second), field)
   440  	})
   441  
   442  	<-time.After(d)
   443  	done <- struct{}{}
   444  	clipboard.WriteAll("")
   445  
   446  	return nil
   447  }
   448  
   449  func exists(records []string, name, objType string) error {
   450  	if len(records) == 0 {
   451  		return nil
   452  	}
   453  
   454  	found := func(name string) error {
   455  		return errors.Errorf("already exists a folder or %s named %q", objType, name)
   456  	}
   457  	// Remove slash to do the comparison
   458  	name = strings.TrimSuffix(name, "/")
   459  
   460  	for _, record := range records {
   461  		if name == record {
   462  			return found(name)
   463  		}
   464  
   465  		// record = "Padmé/Amidala", name = "Padmé/" should return an error
   466  		if hasPrefix(record, name) {
   467  			return found(name)
   468  		}
   469  
   470  		// name = "Padmé/Amidala", record = "Padmé/" should return an error
   471  		if hasPrefix(name, record) {
   472  			return found(record)
   473  		}
   474  	}
   475  
   476  	return nil
   477  }
   478  
   479  // hasPrefix is a modified version of strings.HasPrefix() that suits this use case, prefix is not modified to save an allocation.
   480  func hasPrefix(s, prefix string) bool {
   481  	prefixLen := len(prefix)
   482  	return len(s) > prefixLen && s[0:prefixLen] == prefix && s[prefixLen] == '/'
   483  }
   484  
   485  // listNames lists all the records depending on the object passed.
   486  // It returns a list and the type of object used.
   487  func listNames(db *bolt.DB, obj object) ([]string, string, error) {
   488  	var (
   489  		err     error
   490  		objType string
   491  		records []string
   492  	)
   493  
   494  	switch obj {
   495  	case Card:
   496  		objType = "card"
   497  		records, err = card.ListNames(db)
   498  
   499  	case Entry:
   500  		objType = "entry"
   501  		records, err = entry.ListNames(db)
   502  
   503  	case File:
   504  		objType = "file"
   505  		records, err = file.ListNames(db)
   506  
   507  	case TOTP:
   508  		objType = "TOTP"
   509  		records, err = totp.ListNames(db)
   510  	}
   511  	if err != nil {
   512  		return nil, "", err
   513  	}
   514  
   515  	return records, objType, nil
   516  }