vitess.io/vitess@v0.16.2/go/vt/topo/zk2topo/zk_conn.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package zk2topo
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"fmt"
    24  	"math/rand"
    25  	"net"
    26  	"os"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/spf13/pflag"
    32  	"github.com/z-division/go-zookeeper/zk"
    33  
    34  	"vitess.io/vitess/go/sync2"
    35  	"vitess.io/vitess/go/vt/log"
    36  	"vitess.io/vitess/go/vt/servenv"
    37  )
    38  
    39  const (
    40  	// maxAttempts is how many times we retry queries.  At 2 for
    41  	// now, so if a query fails because the session expired, we
    42  	// just try to reconnect once and go on.
    43  	maxAttempts = 2
    44  
    45  	// PermDirectory are default permissions for a node.
    46  	PermDirectory = zk.PermAdmin | zk.PermCreate | zk.PermDelete | zk.PermRead | zk.PermWrite
    47  
    48  	// PermFile allows a zk node to emulate file behavior by
    49  	// disallowing child nodes.
    50  	PermFile = zk.PermAdmin | zk.PermRead | zk.PermWrite
    51  )
    52  
    53  var (
    54  	maxConcurrency = 64
    55  	baseTimeout    = 30 * time.Second
    56  
    57  	certPath, keyPath, caPath, authFile string
    58  )
    59  
    60  func init() {
    61  	servenv.RegisterFlagsForTopoBinaries(registerFlags)
    62  }
    63  
    64  func registerFlags(fs *pflag.FlagSet) {
    65  	fs.IntVar(&maxConcurrency, "topo_zk_max_concurrency", maxConcurrency, "maximum number of pending requests to send to a Zookeeper server.")
    66  	fs.DurationVar(&baseTimeout, "topo_zk_base_timeout", baseTimeout, "zk base timeout (see zk.Connect)")
    67  	fs.StringVar(&certPath, "topo_zk_tls_cert", certPath, "the cert to use to connect to the zk topo server, requires topo_zk_tls_key, enables TLS")
    68  	fs.StringVar(&keyPath, "topo_zk_tls_key", keyPath, "the key to use to connect to the zk topo server, enables TLS")
    69  	fs.StringVar(&caPath, "topo_zk_tls_ca", caPath, "the server ca to use to validate servers when connecting to the zk topo server")
    70  	fs.StringVar(&authFile, "topo_zk_auth_file", authFile, "auth to use when connecting to the zk topo server, file contents should be <scheme>:<auth>, e.g., digest:user:pass")
    71  
    72  }
    73  
    74  // Time returns a time.Time from a ZK int64 milliseconds since Epoch time.
    75  func Time(i int64) time.Time {
    76  	return time.Unix(i/1000, i%1000*1000000)
    77  }
    78  
    79  // ZkTime returns a ZK time (int64) from a time.Time
    80  func ZkTime(t time.Time) int64 {
    81  	return t.Unix()*1000 + int64(t.Nanosecond()/1000000)
    82  }
    83  
    84  // ZkConn is a wrapper class on top of a zk.Conn.
    85  // It will do a few things for us:
    86  //   - add the context parameter. However, we do not enforce its deadlines
    87  //     necessarily.
    88  //   - enforce a max concurrency of access to Zookeeper. We just don't
    89  //     want to make too many calls concurrently, to not take too many resources.
    90  //   - retry some calls to Zookeeper. If we were disconnected from the
    91  //     server, we want to try connecting again before failing.
    92  type ZkConn struct {
    93  	// addr is set at construction time, and immutable.
    94  	addr string
    95  
    96  	// sem protects concurrent calls to Zookeeper.
    97  	sem *sync2.Semaphore
    98  
    99  	// mu protects the following fields.
   100  	mu   sync.Mutex
   101  	conn *zk.Conn
   102  }
   103  
   104  // Connect to the Zookeeper servers specified in addr
   105  // addr can be a comma separated list of servers and each server can be a DNS entry with multiple values.
   106  // Connects to the endpoints in a randomized order to avoid hot spots.
   107  func Connect(addr string) *ZkConn {
   108  	return &ZkConn{
   109  		addr: addr,
   110  		sem:  sync2.NewSemaphore(maxConcurrency, 0),
   111  	}
   112  }
   113  
   114  // Get is part of the Conn interface.
   115  func (c *ZkConn) Get(ctx context.Context, path string) (data []byte, stat *zk.Stat, err error) {
   116  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   117  		data, stat, err = conn.Get(path)
   118  		return err
   119  	})
   120  	return
   121  }
   122  
   123  // GetW is part of the Conn interface.
   124  func (c *ZkConn) GetW(ctx context.Context, path string) (data []byte, stat *zk.Stat, watch <-chan zk.Event, err error) {
   125  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   126  		data, stat, watch, err = conn.GetW(path)
   127  		return err
   128  	})
   129  	return
   130  }
   131  
   132  // Children is part of the Conn interface.
   133  func (c *ZkConn) Children(ctx context.Context, path string) (children []string, stat *zk.Stat, err error) {
   134  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   135  		children, stat, err = conn.Children(path)
   136  		return err
   137  	})
   138  	return
   139  }
   140  
   141  // ChildrenW is part of the Conn interface.
   142  func (c *ZkConn) ChildrenW(ctx context.Context, path string) (children []string, stat *zk.Stat, watch <-chan zk.Event, err error) {
   143  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   144  		children, stat, watch, err = conn.ChildrenW(path)
   145  		return err
   146  	})
   147  	return
   148  }
   149  
   150  // Exists is part of the Conn interface.
   151  func (c *ZkConn) Exists(ctx context.Context, path string) (exists bool, stat *zk.Stat, err error) {
   152  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   153  		exists, stat, err = conn.Exists(path)
   154  		return err
   155  	})
   156  	return
   157  }
   158  
   159  // ExistsW is part of the Conn interface.
   160  func (c *ZkConn) ExistsW(ctx context.Context, path string) (exists bool, stat *zk.Stat, watch <-chan zk.Event, err error) {
   161  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   162  		exists, stat, watch, err = conn.ExistsW(path)
   163  		return err
   164  	})
   165  	return
   166  }
   167  
   168  // Create is part of the Conn interface.
   169  func (c *ZkConn) Create(ctx context.Context, path string, value []byte, flags int32, aclv []zk.ACL) (pathCreated string, err error) {
   170  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   171  		pathCreated, err = conn.Create(path, value, flags, aclv)
   172  		return err
   173  	})
   174  	return
   175  }
   176  
   177  // Set is part of the Conn interface.
   178  func (c *ZkConn) Set(ctx context.Context, path string, value []byte, version int32) (stat *zk.Stat, err error) {
   179  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   180  		stat, err = conn.Set(path, value, version)
   181  		return err
   182  	})
   183  	return
   184  }
   185  
   186  // Delete is part of the Conn interface.
   187  func (c *ZkConn) Delete(ctx context.Context, path string, version int32) error {
   188  	return c.withRetry(ctx, func(conn *zk.Conn) error {
   189  		return conn.Delete(path, version)
   190  	})
   191  }
   192  
   193  // GetACL is part of the Conn interface.
   194  func (c *ZkConn) GetACL(ctx context.Context, path string) (aclv []zk.ACL, stat *zk.Stat, err error) {
   195  	err = c.withRetry(ctx, func(conn *zk.Conn) error {
   196  		aclv, stat, err = conn.GetACL(path)
   197  		return err
   198  	})
   199  	return
   200  }
   201  
   202  // SetACL is part of the Conn interface.
   203  func (c *ZkConn) SetACL(ctx context.Context, path string, aclv []zk.ACL, version int32) error {
   204  	return c.withRetry(ctx, func(conn *zk.Conn) error {
   205  		_, err := conn.SetACL(path, aclv, version)
   206  		return err
   207  	})
   208  }
   209  
   210  // AddAuth is part of the Conn interface.
   211  func (c *ZkConn) AddAuth(ctx context.Context, scheme string, auth []byte) error {
   212  	return c.withRetry(ctx, func(conn *zk.Conn) error {
   213  		err := conn.AddAuth(scheme, auth)
   214  		return err
   215  	})
   216  }
   217  
   218  // Close is part of the Conn interface.
   219  func (c *ZkConn) Close() error {
   220  	c.mu.Lock()
   221  	defer c.mu.Unlock()
   222  	if c.conn != nil {
   223  		c.conn.Close()
   224  	}
   225  	return nil
   226  }
   227  
   228  // withRetry encapsulates the retry logic and concurrent access to
   229  // Zookeeper.
   230  //
   231  // Some errors are not handled gracefully by the Zookeeper client. This is
   232  // sort of odd, but in general it doesn't affect the kind of code you
   233  // need to have a truly reliable client.
   234  //
   235  // However, it can manifest itself as an annoying transient error that
   236  // is likely avoidable when trying simple operations like Get.
   237  // To that end, we retry when possible to minimize annoyance at
   238  // higher levels.
   239  //
   240  // https://issues.apache.org/jira/browse/ZOOKEEPER-22
   241  func (c *ZkConn) withRetry(ctx context.Context, action func(conn *zk.Conn) error) (err error) {
   242  
   243  	// Handle concurrent access to a Zookeeper server here.
   244  	c.sem.Acquire()
   245  	defer c.sem.Release()
   246  
   247  	for i := 0; i < maxAttempts; i++ {
   248  		if i > 0 {
   249  			// Add a bit of backoff time before retrying:
   250  			// 1 second base + up to 5 seconds.
   251  			time.Sleep(1*time.Second + time.Duration(rand.Int63n(5e9)))
   252  		}
   253  
   254  		// Get the current connection, or connect.
   255  		var conn *zk.Conn
   256  		conn, err = c.getConn(ctx)
   257  		if err != nil {
   258  			// We can't connect, try again.
   259  			continue
   260  		}
   261  
   262  		// Execute the action.
   263  		err = action(conn)
   264  		if err != zk.ErrConnectionClosed {
   265  			// It worked, or it failed for another reason
   266  			// than connection related.
   267  			return
   268  		}
   269  
   270  		// We got an error, because the connection was closed.
   271  		// Let's clear up our errored connection and try again.
   272  		c.mu.Lock()
   273  		if c.conn == conn {
   274  			c.conn = nil
   275  		}
   276  		c.mu.Unlock()
   277  	}
   278  	return
   279  }
   280  
   281  // getConn returns the connection in a thread safe way. It will try to connect
   282  // if not connected yet.
   283  func (c *ZkConn) getConn(ctx context.Context) (*zk.Conn, error) {
   284  	c.mu.Lock()
   285  	defer c.mu.Unlock()
   286  
   287  	if c.conn == nil {
   288  		conn, events, err := dialZk(ctx, c.addr)
   289  		if err != nil {
   290  			return nil, err
   291  		}
   292  		c.conn = conn
   293  		go c.handleSessionEvents(conn, events)
   294  		c.maybeAddAuth(ctx)
   295  	}
   296  	return c.conn, nil
   297  }
   298  
   299  // maybeAddAuth calls AddAuth if the `-topo_zk_auth_file` flag was specified
   300  func (c *ZkConn) maybeAddAuth(ctx context.Context) {
   301  	if authFile == "" {
   302  		return
   303  	}
   304  	authInfoBytes, err := os.ReadFile(authFile)
   305  	if err != nil {
   306  		log.Errorf("failed to read topo_zk_auth_file: %v", err)
   307  		return
   308  	}
   309  	authInfo := strings.TrimRight(string(authInfoBytes), "\n")
   310  	authInfoParts := strings.SplitN(authInfo, ":", 2)
   311  	if len(authInfoParts) != 2 {
   312  		log.Errorf("failed to parse topo_zk_auth_file contents, expected format <scheme>:<auth> but saw: %s", authInfo)
   313  		return
   314  	}
   315  	err = c.conn.AddAuth(authInfoParts[0], []byte(authInfoParts[1]))
   316  	if err != nil {
   317  		log.Errorf("failed to add auth from topo_zk_auth_file: %v", err)
   318  		return
   319  	}
   320  }
   321  
   322  // handleSessionEvents is processing events from the session channel.
   323  // When it detects that the connection is not working any more, it
   324  // clears out the connection record.
   325  func (c *ZkConn) handleSessionEvents(conn *zk.Conn, session <-chan zk.Event) {
   326  	for event := range session {
   327  		closeRequired := false
   328  
   329  		switch event.State {
   330  		case zk.StateExpired, zk.StateConnecting:
   331  			closeRequired = true
   332  			fallthrough
   333  		case zk.StateDisconnected:
   334  			c.mu.Lock()
   335  			if c.conn == conn {
   336  				// The ZkConn still references this
   337  				// connection, let's nil it.
   338  				c.conn = nil
   339  			}
   340  			c.mu.Unlock()
   341  			if closeRequired {
   342  				conn.Close()
   343  			}
   344  			log.Infof("zk conn: session for addr %v ended: %v", c.addr, event)
   345  			return
   346  		}
   347  		log.Infof("zk conn: session for addr %v event: %v", c.addr, event)
   348  	}
   349  }
   350  
   351  // dialZk dials the server, and waits until connection.
   352  func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error) {
   353  	servers := strings.Split(addr, ",")
   354  	dialer := zk.WithDialer(net.DialTimeout)
   355  	ctx, cancel := context.WithTimeout(ctx, baseTimeout)
   356  	defer cancel()
   357  	// If TLS is enabled use a TLS enabled dialer option
   358  	if certPath != "" && keyPath != "" {
   359  		if strings.Contains(addr, ",") {
   360  			log.Fatalf("This TLS zk code requires that the all the zk servers validate to a single server name.")
   361  		}
   362  
   363  		serverName := strings.Split(addr, ":")[0]
   364  
   365  		log.Infof("Using TLS ZK, connecting to %v server name %v", addr, serverName)
   366  		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
   367  		if err != nil {
   368  			log.Fatalf("Unable to load cert %v and key %v, err %v", certPath, keyPath, err)
   369  		}
   370  
   371  		clientCACert, err := os.ReadFile(caPath)
   372  		if err != nil {
   373  			log.Fatalf("Unable to open ca cert %v, err %v", caPath, err)
   374  		}
   375  
   376  		clientCertPool := x509.NewCertPool()
   377  		clientCertPool.AppendCertsFromPEM(clientCACert)
   378  
   379  		tlsConfig := &tls.Config{
   380  			Certificates: []tls.Certificate{cert},
   381  			RootCAs:      clientCertPool,
   382  			ServerName:   serverName,
   383  		}
   384  
   385  		dialer = zk.WithDialer(func(network, address string, timeout time.Duration) (net.Conn, error) {
   386  			d := net.Dialer{Timeout: timeout}
   387  
   388  			return tls.DialWithDialer(&d, network, address, tlsConfig)
   389  		})
   390  	}
   391  	// Make sure we re-resolve the DNS name every time we reconnect to a server
   392  	// In environments where DNS changes such as Kubernetes we can't cache the IP address
   393  	hostProvider := zk.WithHostProvider(&zk.SimpleDNSHostProvider{})
   394  
   395  	// zk.Connect automatically shuffles the servers
   396  	zconn, session, err := zk.Connect(servers, baseTimeout, dialer, hostProvider)
   397  	if err != nil {
   398  		return nil, nil, err
   399  	}
   400  
   401  	// Wait for connection, skipping transition states.
   402  	for {
   403  		select {
   404  		case <-ctx.Done():
   405  			zconn.Close()
   406  			return nil, nil, ctx.Err()
   407  		case event := <-session:
   408  			switch event.State {
   409  			case zk.StateConnected:
   410  				// success
   411  				return zconn, session, nil
   412  
   413  			case zk.StateAuthFailed:
   414  				// fast fail this one
   415  				zconn.Close()
   416  				return nil, nil, fmt.Errorf("zk connect failed: StateAuthFailed")
   417  			}
   418  		}
   419  	}
   420  }