github.com/GGP1/kure@v0.8.4/commands/util.go (about) 1 package cmdutil 2 3 import ( 4 "crypto/rand" 5 "fmt" 6 "os" 7 "path/filepath" 8 "strings" 9 "testing" 10 "time" 11 12 "github.com/GGP1/kure/config" 13 "github.com/GGP1/kure/db/bucket" 14 "github.com/GGP1/kure/db/card" 15 "github.com/GGP1/kure/db/entry" 16 "github.com/GGP1/kure/db/file" 17 "github.com/GGP1/kure/db/totp" 18 "github.com/GGP1/kure/orderedmap" 19 "github.com/GGP1/kure/sig" 20 "github.com/GGP1/kure/terminal" 21 22 "github.com/atotto/clipboard" 23 "github.com/awnumar/memguard" 24 "github.com/pkg/errors" 25 "github.com/spf13/cobra" 26 "github.com/stretchr/testify/assert" 27 bolt "go.etcd.io/bbolt" 28 ) 29 30 var ( 31 // ErrInvalidLength is returned when generating a password/passphrase and the length passed is < 1. 32 ErrInvalidLength = errors.New("invalid length") 33 // ErrInvalidName is returned when a name is required and received "" or contains "//". 34 ErrInvalidName = errors.New("invalid name") 35 // ErrInvalidPath is returned when a path is required and received "". 36 ErrInvalidPath = errors.New("invalid path") 37 ) 38 39 const ( 40 // Card object 41 Card object = iota 42 // Entry object 43 Entry 44 // File object 45 File 46 // TOTP object 47 TOTP 48 49 // Box 50 hBar = "─" 51 vBar = "│" 52 upperLeft = "╭" 53 lowerLeft = "╰" 54 upperRight = "╮" 55 lowerRight = "╯" 56 ) 57 58 // RunEFunc runs a cobra function returning an error. 59 type RunEFunc func(cmd *cobra.Command, args []string) error 60 61 type object int 62 63 // BuildBox constructs a responsive box used to display records information. 64 // 65 // ┌──── Sample ────┐ 66 // │ Key │ Value │ 67 // └────────────────┘ 68 func BuildBox(name string, mp *orderedmap.Map) string { 69 var sb strings.Builder 70 71 // Do not use folders as part of the name 72 name = filepath.Base(name) 73 if !strings.Contains(name, ".") { 74 name = strings.Title(name) 75 } 76 77 nameLen := len([]rune(name)) 78 longestKey := 0 79 longestValue := nameLen 80 81 // Range to take the longest key and value 82 // Keys will always be 1 byte characters 83 // Values may be 1, 2 or 3 bytes, to take the length use len([]rune(v)) 84 for _, key := range mp.Keys() { 85 value := mp.Get(key) // Get key's value 86 87 // Take map's longest key 88 if len(key) > longestKey { 89 longestKey = len(key) 90 } 91 92 // Split each value by a new line (fields like Notes contain multiple lines) 93 for _, v := range strings.Split(value, "\n") { 94 lenV := len([]rune(v)) 95 96 // Take map's longest value 97 if lenV > longestValue { 98 longestValue = lenV 99 } 100 } 101 } 102 103 // -4-: 2 spaces that wrap name and 2 corners 104 headerLen := longestKey + longestValue - nameLen + 4 105 headerHalfLen := headerLen / 2 106 107 // Left side header 108 sb.WriteString(upperLeft) 109 sb.WriteString(strings.Repeat(hBar, headerHalfLen)) 110 111 // Header name 112 sb.WriteRune(' ') 113 sb.WriteString(name) 114 sb.WriteRune(' ') 115 116 // Adjust the right side of the header if its width is even 117 if headerLen%2 == 0 { 118 headerHalfLen-- 119 } 120 121 // Right side header 122 sb.WriteString(strings.Repeat(hBar, headerHalfLen)) 123 sb.WriteString(upperRight) 124 sb.WriteString("\n") 125 126 // Body 127 for _, key := range mp.Keys() { 128 value := mp.Get(key) // Get key's value 129 // Start 130 sb.WriteString(vBar) 131 132 // Key 133 sb.WriteRune(' ') 134 sb.WriteString(key) 135 sb.WriteRune(' ') 136 sb.WriteString(strings.Repeat(" ", longestKey-len(key))) // Padding 137 138 // Middle 139 sb.WriteString(vBar) 140 141 // Value 142 for i, v := range strings.Split(value, "\n") { 143 // In case the value contains multi-lines, 144 // repeat the process above but do not add the key 145 if i >= 1 { 146 sb.WriteString("\n") 147 sb.WriteString(vBar) 148 // -2- represents key leading and trailing spaces 149 // 1 2 150 // (│ key │), here key = "" 151 sb.WriteString(strings.Repeat(" ", longestKey+2)) // Padding 152 sb.WriteString(vBar) 153 } 154 155 sb.WriteRune(' ') 156 sb.WriteString(v) 157 sb.WriteString(strings.Repeat(" ", longestValue-len([]rune(v)))) // Padding 158 159 // End 160 sb.WriteString(" ") 161 sb.WriteString(vBar) 162 } 163 sb.WriteString("\n") 164 } 165 166 // Footer 167 // -5- represents the characters that wrap both key and value 168 // 1 234 5 169 // ( key │ value ) 170 footerLen := longestKey + longestValue + 5 171 sb.WriteString(lowerLeft) 172 sb.WriteString(strings.Repeat(hBar, footerLen)) 173 sb.WriteString(lowerRight) 174 175 return sb.String() 176 } 177 178 // Erase overwrites the file content with random bytes and then deletes it. 179 func Erase(filename string) error { 180 f, err := os.Stat(filename) 181 if err != nil { 182 return errors.Wrap(err, "obtaining file information") 183 } 184 185 buf := make([]byte, f.Size()) 186 if _, err := rand.Read(buf); err != nil { 187 return errors.Wrap(err, "generating random bytes") 188 } 189 190 // WriteFile truncates the file and overwrites it 191 if err := os.WriteFile(filename, buf, 0o600); err != nil { 192 return errors.Wrap(err, "overwriting file") 193 } 194 195 if err := os.Remove(filename); err != nil { 196 return errors.Wrap(err, "removing file") 197 } 198 199 return nil 200 } 201 202 // Exists checks if name or one of its folders is already being used. 203 // 204 // Returns an error if a match was found. 205 func Exists(db *bolt.DB, name string, obj object) error { 206 records, objType, err := listNames(db, obj) 207 if err != nil { 208 return err 209 } 210 211 return exists(records, name, objType) 212 } 213 214 // FmtExpires returns expires formatted. 215 func FmtExpires(expires string) (string, error) { 216 switch strings.ToLower(expires) { 217 case "never", "", " ", "0", "0s": 218 return "Never", nil 219 220 default: 221 expires = strings.ReplaceAll(expires, "-", "/") 222 223 // If the first format fails, try the second 224 exp, err := time.Parse("02/01/2006", expires) 225 if err != nil { 226 exp, err = time.Parse("2006/01/02", expires) 227 if err != nil { 228 return "", errors.New("\"expires\" field has an invalid format. Valid formats: d/m/y or y/m/d") 229 } 230 } 231 232 return exp.Format(time.RFC1123Z), nil 233 } 234 } 235 236 // MustExist returns an error if a record does not exist or if the name is invalid. 237 func MustExist(db *bolt.DB, obj object, allowDir ...bool) cobra.PositionalArgs { 238 return func(cmd *cobra.Command, args []string) error { 239 if len(args) == 0 { 240 return ErrInvalidName 241 } 242 243 records, objType, err := listNames(db, obj) 244 if err != nil { 245 return err 246 } 247 248 for _, name := range args { 249 if name == "" || strings.Contains(name, "//") { 250 return ErrInvalidName 251 } 252 name = NormalizeName(name, allowDir...) 253 254 if strings.HasSuffix(name, "/") { 255 // Take directories into consideration only when the user 256 // is trying to perform an action with one 257 if err := exists(records, name, objType); err == nil { 258 return errors.Errorf("directory %q does not exist", strings.TrimSuffix(name, "/")) 259 } 260 return nil 261 } 262 263 exists := false 264 for _, record := range records { 265 if name == record { 266 exists = true 267 break 268 } 269 } 270 if !exists { 271 return errors.Errorf("%q does not exist", name) 272 } 273 } 274 275 return nil 276 } 277 } 278 279 // MustExistLs is like MustExist but it doesn't fail if 280 // there are no arguments or if the user is using the filter flag. 281 func MustExistLs(db *bolt.DB, obj object) cobra.PositionalArgs { 282 return func(cmd *cobra.Command, args []string) error { 283 if len(args) == 0 || cmd.Flags().Changed("filter") { 284 return nil 285 } 286 287 // If an empty string is joined in session/it command 288 // it returns a 1 item long slice [""] 289 if strings.Join(args, "") == "" { 290 return nil 291 } 292 293 // Pass on cmd and args 294 return MustExist(db, obj)(cmd, args) 295 } 296 } 297 298 // MustNotExist returns an error if the record exists or if the name is invalid. 299 func MustNotExist(db *bolt.DB, obj object, allowDir ...bool) cobra.PositionalArgs { 300 return func(cmd *cobra.Command, args []string) error { 301 if len(args) == 0 { 302 return ErrInvalidName 303 } 304 305 for _, name := range args { 306 if name == "" || strings.Contains(name, "//") { 307 return ErrInvalidName 308 } 309 name = NormalizeName(name, allowDir...) 310 311 if err := Exists(db, name, obj); err != nil { 312 return err 313 } 314 } 315 316 return nil 317 } 318 } 319 320 // NormalizeName sanitizes the user input name. 321 func NormalizeName(name string, allowDir ...bool) string { 322 if name == "" { 323 return name // Avoid allocations 324 } 325 if len(allowDir) == 0 { 326 return strings.ToLower(strings.TrimSpace(strings.Trim(strings.TrimSpace(name), "/"))) 327 } 328 return strings.ToLower(strings.TrimSpace(name)) 329 } 330 331 // SelectEditor returns the editor to use, if none is found it returns vim. 332 func SelectEditor() string { 333 if def := config.GetString("editor"); def != "" { 334 return def 335 } else if e := os.Getenv("EDITOR"); e != "" { 336 return e 337 } else if v := os.Getenv("VISUAL"); v != "" { 338 return v 339 } 340 341 return "vim" 342 } 343 344 // SetContext sets up the testing environment. 345 // 346 // It uses t.Cleanup() to close the database connection after the test and 347 // all its subtests are completed. 348 func SetContext(t testing.TB) *bolt.DB { 349 t.Helper() 350 351 dbFile, err := os.CreateTemp("", "*") 352 assert.NoError(t, err) 353 354 db, err := bolt.Open(dbFile.Name(), 0o600, &bolt.Options{Timeout: 1 * time.Second}) 355 assert.NoError(t, err, "Failed connecting to the database") 356 357 config.Reset() 358 // Reduce argon2 parameters to speed up tests 359 auth := map[string]interface{}{ 360 "password": memguard.NewEnclave([]byte("1")), 361 "iterations": 1, 362 "memory": 1, 363 "threads": 1, 364 } 365 config.Set("auth", auth) 366 367 db.Update(func(tx *bolt.Tx) error { 368 buckets := bucket.GetNames() 369 for _, bucket := range buckets { 370 // Ignore errors on purpose 371 tx.DeleteBucket(bucket) 372 tx.CreateBucketIfNotExists(bucket) 373 } 374 return nil 375 }) 376 377 os.Stdout = os.NewFile(0, "") // Mute stdout 378 os.Stderr = os.NewFile(0, "") // Mute stderr 379 t.Cleanup(func() { 380 assert.NoError(t, db.Close(), "Failed connecting to the database") 381 }) 382 383 return db 384 } 385 386 // WatchFile looks for the file initial state and loops until the first modification. 387 // 388 // Preferred over fsnotify since this last returns false events with recently created files. 389 func WatchFile(filename string, done chan struct{}, errCh chan error) { 390 initStat, err := os.Stat(filename) 391 if err != nil { 392 errCh <- err 393 return 394 } 395 396 for { 397 stat, err := os.Stat(filename) 398 if err != nil { 399 errCh <- err 400 return 401 } 402 403 if stat.Size() != initStat.Size() || stat.ModTime() != initStat.ModTime() { 404 break 405 } 406 407 time.Sleep(300 * time.Millisecond) 408 } 409 410 done <- struct{}{} 411 } 412 413 // WriteClipboard writes the content to the clipboard and deletes it after 414 // "t" if "t" is higher than 0 or if there is a default timeout set in the configuration. 415 // Otherwise it does nothing. 416 func WriteClipboard(cmd *cobra.Command, d time.Duration, field, content string) error { 417 if err := clipboard.WriteAll(content); err != nil { 418 return errors.Wrap(err, "writing to clipboard") 419 } 420 memguard.WipeBytes([]byte(content)) 421 422 // Use the config value if it's specified and the timeout flag wasn't used 423 configKey := "clipboard.timeout" 424 if config.IsSet(configKey) && !cmd.Flags().Changed("timeout") { 425 d = config.GetDuration(configKey) 426 } 427 428 if d <= 0 { 429 fmt.Println(field, "copied to clipboard") 430 return nil 431 } 432 433 sig.Signal.AddCleanup(func() error { return clipboard.WriteAll("") }) 434 done := make(chan struct{}) 435 start := time.Now() 436 437 go terminal.Ticker(done, true, func() { 438 timeLeft := d - time.Since(start) 439 fmt.Printf("(%v) %s copied to clipboard", timeLeft.Round(time.Second), field) 440 }) 441 442 <-time.After(d) 443 done <- struct{}{} 444 clipboard.WriteAll("") 445 446 return nil 447 } 448 449 func exists(records []string, name, objType string) error { 450 if len(records) == 0 { 451 return nil 452 } 453 454 found := func(name string) error { 455 return errors.Errorf("already exists a folder or %s named %q", objType, name) 456 } 457 // Remove slash to do the comparison 458 name = strings.TrimSuffix(name, "/") 459 460 for _, record := range records { 461 if name == record { 462 return found(name) 463 } 464 465 // record = "Padmé/Amidala", name = "Padmé/" should return an error 466 if hasPrefix(record, name) { 467 return found(name) 468 } 469 470 // name = "Padmé/Amidala", record = "Padmé/" should return an error 471 if hasPrefix(name, record) { 472 return found(record) 473 } 474 } 475 476 return nil 477 } 478 479 // hasPrefix is a modified version of strings.HasPrefix() that suits this use case, prefix is not modified to save an allocation. 480 func hasPrefix(s, prefix string) bool { 481 prefixLen := len(prefix) 482 return len(s) > prefixLen && s[0:prefixLen] == prefix && s[prefixLen] == '/' 483 } 484 485 // listNames lists all the records depending on the object passed. 486 // It returns a list and the type of object used. 487 func listNames(db *bolt.DB, obj object) ([]string, string, error) { 488 var ( 489 err error 490 objType string 491 records []string 492 ) 493 494 switch obj { 495 case Card: 496 objType = "card" 497 records, err = card.ListNames(db) 498 499 case Entry: 500 objType = "entry" 501 records, err = entry.ListNames(db) 502 503 case File: 504 objType = "file" 505 records, err = file.ListNames(db) 506 507 case TOTP: 508 objType = "TOTP" 509 records, err = totp.ListNames(db) 510 } 511 if err != nil { 512 return nil, "", err 513 } 514 515 return records, objType, nil 516 }