github.com/slackhq/nebula@v1.9.0/ssh.go (about)

     1  package nebula
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"flag"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"reflect"
    12  	"runtime"
    13  	"runtime/pprof"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/sirupsen/logrus"
    19  	"github.com/slackhq/nebula/config"
    20  	"github.com/slackhq/nebula/header"
    21  	"github.com/slackhq/nebula/iputil"
    22  	"github.com/slackhq/nebula/sshd"
    23  	"github.com/slackhq/nebula/udp"
    24  )
    25  
    26  type sshListHostMapFlags struct {
    27  	Json    bool
    28  	Pretty  bool
    29  	ByIndex bool
    30  }
    31  
    32  type sshPrintCertFlags struct {
    33  	Json   bool
    34  	Pretty bool
    35  	Raw    bool
    36  }
    37  
    38  type sshPrintTunnelFlags struct {
    39  	Pretty bool
    40  }
    41  
    42  type sshChangeRemoteFlags struct {
    43  	Address string
    44  }
    45  
    46  type sshCloseTunnelFlags struct {
    47  	LocalOnly bool
    48  }
    49  
    50  type sshCreateTunnelFlags struct {
    51  	Address string
    52  }
    53  
    54  type sshDeviceInfoFlags struct {
    55  	Json   bool
    56  	Pretty bool
    57  }
    58  
    59  func wireSSHReload(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) {
    60  	c.RegisterReloadCallback(func(c *config.C) {
    61  		if c.GetBool("sshd.enabled", false) {
    62  			sshRun, err := configSSH(l, ssh, c)
    63  			if err != nil {
    64  				l.WithError(err).Error("Failed to reconfigure the sshd")
    65  				ssh.Stop()
    66  			}
    67  			if sshRun != nil {
    68  				go sshRun()
    69  			}
    70  		} else {
    71  			ssh.Stop()
    72  		}
    73  	})
    74  }
    75  
    76  // configSSH reads the ssh info out of the passed-in Config and
    77  // updates the passed-in SSHServer. On success, it returns a function
    78  // that callers may invoke to run the configured ssh server. On
    79  // failure, it returns nil, error.
    80  func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), error) {
    81  	//TODO conntrack list
    82  	//TODO print firewall rules or hash?
    83  
    84  	listen := c.GetString("sshd.listen", "")
    85  	if listen == "" {
    86  		return nil, fmt.Errorf("sshd.listen must be provided")
    87  	}
    88  
    89  	_, port, err := net.SplitHostPort(listen)
    90  	if err != nil {
    91  		return nil, fmt.Errorf("invalid sshd.listen address: %s", err)
    92  	}
    93  	if port == "22" {
    94  		return nil, fmt.Errorf("sshd.listen can not use port 22")
    95  	}
    96  
    97  	//TODO: no good way to reload this right now
    98  	hostKeyPathOrKey := c.GetString("sshd.host_key", "")
    99  	if hostKeyPathOrKey == "" {
   100  		return nil, fmt.Errorf("sshd.host_key must be provided")
   101  	}
   102  
   103  	var hostKeyBytes []byte
   104  	if strings.Contains(hostKeyPathOrKey, "-----BEGIN") {
   105  		hostKeyBytes = []byte(hostKeyPathOrKey)
   106  	} else {
   107  		hostKeyBytes, err = os.ReadFile(hostKeyPathOrKey)
   108  		if err != nil {
   109  			return nil, fmt.Errorf("error while loading sshd.host_key file: %s", err)
   110  		}
   111  	}
   112  
   113  	err = ssh.SetHostKey(hostKeyBytes)
   114  	if err != nil {
   115  		return nil, fmt.Errorf("error while adding sshd.host_key: %s", err)
   116  	}
   117  
   118  	// Clear existing trusted CAs and authorized keys
   119  	ssh.ClearTrustedCAs()
   120  	ssh.ClearAuthorizedKeys()
   121  
   122  	rawCAs := c.GetStringSlice("sshd.trusted_cas", []string{})
   123  	for _, caAuthorizedKey := range rawCAs {
   124  		err := ssh.AddTrustedCA(caAuthorizedKey)
   125  		if err != nil {
   126  			l.WithError(err).WithField("sshCA", caAuthorizedKey).Warn("SSH CA had an error, ignoring")
   127  			continue
   128  		}
   129  	}
   130  
   131  	rawKeys := c.Get("sshd.authorized_users")
   132  	keys, ok := rawKeys.([]interface{})
   133  	if ok {
   134  		for _, rk := range keys {
   135  			kDef, ok := rk.(map[interface{}]interface{})
   136  			if !ok {
   137  				l.WithField("sshKeyConfig", rk).Warn("Authorized user had an error, ignoring")
   138  				continue
   139  			}
   140  
   141  			user, ok := kDef["user"].(string)
   142  			if !ok {
   143  				l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the user field")
   144  				continue
   145  			}
   146  
   147  			k := kDef["keys"]
   148  			switch v := k.(type) {
   149  			case string:
   150  				err := ssh.AddAuthorizedKey(user, v)
   151  				if err != nil {
   152  					l.WithError(err).WithField("sshKeyConfig", rk).WithField("sshKey", v).Warn("Failed to authorize key")
   153  					continue
   154  				}
   155  
   156  			case []interface{}:
   157  				for _, subK := range v {
   158  					sk, ok := subK.(string)
   159  					if !ok {
   160  						l.WithField("sshKeyConfig", rk).WithField("sshKey", subK).Warn("Did not understand ssh key")
   161  						continue
   162  					}
   163  
   164  					err := ssh.AddAuthorizedKey(user, sk)
   165  					if err != nil {
   166  						l.WithError(err).WithField("sshKeyConfig", sk).Warn("Failed to authorize key")
   167  						continue
   168  					}
   169  				}
   170  
   171  			default:
   172  				l.WithField("sshKeyConfig", rk).Warn("Authorized user is missing the keys field or was not understood")
   173  			}
   174  		}
   175  	} else {
   176  		l.Info("no ssh users to authorize")
   177  	}
   178  
   179  	var runner func()
   180  	if c.GetBool("sshd.enabled", false) {
   181  		ssh.Stop()
   182  		runner = func() {
   183  			if err := ssh.Run(listen); err != nil {
   184  				l.WithField("err", err).Warn("Failed to run the SSH server")
   185  			}
   186  		}
   187  	} else {
   188  		ssh.Stop()
   189  	}
   190  
   191  	return runner, nil
   192  }
   193  
   194  func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
   195  	ssh.RegisterCommand(&sshd.Command{
   196  		Name:             "list-hostmap",
   197  		ShortDescription: "List all known previously connected hosts",
   198  		Flags: func() (*flag.FlagSet, interface{}) {
   199  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   200  			s := sshListHostMapFlags{}
   201  			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
   202  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
   203  			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
   204  			return fl, &s
   205  		},
   206  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   207  			return sshListHostMap(f.hostMap, fs, w)
   208  		},
   209  	})
   210  
   211  	ssh.RegisterCommand(&sshd.Command{
   212  		Name:             "list-pending-hostmap",
   213  		ShortDescription: "List all handshaking hosts",
   214  		Flags: func() (*flag.FlagSet, interface{}) {
   215  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   216  			s := sshListHostMapFlags{}
   217  			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
   218  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
   219  			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
   220  			return fl, &s
   221  		},
   222  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   223  			return sshListHostMap(f.handshakeManager, fs, w)
   224  		},
   225  	})
   226  
   227  	ssh.RegisterCommand(&sshd.Command{
   228  		Name:             "list-lighthouse-addrmap",
   229  		ShortDescription: "List all lighthouse map entries",
   230  		Flags: func() (*flag.FlagSet, interface{}) {
   231  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   232  			s := sshListHostMapFlags{}
   233  			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
   234  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
   235  			return fl, &s
   236  		},
   237  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   238  			return sshListLighthouseMap(f.lightHouse, fs, w)
   239  		},
   240  	})
   241  
   242  	ssh.RegisterCommand(&sshd.Command{
   243  		Name:             "reload",
   244  		ShortDescription: "Reloads configuration from disk, same as sending HUP to the process",
   245  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   246  			return sshReload(c, w)
   247  		},
   248  	})
   249  
   250  	ssh.RegisterCommand(&sshd.Command{
   251  		Name:             "start-cpu-profile",
   252  		ShortDescription: "Starts a cpu profile and write output to the provided file, ex: `cpu-profile.pb.gz`",
   253  		Callback:         sshStartCpuProfile,
   254  	})
   255  
   256  	ssh.RegisterCommand(&sshd.Command{
   257  		Name:             "stop-cpu-profile",
   258  		ShortDescription: "Stops a cpu profile and writes output to the previously provided file",
   259  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   260  			pprof.StopCPUProfile()
   261  			return w.WriteLine("If a CPU profile was running it is now stopped")
   262  		},
   263  	})
   264  
   265  	ssh.RegisterCommand(&sshd.Command{
   266  		Name:             "save-heap-profile",
   267  		ShortDescription: "Saves a heap profile to the provided path, ex: `heap-profile.pb.gz`",
   268  		Callback:         sshGetHeapProfile,
   269  	})
   270  
   271  	ssh.RegisterCommand(&sshd.Command{
   272  		Name:             "mutex-profile-fraction",
   273  		ShortDescription: "Gets or sets runtime.SetMutexProfileFraction",
   274  		Callback:         sshMutexProfileFraction,
   275  	})
   276  
   277  	ssh.RegisterCommand(&sshd.Command{
   278  		Name:             "save-mutex-profile",
   279  		ShortDescription: "Saves a mutex profile to the provided path, ex: `mutex-profile.pb.gz`",
   280  		Callback:         sshGetMutexProfile,
   281  	})
   282  
   283  	ssh.RegisterCommand(&sshd.Command{
   284  		Name:             "log-level",
   285  		ShortDescription: "Gets or sets the current log level",
   286  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   287  			return sshLogLevel(l, fs, a, w)
   288  		},
   289  	})
   290  
   291  	ssh.RegisterCommand(&sshd.Command{
   292  		Name:             "log-format",
   293  		ShortDescription: "Gets or sets the current log format",
   294  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   295  			return sshLogFormat(l, fs, a, w)
   296  		},
   297  	})
   298  
   299  	ssh.RegisterCommand(&sshd.Command{
   300  		Name:             "version",
   301  		ShortDescription: "Prints the currently running version of nebula",
   302  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   303  			return sshVersion(f, fs, a, w)
   304  		},
   305  	})
   306  
   307  	ssh.RegisterCommand(&sshd.Command{
   308  		Name:             "device-info",
   309  		ShortDescription: "Prints information about the network device.",
   310  		Flags: func() (*flag.FlagSet, interface{}) {
   311  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   312  			s := sshDeviceInfoFlags{}
   313  			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
   314  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
   315  			return fl, &s
   316  		},
   317  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   318  			return sshDeviceInfo(f, fs, w)
   319  		},
   320  	})
   321  
   322  	ssh.RegisterCommand(&sshd.Command{
   323  		Name:             "print-cert",
   324  		ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip",
   325  		Flags: func() (*flag.FlagSet, interface{}) {
   326  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   327  			s := sshPrintCertFlags{}
   328  			fl.BoolVar(&s.Json, "json", false, "outputs as json")
   329  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
   330  			fl.BoolVar(&s.Raw, "raw", false, "raw prints the PEM encoded certificate, not compatible with -json or -pretty")
   331  			return fl, &s
   332  		},
   333  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   334  			return sshPrintCert(f, fs, a, w)
   335  		},
   336  	})
   337  
   338  	ssh.RegisterCommand(&sshd.Command{
   339  		Name:             "print-tunnel",
   340  		ShortDescription: "Prints json details about a tunnel for the provided vpn ip",
   341  		Flags: func() (*flag.FlagSet, interface{}) {
   342  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   343  			s := sshPrintTunnelFlags{}
   344  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
   345  			return fl, &s
   346  		},
   347  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   348  			return sshPrintTunnel(f, fs, a, w)
   349  		},
   350  	})
   351  
   352  	ssh.RegisterCommand(&sshd.Command{
   353  		Name:             "print-relays",
   354  		ShortDescription: "Prints json details about all relay info",
   355  		Flags: func() (*flag.FlagSet, interface{}) {
   356  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   357  			s := sshPrintTunnelFlags{}
   358  			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json")
   359  			return fl, &s
   360  		},
   361  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   362  			return sshPrintRelays(f, fs, a, w)
   363  		},
   364  	})
   365  
   366  	ssh.RegisterCommand(&sshd.Command{
   367  		Name:             "change-remote",
   368  		ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip",
   369  		Flags: func() (*flag.FlagSet, interface{}) {
   370  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   371  			s := sshChangeRemoteFlags{}
   372  			fl.StringVar(&s.Address, "address", "", "The new remote address, ip:port")
   373  			return fl, &s
   374  		},
   375  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   376  			return sshChangeRemote(f, fs, a, w)
   377  		},
   378  	})
   379  
   380  	ssh.RegisterCommand(&sshd.Command{
   381  		Name:             "close-tunnel",
   382  		ShortDescription: "Closes a tunnel for the provided vpn ip",
   383  		Flags: func() (*flag.FlagSet, interface{}) {
   384  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   385  			s := sshCloseTunnelFlags{}
   386  			fl.BoolVar(&s.LocalOnly, "local-only", false, "Disables notifying the remote that the tunnel is shutting down")
   387  			return fl, &s
   388  		},
   389  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   390  			return sshCloseTunnel(f, fs, a, w)
   391  		},
   392  	})
   393  
   394  	ssh.RegisterCommand(&sshd.Command{
   395  		Name:             "create-tunnel",
   396  		ShortDescription: "Creates a tunnel for the provided vpn ip and address",
   397  		Help:             "The lighthouses will be queried for real addresses but you can provide one as well.",
   398  		Flags: func() (*flag.FlagSet, interface{}) {
   399  			fl := flag.NewFlagSet("", flag.ContinueOnError)
   400  			s := sshCreateTunnelFlags{}
   401  			fl.StringVar(&s.Address, "address", "", "Optionally provide a real remote address, ip:port ")
   402  			return fl, &s
   403  		},
   404  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   405  			return sshCreateTunnel(f, fs, a, w)
   406  		},
   407  	})
   408  
   409  	ssh.RegisterCommand(&sshd.Command{
   410  		Name:             "query-lighthouse",
   411  		ShortDescription: "Query the lighthouses for the provided vpn ip",
   412  		Help:             "This command is asynchronous. Only currently known udp ips will be printed.",
   413  		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
   414  			return sshQueryLighthouse(f, fs, a, w)
   415  		},
   416  	})
   417  }
   418  
   419  func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
   420  	fs, ok := a.(*sshListHostMapFlags)
   421  	if !ok {
   422  		//TODO: error
   423  		return nil
   424  	}
   425  
   426  	var hm []ControlHostInfo
   427  	if fs.ByIndex {
   428  		hm = listHostMapIndexes(hl)
   429  	} else {
   430  		hm = listHostMapHosts(hl)
   431  	}
   432  
   433  	sort.Slice(hm, func(i, j int) bool {
   434  		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
   435  	})
   436  
   437  	if fs.Json || fs.Pretty {
   438  		js := json.NewEncoder(w.GetWriter())
   439  		if fs.Pretty {
   440  			js.SetIndent("", "    ")
   441  		}
   442  
   443  		err := js.Encode(hm)
   444  		if err != nil {
   445  			//TODO
   446  			return nil
   447  		}
   448  
   449  	} else {
   450  		for _, v := range hm {
   451  			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs))
   452  			if err != nil {
   453  				return err
   454  			}
   455  		}
   456  	}
   457  
   458  	return nil
   459  }
   460  
   461  func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWriter) error {
   462  	fs, ok := a.(*sshListHostMapFlags)
   463  	if !ok {
   464  		//TODO: error
   465  		return nil
   466  	}
   467  
   468  	type lighthouseInfo struct {
   469  		VpnIp string    `json:"vpnIp"`
   470  		Addrs *CacheMap `json:"addrs"`
   471  	}
   472  
   473  	lightHouse.RLock()
   474  	addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
   475  	x := 0
   476  	for k, v := range lightHouse.addrMap {
   477  		addrMap[x] = lighthouseInfo{
   478  			VpnIp: k.String(),
   479  			Addrs: v.CopyCache(),
   480  		}
   481  		x++
   482  	}
   483  	lightHouse.RUnlock()
   484  
   485  	sort.Slice(addrMap, func(i, j int) bool {
   486  		return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0
   487  	})
   488  
   489  	if fs.Json || fs.Pretty {
   490  		js := json.NewEncoder(w.GetWriter())
   491  		if fs.Pretty {
   492  			js.SetIndent("", "    ")
   493  		}
   494  
   495  		err := js.Encode(addrMap)
   496  		if err != nil {
   497  			//TODO
   498  			return nil
   499  		}
   500  
   501  	} else {
   502  		for _, v := range addrMap {
   503  			b, err := json.Marshal(v.Addrs)
   504  			if err != nil {
   505  				return err
   506  			}
   507  			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b)))
   508  			if err != nil {
   509  				return err
   510  			}
   511  		}
   512  	}
   513  
   514  	return nil
   515  }
   516  
   517  func sshStartCpuProfile(fs interface{}, a []string, w sshd.StringWriter) error {
   518  	if len(a) == 0 {
   519  		err := w.WriteLine("No path to write profile provided")
   520  		return err
   521  	}
   522  
   523  	file, err := os.Create(a[0])
   524  	if err != nil {
   525  		err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
   526  		return err
   527  	}
   528  
   529  	err = pprof.StartCPUProfile(file)
   530  	if err != nil {
   531  		err = w.WriteLine(fmt.Sprintf("Unable to start cpu profile: %s", err))
   532  		return err
   533  	}
   534  
   535  	err = w.WriteLine(fmt.Sprintf("Started cpu profile, issue stop-cpu-profile to write the output to %s", a))
   536  	return err
   537  }
   538  
   539  func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   540  	return w.WriteLine(fmt.Sprintf("%s", ifce.version))
   541  }
   542  
   543  func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   544  	if len(a) == 0 {
   545  		return w.WriteLine("No vpn ip was provided")
   546  	}
   547  
   548  	parsedIp := net.ParseIP(a[0])
   549  	if parsedIp == nil {
   550  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   551  	}
   552  
   553  	vpnIp := iputil.Ip2VpnIp(parsedIp)
   554  	if vpnIp == 0 {
   555  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   556  	}
   557  
   558  	var cm *CacheMap
   559  	rl := ifce.lightHouse.Query(vpnIp)
   560  	if rl != nil {
   561  		cm = rl.CopyCache()
   562  	}
   563  	return json.NewEncoder(w.GetWriter()).Encode(cm)
   564  }
   565  
   566  func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   567  	flags, ok := fs.(*sshCloseTunnelFlags)
   568  	if !ok {
   569  		//TODO: error
   570  		return nil
   571  	}
   572  
   573  	if len(a) == 0 {
   574  		return w.WriteLine("No vpn ip was provided")
   575  	}
   576  
   577  	parsedIp := net.ParseIP(a[0])
   578  	if parsedIp == nil {
   579  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   580  	}
   581  
   582  	vpnIp := iputil.Ip2VpnIp(parsedIp)
   583  	if vpnIp == 0 {
   584  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   585  	}
   586  
   587  	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
   588  	if hostInfo == nil {
   589  		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
   590  	}
   591  
   592  	if !flags.LocalOnly {
   593  		ifce.send(
   594  			header.CloseTunnel,
   595  			0,
   596  			hostInfo.ConnectionState,
   597  			hostInfo,
   598  			[]byte{},
   599  			make([]byte, 12, 12),
   600  			make([]byte, mtu),
   601  		)
   602  	}
   603  
   604  	ifce.closeTunnel(hostInfo)
   605  	return w.WriteLine("Closed")
   606  }
   607  
   608  func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   609  	flags, ok := fs.(*sshCreateTunnelFlags)
   610  	if !ok {
   611  		//TODO: error
   612  		return nil
   613  	}
   614  
   615  	if len(a) == 0 {
   616  		return w.WriteLine("No vpn ip was provided")
   617  	}
   618  
   619  	parsedIp := net.ParseIP(a[0])
   620  	if parsedIp == nil {
   621  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   622  	}
   623  
   624  	vpnIp := iputil.Ip2VpnIp(parsedIp)
   625  	if vpnIp == 0 {
   626  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   627  	}
   628  
   629  	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
   630  	if hostInfo != nil {
   631  		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
   632  	}
   633  
   634  	hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
   635  	if hostInfo != nil {
   636  		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
   637  	}
   638  
   639  	var addr *udp.Addr
   640  	if flags.Address != "" {
   641  		addr = udp.NewAddrFromString(flags.Address)
   642  		if addr == nil {
   643  			return w.WriteLine("Address could not be parsed")
   644  		}
   645  	}
   646  
   647  	hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil)
   648  	if addr != nil {
   649  		hostInfo.SetRemote(addr)
   650  	}
   651  
   652  	return w.WriteLine("Created")
   653  }
   654  
   655  func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   656  	flags, ok := fs.(*sshChangeRemoteFlags)
   657  	if !ok {
   658  		//TODO: error
   659  		return nil
   660  	}
   661  
   662  	if len(a) == 0 {
   663  		return w.WriteLine("No vpn ip was provided")
   664  	}
   665  
   666  	if flags.Address == "" {
   667  		return w.WriteLine("No address was provided")
   668  	}
   669  
   670  	addr := udp.NewAddrFromString(flags.Address)
   671  	if addr == nil {
   672  		return w.WriteLine("Address could not be parsed")
   673  	}
   674  
   675  	parsedIp := net.ParseIP(a[0])
   676  	if parsedIp == nil {
   677  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   678  	}
   679  
   680  	vpnIp := iputil.Ip2VpnIp(parsedIp)
   681  	if vpnIp == 0 {
   682  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   683  	}
   684  
   685  	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
   686  	if hostInfo == nil {
   687  		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
   688  	}
   689  
   690  	hostInfo.SetRemote(addr)
   691  	return w.WriteLine("Changed")
   692  }
   693  
   694  func sshGetHeapProfile(fs interface{}, a []string, w sshd.StringWriter) error {
   695  	if len(a) == 0 {
   696  		return w.WriteLine("No path to write profile provided")
   697  	}
   698  
   699  	file, err := os.Create(a[0])
   700  	if err != nil {
   701  		err = w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
   702  		return err
   703  	}
   704  
   705  	err = pprof.WriteHeapProfile(file)
   706  	if err != nil {
   707  		err = w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err))
   708  		return err
   709  	}
   710  
   711  	err = w.WriteLine(fmt.Sprintf("Mem profile created at %s", a))
   712  	return err
   713  }
   714  
   715  func sshMutexProfileFraction(fs interface{}, a []string, w sshd.StringWriter) error {
   716  	if len(a) == 0 {
   717  		rate := runtime.SetMutexProfileFraction(-1)
   718  		return w.WriteLine(fmt.Sprintf("Current value: %d", rate))
   719  	}
   720  
   721  	newRate, err := strconv.Atoi(a[0])
   722  	if err != nil {
   723  		return w.WriteLine(fmt.Sprintf("Invalid argument: %s", a[0]))
   724  	}
   725  
   726  	oldRate := runtime.SetMutexProfileFraction(newRate)
   727  	return w.WriteLine(fmt.Sprintf("New value: %d. Old value: %d", newRate, oldRate))
   728  }
   729  
   730  func sshGetMutexProfile(fs interface{}, a []string, w sshd.StringWriter) error {
   731  	if len(a) == 0 {
   732  		return w.WriteLine("No path to write profile provided")
   733  	}
   734  
   735  	file, err := os.Create(a[0])
   736  	if err != nil {
   737  		return w.WriteLine(fmt.Sprintf("Unable to create profile file: %s", err))
   738  	}
   739  	defer file.Close()
   740  
   741  	mutexProfile := pprof.Lookup("mutex")
   742  	if mutexProfile == nil {
   743  		return w.WriteLine("Unable to get pprof.Lookup(\"mutex\")")
   744  	}
   745  
   746  	err = mutexProfile.WriteTo(file, 0)
   747  	if err != nil {
   748  		return w.WriteLine(fmt.Sprintf("Unable to write profile: %s", err))
   749  	}
   750  
   751  	return w.WriteLine(fmt.Sprintf("Mutex profile created at %s", a))
   752  }
   753  
   754  func sshLogLevel(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
   755  	if len(a) == 0 {
   756  		return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
   757  	}
   758  
   759  	level, err := logrus.ParseLevel(a[0])
   760  	if err != nil {
   761  		return w.WriteLine(fmt.Sprintf("Unknown log level %s. Possible log levels: %s", a, logrus.AllLevels))
   762  	}
   763  
   764  	l.SetLevel(level)
   765  	return w.WriteLine(fmt.Sprintf("Log level is: %s", l.Level))
   766  }
   767  
   768  func sshLogFormat(l *logrus.Logger, fs interface{}, a []string, w sshd.StringWriter) error {
   769  	if len(a) == 0 {
   770  		return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
   771  	}
   772  
   773  	logFormat := strings.ToLower(a[0])
   774  	switch logFormat {
   775  	case "text":
   776  		l.Formatter = &logrus.TextFormatter{}
   777  	case "json":
   778  		l.Formatter = &logrus.JSONFormatter{}
   779  	default:
   780  		return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
   781  	}
   782  
   783  	return w.WriteLine(fmt.Sprintf("Log format is: %s", reflect.TypeOf(l.Formatter)))
   784  }
   785  
   786  func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   787  	args, ok := fs.(*sshPrintCertFlags)
   788  	if !ok {
   789  		//TODO: error
   790  		return nil
   791  	}
   792  
   793  	cert := ifce.pki.GetCertState().Certificate
   794  	if len(a) > 0 {
   795  		parsedIp := net.ParseIP(a[0])
   796  		if parsedIp == nil {
   797  			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   798  		}
   799  
   800  		vpnIp := iputil.Ip2VpnIp(parsedIp)
   801  		if vpnIp == 0 {
   802  			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   803  		}
   804  
   805  		hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
   806  		if hostInfo == nil {
   807  			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
   808  		}
   809  
   810  		cert = hostInfo.GetCert()
   811  	}
   812  
   813  	if args.Json || args.Pretty {
   814  		b, err := cert.MarshalJSON()
   815  		if err != nil {
   816  			//TODO: handle it
   817  			return nil
   818  		}
   819  
   820  		if args.Pretty {
   821  			buf := new(bytes.Buffer)
   822  			err := json.Indent(buf, b, "", "    ")
   823  			b = buf.Bytes()
   824  			if err != nil {
   825  				//TODO: handle it
   826  				return nil
   827  			}
   828  		}
   829  
   830  		return w.WriteBytes(b)
   831  	}
   832  
   833  	if args.Raw {
   834  		b, err := cert.MarshalToPEM()
   835  		if err != nil {
   836  			//TODO: handle it
   837  			return nil
   838  		}
   839  
   840  		return w.WriteBytes(b)
   841  	}
   842  
   843  	return w.WriteLine(cert.String())
   844  }
   845  
   846  func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   847  	args, ok := fs.(*sshPrintTunnelFlags)
   848  	if !ok {
   849  		//TODO: error
   850  		w.WriteLine(fmt.Sprintf("sshPrintRelays failed to convert args type"))
   851  		return nil
   852  	}
   853  
   854  	relays := map[uint32]*HostInfo{}
   855  	ifce.hostMap.Lock()
   856  	for k, v := range ifce.hostMap.Relays {
   857  		relays[k] = v
   858  	}
   859  	ifce.hostMap.Unlock()
   860  
   861  	type RelayFor struct {
   862  		Error          error
   863  		Type           string
   864  		State          string
   865  		PeerIp         iputil.VpnIp
   866  		LocalIndex     uint32
   867  		RemoteIndex    uint32
   868  		RelayedThrough []iputil.VpnIp
   869  	}
   870  
   871  	type RelayOutput struct {
   872  		NebulaIp    iputil.VpnIp
   873  		RelayForIps []RelayFor
   874  	}
   875  
   876  	type CmdOutput struct {
   877  		Relays []*RelayOutput
   878  	}
   879  
   880  	co := CmdOutput{}
   881  
   882  	enc := json.NewEncoder(w.GetWriter())
   883  
   884  	if args.Pretty {
   885  		enc.SetIndent("", "    ")
   886  	}
   887  
   888  	for k, v := range relays {
   889  		ro := RelayOutput{NebulaIp: v.vpnIp}
   890  		co.Relays = append(co.Relays, &ro)
   891  		relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
   892  		if relayHI == nil {
   893  			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
   894  			continue
   895  		}
   896  		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
   897  			rf := RelayFor{Error: nil}
   898  			r, ok := relayHI.relayState.GetRelayForByIp(vpnIp)
   899  			if ok {
   900  				t := ""
   901  				switch r.Type {
   902  				case ForwardingType:
   903  					t = "forwarding"
   904  				case TerminalType:
   905  					t = "terminal"
   906  				default:
   907  					t = "unknown"
   908  				}
   909  
   910  				s := ""
   911  				switch r.State {
   912  				case Requested:
   913  					s = "requested"
   914  				case Established:
   915  					s = "established"
   916  				default:
   917  					s = "unknown"
   918  				}
   919  
   920  				rf.LocalIndex = r.LocalIndex
   921  				rf.RemoteIndex = r.RemoteIndex
   922  				rf.PeerIp = r.PeerIp
   923  				rf.Type = t
   924  				rf.State = s
   925  				if rf.LocalIndex != k {
   926  					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
   927  				}
   928  			}
   929  			relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
   930  			if relayedHI != nil {
   931  				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
   932  			}
   933  
   934  			ro.RelayForIps = append(ro.RelayForIps, rf)
   935  		}
   936  	}
   937  	err := enc.Encode(co)
   938  	if err != nil {
   939  		return err
   940  	}
   941  	return nil
   942  }
   943  
   944  func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
   945  	args, ok := fs.(*sshPrintTunnelFlags)
   946  	if !ok {
   947  		//TODO: error
   948  		return nil
   949  	}
   950  
   951  	if len(a) == 0 {
   952  		return w.WriteLine("No vpn ip was provided")
   953  	}
   954  
   955  	parsedIp := net.ParseIP(a[0])
   956  	if parsedIp == nil {
   957  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   958  	}
   959  
   960  	vpnIp := iputil.Ip2VpnIp(parsedIp)
   961  	if vpnIp == 0 {
   962  		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
   963  	}
   964  
   965  	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
   966  	if hostInfo == nil {
   967  		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
   968  	}
   969  
   970  	enc := json.NewEncoder(w.GetWriter())
   971  	if args.Pretty {
   972  		enc.SetIndent("", "    ")
   973  	}
   974  
   975  	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
   976  }
   977  
   978  func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error {
   979  
   980  	data := struct {
   981  		Name string `json:"name"`
   982  		Cidr string `json:"cidr"`
   983  	}{
   984  		Name: ifce.inside.Name(),
   985  		Cidr: ifce.inside.Cidr().String(),
   986  	}
   987  
   988  	flags, ok := fs.(*sshDeviceInfoFlags)
   989  	if !ok {
   990  		return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs)
   991  	}
   992  
   993  	if flags.Json || flags.Pretty {
   994  		js := json.NewEncoder(w.GetWriter())
   995  		if flags.Pretty {
   996  			js.SetIndent("", "    ")
   997  		}
   998  
   999  		return js.Encode(data)
  1000  	} else {
  1001  		return w.WriteLine(fmt.Sprintf("name=%v cidr=%v", data.Name, data.Cidr))
  1002  	}
  1003  }
  1004  
  1005  func sshReload(c *config.C, w sshd.StringWriter) error {
  1006  	err := w.WriteLine("Reloading config")
  1007  	c.ReloadConfig()
  1008  	return err
  1009  }