github.com/slackhq/nebula@v1.9.0/sshd/command.go (about)

     1  package sshd
     2  
     3  import (
     4  	"errors"
     5  	"flag"
     6  	"fmt"
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/armon/go-radix"
    11  )
    12  
    13  // CommandFlags is a function called before help or command execution to parse command line flags
    14  // It should return a flag.FlagSet instance and a pointer to the struct that will contain parsed flags
    15  type CommandFlags func() (*flag.FlagSet, interface{})
    16  
    17  // CommandCallback is the function called when your command should execute.
    18  // fs will be a a pointer to the struct provided by Command.Flags callback, if there was one. -h and -help are reserved
    19  // and handled automatically for you.
    20  // a will be any unconsumed arguments, if no Command.Flags was available this will be all the flags passed in.
    21  // w is the writer to use when sending messages back to the client.
    22  // If an error is returned by the callback it is logged locally, the callback should handle messaging errors to the user
    23  // where appropriate
    24  type CommandCallback func(fs interface{}, a []string, w StringWriter) error
    25  
    26  type Command struct {
    27  	Name             string
    28  	ShortDescription string
    29  	Help             string
    30  	Flags            CommandFlags
    31  	Callback         CommandCallback
    32  }
    33  
    34  func execCommand(c *Command, args []string, w StringWriter) error {
    35  	var (
    36  		fl *flag.FlagSet
    37  		fs interface{}
    38  	)
    39  
    40  	if c.Flags != nil {
    41  		fl, fs = c.Flags()
    42  		if fl != nil {
    43  			// SetOutput() here in case fl.Parse dumps usage.
    44  			fl.SetOutput(w.GetWriter())
    45  			err := fl.Parse(args)
    46  			if err != nil {
    47  				// fl.Parse has dumped error information to the user via the w writer.
    48  				return err
    49  			}
    50  			args = fl.Args()
    51  		}
    52  	}
    53  
    54  	return c.Callback(fs, args, w)
    55  }
    56  
    57  func dumpCommands(c *radix.Tree, w StringWriter) {
    58  	err := w.WriteLine("Available commands:")
    59  	if err != nil {
    60  		//TODO: log
    61  		return
    62  	}
    63  
    64  	cmds := make([]string, 0)
    65  	for _, l := range allCommands(c) {
    66  		cmds = append(cmds, fmt.Sprintf("%s - %s", l.Name, l.ShortDescription))
    67  	}
    68  
    69  	sort.Strings(cmds)
    70  	err = w.Write(strings.Join(cmds, "\n") + "\n\n")
    71  	if err != nil {
    72  		//TODO: log
    73  	}
    74  }
    75  
    76  func lookupCommand(c *radix.Tree, sCmd string) (*Command, error) {
    77  	cmd, ok := c.Get(sCmd)
    78  	if !ok {
    79  		return nil, nil
    80  	}
    81  
    82  	command, ok := cmd.(*Command)
    83  	if !ok {
    84  		return nil, errors.New("failed to cast command")
    85  	}
    86  
    87  	return command, nil
    88  }
    89  
    90  func matchCommand(c *radix.Tree, cmd string) []string {
    91  	cmds := make([]string, 0)
    92  	c.WalkPrefix(cmd, func(found string, v interface{}) bool {
    93  		cmds = append(cmds, found)
    94  		return false
    95  	})
    96  	sort.Strings(cmds)
    97  	return cmds
    98  }
    99  
   100  func allCommands(c *radix.Tree) []*Command {
   101  	cmds := make([]*Command, 0)
   102  	c.WalkPrefix("", func(found string, v interface{}) bool {
   103  		cmd, ok := v.(*Command)
   104  		if ok {
   105  			cmds = append(cmds, cmd)
   106  		}
   107  		return false
   108  	})
   109  	return cmds
   110  }
   111  
   112  func helpCallback(commands *radix.Tree, a []string, w StringWriter) (err error) {
   113  	// Just typed help
   114  	if len(a) == 0 {
   115  		dumpCommands(commands, w)
   116  		return nil
   117  	}
   118  
   119  	// We are printing a specific commands help text
   120  	cmd, err := lookupCommand(commands, a[0])
   121  	if err != nil {
   122  		//TODO: handle error
   123  		//TODO: message the user
   124  		return
   125  	}
   126  
   127  	if cmd != nil {
   128  		err = w.WriteLine(fmt.Sprintf("%s - %s", cmd.Name, cmd.ShortDescription))
   129  		if err != nil {
   130  			return err
   131  		}
   132  
   133  		if cmd.Help != "" {
   134  			err = w.WriteLine(fmt.Sprintf("  %s", cmd.Help))
   135  			if err != nil {
   136  				return err
   137  			}
   138  		}
   139  
   140  		if cmd.Flags != nil {
   141  			fs, _ := cmd.Flags()
   142  			if fs != nil {
   143  				fs.SetOutput(w.GetWriter())
   144  				fs.PrintDefaults()
   145  			}
   146  		}
   147  
   148  		return nil
   149  	}
   150  
   151  	err = w.WriteLine("Command not available " + a[0])
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  func checkHelpArgs(args []string) bool {
   160  	for _, a := range args {
   161  		if a == "-h" || a == "-help" {
   162  			return true
   163  		}
   164  	}
   165  
   166  	return false
   167  }