github.com/vmware/govmomi@v0.37.1/govc/flags/client.go (about)

     1  /*
     2  Copyright (c) 2014-2023 VMware, Inc. All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8  http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package flags
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"errors"
    23  	"flag"
    24  	"fmt"
    25  	"net/url"
    26  	"os"
    27  	"os/signal"
    28  	"path/filepath"
    29  	"strings"
    30  	"syscall"
    31  	"time"
    32  
    33  	"github.com/vmware/govmomi/cns"
    34  	"github.com/vmware/govmomi/pbm"
    35  	"github.com/vmware/govmomi/session"
    36  	"github.com/vmware/govmomi/session/cache"
    37  	"github.com/vmware/govmomi/session/keepalive"
    38  	"github.com/vmware/govmomi/vapi/rest"
    39  	"github.com/vmware/govmomi/vim25"
    40  	"github.com/vmware/govmomi/vim25/soap"
    41  )
    42  
    43  const (
    44  	envURL           = "GOVC_URL"
    45  	envUsername      = "GOVC_USERNAME"
    46  	envPassword      = "GOVC_PASSWORD"
    47  	envCertificate   = "GOVC_CERTIFICATE"
    48  	envPrivateKey    = "GOVC_PRIVATE_KEY"
    49  	envInsecure      = "GOVC_INSECURE"
    50  	envPersist       = "GOVC_PERSIST_SESSION"
    51  	envMinAPIVersion = "GOVC_MIN_API_VERSION"
    52  	envVimNamespace  = "GOVC_VIM_NAMESPACE"
    53  	envVimVersion    = "GOVC_VIM_VERSION"
    54  	envTLSCaCerts    = "GOVC_TLS_CA_CERTS"
    55  	envTLSKnownHosts = "GOVC_TLS_KNOWN_HOSTS"
    56  
    57  	defaultMinVimVersion = "5.5"
    58  )
    59  
    60  const cDescr = "ESX or vCenter URL"
    61  
    62  type ClientFlag struct {
    63  	common
    64  
    65  	*DebugFlag
    66  
    67  	username      string
    68  	password      string
    69  	cert          string
    70  	key           string
    71  	persist       bool
    72  	minAPIVersion string
    73  	vimNamespace  string
    74  	vimVersion    string
    75  	tlsCaCerts    string
    76  	tlsKnownHosts string
    77  	client        *vim25.Client
    78  	restClient    *rest.Client
    79  	Session       cache.Session
    80  }
    81  
    82  var (
    83  	home          = os.Getenv("GOVMOMI_HOME")
    84  	clientFlagKey = flagKey("client")
    85  )
    86  
    87  func init() {
    88  	if home == "" {
    89  		home = filepath.Join(os.Getenv("HOME"), ".govmomi")
    90  	}
    91  }
    92  
    93  func NewClientFlag(ctx context.Context) (*ClientFlag, context.Context) {
    94  	if v := ctx.Value(clientFlagKey); v != nil {
    95  		return v.(*ClientFlag), ctx
    96  	}
    97  
    98  	v := &ClientFlag{}
    99  	v.DebugFlag, ctx = NewDebugFlag(ctx)
   100  	ctx = context.WithValue(ctx, clientFlagKey, v)
   101  	return v, ctx
   102  }
   103  
   104  func (flag *ClientFlag) String() string {
   105  	url := flag.Session.Endpoint()
   106  	if url == nil {
   107  		return ""
   108  	}
   109  
   110  	return url.String()
   111  }
   112  
   113  func (flag *ClientFlag) Set(s string) error {
   114  	var err error
   115  
   116  	flag.Session.URL, err = soap.ParseURL(s)
   117  
   118  	return err
   119  }
   120  
   121  func (flag *ClientFlag) Register(ctx context.Context, f *flag.FlagSet) {
   122  	flag.RegisterOnce(func() {
   123  		flag.DebugFlag.Register(ctx, f)
   124  
   125  		{
   126  			flag.Set(os.Getenv(envURL))
   127  			usage := fmt.Sprintf("%s [%s]", cDescr, envURL)
   128  			f.Var(flag, "u", usage)
   129  		}
   130  
   131  		{
   132  			flag.username = os.Getenv(envUsername)
   133  			flag.password = os.Getenv(envPassword)
   134  		}
   135  
   136  		{
   137  			value := os.Getenv(envCertificate)
   138  			usage := fmt.Sprintf("Certificate [%s]", envCertificate)
   139  			f.StringVar(&flag.cert, "cert", value, usage)
   140  		}
   141  
   142  		{
   143  			value := os.Getenv(envPrivateKey)
   144  			usage := fmt.Sprintf("Private key [%s]", envPrivateKey)
   145  			f.StringVar(&flag.key, "key", value, usage)
   146  		}
   147  
   148  		{
   149  			insecure := false
   150  			switch env := strings.ToLower(os.Getenv(envInsecure)); env {
   151  			case "1", "true":
   152  				insecure = true
   153  			}
   154  
   155  			usage := fmt.Sprintf("Skip verification of server certificate [%s]", envInsecure)
   156  			f.BoolVar(&flag.Session.Insecure, "k", insecure, usage)
   157  		}
   158  
   159  		{
   160  			persist := true
   161  			switch env := strings.ToLower(os.Getenv(envPersist)); env {
   162  			case "0", "false":
   163  				persist = false
   164  			}
   165  
   166  			usage := fmt.Sprintf("Persist session to disk [%s]", envPersist)
   167  			f.BoolVar(&flag.persist, "persist-session", persist, usage)
   168  		}
   169  
   170  		{
   171  			env := os.Getenv(envMinAPIVersion)
   172  			if env == "" {
   173  				env = defaultMinVimVersion
   174  			}
   175  
   176  			flag.minAPIVersion = env
   177  		}
   178  
   179  		{
   180  			value := os.Getenv(envVimNamespace)
   181  			if value == "" {
   182  				value = vim25.Namespace
   183  			}
   184  			usage := fmt.Sprintf("Vim namespace [%s]", envVimNamespace)
   185  			f.StringVar(&flag.vimNamespace, "vim-namespace", value, usage)
   186  		}
   187  
   188  		{
   189  			value := os.Getenv(envVimVersion)
   190  			if value == "" {
   191  				value = vim25.Version
   192  			}
   193  			usage := fmt.Sprintf("Vim version [%s]", envVimVersion)
   194  			f.StringVar(&flag.vimVersion, "vim-version", value, usage)
   195  		}
   196  
   197  		{
   198  			value := os.Getenv(envTLSCaCerts)
   199  			usage := fmt.Sprintf("TLS CA certificates file [%s]", envTLSCaCerts)
   200  			f.StringVar(&flag.tlsCaCerts, "tls-ca-certs", value, usage)
   201  		}
   202  
   203  		{
   204  			value := os.Getenv(envTLSKnownHosts)
   205  			usage := fmt.Sprintf("TLS known hosts file [%s]", envTLSKnownHosts)
   206  			f.StringVar(&flag.tlsKnownHosts, "tls-known-hosts", value, usage)
   207  		}
   208  	})
   209  }
   210  
   211  func (flag *ClientFlag) Process(ctx context.Context) error {
   212  	return flag.ProcessOnce(func() error {
   213  		err := flag.DebugFlag.Process(ctx)
   214  		if err != nil {
   215  			return err
   216  		}
   217  
   218  		if flag.Session.URL == nil {
   219  			return errors.New("specify an " + cDescr)
   220  		}
   221  
   222  		if !flag.persist {
   223  			flag.Session.Passthrough = true
   224  		}
   225  
   226  		flag.username, err = session.Secret(flag.username)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		flag.password, err = session.Secret(flag.password)
   231  		if err != nil {
   232  			return err
   233  		}
   234  
   235  		// Override username if set
   236  		if flag.username != "" {
   237  			var password string
   238  			var ok bool
   239  
   240  			if flag.Session.URL.User != nil {
   241  				password, ok = flag.Session.URL.User.Password()
   242  			}
   243  
   244  			if ok {
   245  				flag.Session.URL.User = url.UserPassword(flag.username, password)
   246  			} else {
   247  				flag.Session.URL.User = url.User(flag.username)
   248  			}
   249  		}
   250  
   251  		// Override password if set
   252  		if flag.password != "" {
   253  			var username string
   254  
   255  			if flag.Session.URL.User != nil {
   256  				username = flag.Session.URL.User.Username()
   257  			}
   258  
   259  			flag.Session.URL.User = url.UserPassword(username, flag.password)
   260  		}
   261  
   262  		return nil
   263  	})
   264  }
   265  
   266  func (flag *ClientFlag) ConfigureTLS(sc *soap.Client) error {
   267  	if flag.cert != "" {
   268  		cert, err := tls.LoadX509KeyPair(flag.cert, flag.key)
   269  		if err != nil {
   270  			return fmt.Errorf("%s=%q %s=%q: %s", envCertificate, flag.cert, envPrivateKey, flag.key, err)
   271  		}
   272  
   273  		sc.SetCertificate(cert)
   274  	}
   275  
   276  	// Set namespace and version
   277  	sc.Namespace = "urn:" + flag.vimNamespace
   278  	sc.Version = flag.vimVersion
   279  
   280  	sc.UserAgent = fmt.Sprintf("govc/%s", strings.TrimPrefix(BuildVersion, "v"))
   281  
   282  	if err := flag.SetRootCAs(sc); err != nil {
   283  		return err
   284  	}
   285  
   286  	if err := sc.LoadThumbprints(flag.tlsKnownHosts); err != nil {
   287  		return err
   288  	}
   289  
   290  	t := sc.DefaultTransport()
   291  	var err error
   292  
   293  	value := os.Getenv("GOVC_TLS_HANDSHAKE_TIMEOUT")
   294  	if value != "" {
   295  		t.TLSHandshakeTimeout, err = time.ParseDuration(value)
   296  		if err != nil {
   297  			return err
   298  		}
   299  	}
   300  
   301  	sc.UseJSON(os.Getenv("GOVC_VI_JSON") != "")
   302  
   303  	return nil
   304  }
   305  
   306  func (flag *ClientFlag) SetRootCAs(c *soap.Client) error {
   307  	if flag.tlsCaCerts != "" {
   308  		return c.SetRootCAs(flag.tlsCaCerts)
   309  	}
   310  	return nil
   311  }
   312  
   313  func isDevelopmentVersion(apiVersion string) bool {
   314  	// Skip version check for development builds which can be in the form of "r4A70F" or "6.5.x"
   315  	return strings.Count(apiVersion, ".") == 0 || strings.HasSuffix(apiVersion, ".x")
   316  }
   317  
   318  // apiVersionValid returns whether or not the API version supported by the
   319  // server the client is connected to is not recent enough.
   320  func apiVersionValid(c *vim25.Client, minVersionString string) error {
   321  	if minVersionString == "-" {
   322  		// Disable version check
   323  		return nil
   324  	}
   325  
   326  	apiVersion := c.ServiceContent.About.ApiVersion
   327  	if isDevelopmentVersion(apiVersion) {
   328  		return nil
   329  	}
   330  
   331  	realVersion, err := ParseVersion(apiVersion)
   332  	if err != nil {
   333  		return fmt.Errorf("error parsing API version %q: %s", apiVersion, err)
   334  	}
   335  
   336  	minVersion, err := ParseVersion(minVersionString)
   337  	if err != nil {
   338  		return fmt.Errorf("error parsing %s=%q: %s", envMinAPIVersion, minVersionString, err)
   339  	}
   340  
   341  	if !minVersion.Lte(realVersion) {
   342  		err = fmt.Errorf("require API version %q, connected to API version %q (set %s to override)",
   343  			minVersionString,
   344  			c.ServiceContent.About.ApiVersion,
   345  			envMinAPIVersion)
   346  		return err
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  func (flag *ClientFlag) RoundTripper(c *soap.Client) soap.RoundTripper {
   353  	// Retry twice when a temporary I/O error occurs.
   354  	// This means a maximum of 3 attempts.
   355  	rt := vim25.Retry(c, vim25.RetryTemporaryNetworkError, 3)
   356  
   357  	switch {
   358  	case flag.dump:
   359  		rt = &dump{roundTripper: rt}
   360  	case flag.verbose:
   361  		rt = &verbose{roundTripper: rt}
   362  	}
   363  
   364  	return rt
   365  }
   366  
   367  func (flag *ClientFlag) Client() (*vim25.Client, error) {
   368  	if flag.client != nil {
   369  		return flag.client, nil
   370  	}
   371  
   372  	c := new(vim25.Client)
   373  	err := flag.Session.Login(context.Background(), c, flag.ConfigureTLS)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  
   378  	// Check that the endpoint has the right API version
   379  	err = apiVersionValid(c, flag.minAPIVersion)
   380  	if err != nil {
   381  		return nil, err
   382  	}
   383  
   384  	if flag.vimVersion == "" {
   385  		err = c.UseServiceVersion()
   386  		if err != nil {
   387  			return nil, err
   388  		}
   389  	}
   390  
   391  	c.RoundTripper = flag.RoundTripper(c.Client)
   392  	flag.client = c
   393  
   394  	return flag.client, nil
   395  }
   396  
   397  func (flag *ClientFlag) RestClient() (*rest.Client, error) {
   398  	if flag.restClient != nil {
   399  		return flag.restClient, nil
   400  	}
   401  
   402  	c := new(rest.Client)
   403  
   404  	err := flag.Session.Login(context.Background(), c, flag.ConfigureTLS)
   405  	if err != nil {
   406  		return nil, err
   407  	}
   408  
   409  	flag.restClient = c
   410  	return flag.restClient, nil
   411  }
   412  
   413  func (flag *ClientFlag) PbmClient() (*pbm.Client, error) {
   414  	vc, err := flag.Client()
   415  	if err != nil {
   416  		return nil, err
   417  	}
   418  	c, err := pbm.NewClient(context.Background(), vc)
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  
   423  	c.RoundTripper = flag.RoundTripper(c.Client)
   424  
   425  	return c, nil
   426  }
   427  
   428  func (flag *ClientFlag) CnsClient() (*cns.Client, error) {
   429  	vc, err := flag.Client()
   430  	if err != nil {
   431  		return nil, err
   432  	}
   433  
   434  	c, err := cns.NewClient(context.Background(), vc)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  
   439  	c.RoundTripper = flag.RoundTripper(c.Client)
   440  
   441  	return c, nil
   442  }
   443  
   444  func (flag *ClientFlag) KeepAlive(client cache.Client) {
   445  	switch c := client.(type) {
   446  	case *vim25.Client:
   447  		keepalive.NewHandlerSOAP(c, 0, nil).Start()
   448  	case *rest.Client:
   449  		keepalive.NewHandlerREST(c, 0, nil).Start()
   450  	default:
   451  		panic(fmt.Sprintf("unsupported client type=%T", client))
   452  	}
   453  }
   454  
   455  func (flag *ClientFlag) Logout(ctx context.Context) error {
   456  	if flag.client != nil {
   457  		_ = flag.Session.Logout(ctx, flag.client)
   458  	}
   459  
   460  	if flag.restClient != nil {
   461  		_ = flag.Session.Logout(ctx, flag.restClient)
   462  	}
   463  
   464  	return nil
   465  }
   466  
   467  // Environ returns the govc environment variables for this connection
   468  func (flag *ClientFlag) Environ(extra bool) []string {
   469  	var env []string
   470  	add := func(k, v string) {
   471  		env = append(env, fmt.Sprintf("%s=%s", k, v))
   472  	}
   473  
   474  	u := *flag.Session.URL
   475  	if u.User != nil {
   476  		add(envUsername, u.User.Username())
   477  
   478  		if p, ok := u.User.Password(); ok {
   479  			add(envPassword, p)
   480  		}
   481  
   482  		u.User = nil
   483  	}
   484  
   485  	if u.Path == vim25.Path {
   486  		u.Path = ""
   487  	}
   488  	u.Fragment = ""
   489  	u.RawQuery = ""
   490  
   491  	add(envURL, strings.TrimPrefix(u.String(), "https://"))
   492  
   493  	keys := []string{
   494  		envCertificate,
   495  		envPrivateKey,
   496  		envInsecure,
   497  		envPersist,
   498  		envMinAPIVersion,
   499  		envVimNamespace,
   500  		envVimVersion,
   501  	}
   502  
   503  	for _, k := range keys {
   504  		if v := os.Getenv(k); v != "" {
   505  			add(k, v)
   506  		}
   507  	}
   508  
   509  	if extra {
   510  		add("GOVC_URL_SCHEME", flag.Session.URL.Scheme)
   511  
   512  		v := strings.SplitN(u.Host, ":", 2)
   513  		add("GOVC_URL_HOST", v[0])
   514  		if len(v) == 2 {
   515  			add("GOVC_URL_PORT", v[1])
   516  		}
   517  
   518  		add("GOVC_URL_PATH", flag.Session.URL.Path)
   519  
   520  		if f := flag.Session.URL.Fragment; f != "" {
   521  			add("GOVC_URL_FRAGMENT", f)
   522  		}
   523  
   524  		if q := flag.Session.URL.RawQuery; q != "" {
   525  			add("GOVC_URL_QUERY", q)
   526  		}
   527  	}
   528  
   529  	return env
   530  }
   531  
   532  // WithCancel calls the given function, returning when complete or canceled via SIGINT.
   533  func (flag *ClientFlag) WithCancel(ctx context.Context, f func(context.Context) error) error {
   534  	sig := make(chan os.Signal, 1)
   535  	signal.Notify(sig, syscall.SIGINT)
   536  
   537  	wctx, cancel := context.WithCancel(ctx)
   538  	defer cancel()
   539  
   540  	done := make(chan bool)
   541  	var werr error
   542  
   543  	go func() {
   544  		defer close(done)
   545  		werr = f(wctx)
   546  	}()
   547  
   548  	select {
   549  	case <-sig:
   550  		cancel()
   551  		<-done // Wait for f() to complete
   552  	case <-done:
   553  	}
   554  
   555  	return werr
   556  }