go.etcd.io/etcd@v3.3.27+incompatible/etcdctl/ctlv2/command/util.go (about)

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package command
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"os"
    27  	"strings"
    28  	"syscall"
    29  	"time"
    30  
    31  	"github.com/coreos/etcd/client"
    32  	"github.com/coreos/etcd/pkg/transport"
    33  
    34  	"github.com/bgentry/speakeasy"
    35  	"github.com/urfave/cli"
    36  )
    37  
    38  var (
    39  	ErrNoAvailSrc = errors.New("no available argument and stdin")
    40  
    41  	// the maximum amount of time a dial will wait for a connection to setup.
    42  	// 30s is long enough for most of the network conditions.
    43  	defaultDialTimeout = 30 * time.Second
    44  )
    45  
    46  func argOrStdin(args []string, stdin io.Reader, i int) (string, error) {
    47  	if i < len(args) {
    48  		return args[i], nil
    49  	}
    50  	bytes, err := ioutil.ReadAll(stdin)
    51  	if string(bytes) == "" || err != nil {
    52  		return "", ErrNoAvailSrc
    53  	}
    54  	return string(bytes), nil
    55  }
    56  
    57  func getPeersFlagValue(c *cli.Context) []string {
    58  	peerstr := c.GlobalString("endpoints")
    59  
    60  	if peerstr == "" {
    61  		peerstr = os.Getenv("ETCDCTL_ENDPOINTS")
    62  	}
    63  
    64  	if peerstr == "" {
    65  		peerstr = c.GlobalString("endpoint")
    66  	}
    67  
    68  	if peerstr == "" {
    69  		peerstr = os.Getenv("ETCDCTL_ENDPOINT")
    70  	}
    71  
    72  	if peerstr == "" {
    73  		peerstr = c.GlobalString("peers")
    74  	}
    75  
    76  	if peerstr == "" {
    77  		peerstr = os.Getenv("ETCDCTL_PEERS")
    78  	}
    79  
    80  	// If we still don't have peers, use a default
    81  	if peerstr == "" {
    82  		peerstr = "http://127.0.0.1:2379,http://127.0.0.1:4001"
    83  	}
    84  
    85  	return strings.Split(peerstr, ",")
    86  }
    87  
    88  func getDomainDiscoveryFlagValue(c *cli.Context) ([]string, error) {
    89  	domainstr, insecure := getDiscoveryDomain(c)
    90  
    91  	// If we still don't have domain discovery, return nothing
    92  	if domainstr == "" {
    93  		return []string{}, nil
    94  	}
    95  
    96  	discoverer := client.NewSRVDiscover()
    97  	eps, err := discoverer.Discover(domainstr)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	if insecure {
   102  		return eps, err
   103  	}
   104  	// strip insecure connections
   105  	ret := []string{}
   106  	for _, ep := range eps {
   107  		if strings.HasPrefix(ep, "http://") {
   108  			fmt.Fprintf(os.Stderr, "ignoring discovered insecure endpoint %q\n", ep)
   109  			continue
   110  		}
   111  		ret = append(ret, ep)
   112  	}
   113  	return ret, err
   114  }
   115  
   116  func getDiscoveryDomain(c *cli.Context) (domainstr string, insecure bool) {
   117  	domainstr = c.GlobalString("discovery-srv")
   118  	// Use an environment variable if nothing was supplied on the
   119  	// command line
   120  	if domainstr == "" {
   121  		domainstr = os.Getenv("ETCDCTL_DISCOVERY_SRV")
   122  	}
   123  	insecure = c.GlobalBool("insecure-discovery") || (os.Getenv("ETCDCTL_INSECURE_DISCOVERY") != "")
   124  	return domainstr, insecure
   125  }
   126  
   127  func getEndpoints(c *cli.Context) ([]string, error) {
   128  	eps, err := getDomainDiscoveryFlagValue(c)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	// If domain discovery returns no endpoints, check peer flag
   134  	if len(eps) == 0 {
   135  		eps = getPeersFlagValue(c)
   136  	}
   137  
   138  	for i, ep := range eps {
   139  		u, err := url.Parse(ep)
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  
   144  		if u.Scheme == "" {
   145  			u.Scheme = "http"
   146  		}
   147  
   148  		eps[i] = u.String()
   149  	}
   150  
   151  	return eps, nil
   152  }
   153  
   154  func getTransport(c *cli.Context) (*http.Transport, error) {
   155  	cafile := c.GlobalString("ca-file")
   156  	certfile := c.GlobalString("cert-file")
   157  	keyfile := c.GlobalString("key-file")
   158  
   159  	// Use an environment variable if nothing was supplied on the
   160  	// command line
   161  	if cafile == "" {
   162  		cafile = os.Getenv("ETCDCTL_CA_FILE")
   163  	}
   164  	if certfile == "" {
   165  		certfile = os.Getenv("ETCDCTL_CERT_FILE")
   166  	}
   167  	if keyfile == "" {
   168  		keyfile = os.Getenv("ETCDCTL_KEY_FILE")
   169  	}
   170  
   171  	discoveryDomain, insecure := getDiscoveryDomain(c)
   172  	if insecure {
   173  		discoveryDomain = ""
   174  	}
   175  	tls := transport.TLSInfo{
   176  		CAFile:     cafile,
   177  		CertFile:   certfile,
   178  		KeyFile:    keyfile,
   179  		ServerName: discoveryDomain,
   180  	}
   181  
   182  	dialTimeout := defaultDialTimeout
   183  	totalTimeout := c.GlobalDuration("total-timeout")
   184  	if totalTimeout != 0 && totalTimeout < dialTimeout {
   185  		dialTimeout = totalTimeout
   186  	}
   187  	return transport.NewTransport(tls, dialTimeout)
   188  }
   189  
   190  func getUsernamePasswordFromFlag(usernameFlag string) (username string, password string, err error) {
   191  	return getUsernamePassword("Password: ", usernameFlag)
   192  }
   193  
   194  func getUsernamePassword(prompt, usernameFlag string) (username string, password string, err error) {
   195  	colon := strings.Index(usernameFlag, ":")
   196  	if colon == -1 {
   197  		username = usernameFlag
   198  		// Prompt for the password.
   199  		password, err = speakeasy.Ask(prompt)
   200  		if err != nil {
   201  			return "", "", err
   202  		}
   203  	} else {
   204  		username = usernameFlag[:colon]
   205  		password = usernameFlag[colon+1:]
   206  	}
   207  	return username, password, nil
   208  }
   209  
   210  func mustNewKeyAPI(c *cli.Context) client.KeysAPI {
   211  	return client.NewKeysAPI(mustNewClient(c))
   212  }
   213  
   214  func mustNewMembersAPI(c *cli.Context) client.MembersAPI {
   215  	return client.NewMembersAPI(mustNewClient(c))
   216  }
   217  
   218  func mustNewClient(c *cli.Context) client.Client {
   219  	hc, err := newClient(c)
   220  	if err != nil {
   221  		fmt.Fprintln(os.Stderr, err.Error())
   222  		os.Exit(1)
   223  	}
   224  
   225  	debug := c.GlobalBool("debug")
   226  	if debug {
   227  		client.EnablecURLDebug()
   228  	}
   229  
   230  	if !c.GlobalBool("no-sync") {
   231  		if debug {
   232  			fmt.Fprintf(os.Stderr, "start to sync cluster using endpoints(%s)\n", strings.Join(hc.Endpoints(), ","))
   233  		}
   234  		ctx, cancel := contextWithTotalTimeout(c)
   235  		err := hc.Sync(ctx)
   236  		cancel()
   237  		if err != nil {
   238  			if err == client.ErrNoEndpoints {
   239  				fmt.Fprintf(os.Stderr, "etcd cluster has no published client endpoints.\n")
   240  				fmt.Fprintf(os.Stderr, "Try '--no-sync' if you want to access non-published client endpoints(%s).\n", strings.Join(hc.Endpoints(), ","))
   241  				handleError(c, ExitServerError, err)
   242  			}
   243  			if isConnectionError(err) {
   244  				handleError(c, ExitBadConnection, err)
   245  			}
   246  		}
   247  		if debug {
   248  			fmt.Fprintf(os.Stderr, "got endpoints(%s) after sync\n", strings.Join(hc.Endpoints(), ","))
   249  		}
   250  	}
   251  
   252  	if debug {
   253  		fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(hc.Endpoints(), ", "))
   254  	}
   255  
   256  	return hc
   257  }
   258  
   259  func isConnectionError(err error) bool {
   260  	switch t := err.(type) {
   261  	case *client.ClusterError:
   262  		for _, cerr := range t.Errors {
   263  			if !isConnectionError(cerr) {
   264  				return false
   265  			}
   266  		}
   267  		return true
   268  	case *net.OpError:
   269  		if t.Op == "dial" || t.Op == "read" {
   270  			return true
   271  		}
   272  		return isConnectionError(t.Err)
   273  	case net.Error:
   274  		if t.Timeout() {
   275  			return true
   276  		}
   277  	case syscall.Errno:
   278  		if t == syscall.ECONNREFUSED {
   279  			return true
   280  		}
   281  	}
   282  	return false
   283  }
   284  
   285  func mustNewClientNoSync(c *cli.Context) client.Client {
   286  	hc, err := newClient(c)
   287  	if err != nil {
   288  		fmt.Fprintln(os.Stderr, err.Error())
   289  		os.Exit(1)
   290  	}
   291  
   292  	if c.GlobalBool("debug") {
   293  		fmt.Fprintf(os.Stderr, "Cluster-Endpoints: %s\n", strings.Join(hc.Endpoints(), ", "))
   294  		client.EnablecURLDebug()
   295  	}
   296  
   297  	return hc
   298  }
   299  
   300  func newClient(c *cli.Context) (client.Client, error) {
   301  	eps, err := getEndpoints(c)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	tr, err := getTransport(c)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  
   311  	cfg := client.Config{
   312  		Transport:               tr,
   313  		Endpoints:               eps,
   314  		HeaderTimeoutPerRequest: c.GlobalDuration("timeout"),
   315  	}
   316  
   317  	uFlag := c.GlobalString("username")
   318  
   319  	if uFlag == "" {
   320  		uFlag = os.Getenv("ETCDCTL_USERNAME")
   321  	}
   322  
   323  	if uFlag != "" {
   324  		username, password, err := getUsernamePasswordFromFlag(uFlag)
   325  		if err != nil {
   326  			return nil, err
   327  		}
   328  		cfg.Username = username
   329  		cfg.Password = password
   330  	}
   331  
   332  	return client.New(cfg)
   333  }
   334  
   335  func contextWithTotalTimeout(c *cli.Context) (context.Context, context.CancelFunc) {
   336  	return context.WithTimeout(context.Background(), c.GlobalDuration("total-timeout"))
   337  }