vitess.io/vitess@v0.16.2/go/vt/topo/zk2topo/zk_conn.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package zk2topo 18 19 import ( 20 "context" 21 "crypto/tls" 22 "crypto/x509" 23 "fmt" 24 "math/rand" 25 "net" 26 "os" 27 "strings" 28 "sync" 29 "time" 30 31 "github.com/spf13/pflag" 32 "github.com/z-division/go-zookeeper/zk" 33 34 "vitess.io/vitess/go/sync2" 35 "vitess.io/vitess/go/vt/log" 36 "vitess.io/vitess/go/vt/servenv" 37 ) 38 39 const ( 40 // maxAttempts is how many times we retry queries. At 2 for 41 // now, so if a query fails because the session expired, we 42 // just try to reconnect once and go on. 43 maxAttempts = 2 44 45 // PermDirectory are default permissions for a node. 46 PermDirectory = zk.PermAdmin | zk.PermCreate | zk.PermDelete | zk.PermRead | zk.PermWrite 47 48 // PermFile allows a zk node to emulate file behavior by 49 // disallowing child nodes. 50 PermFile = zk.PermAdmin | zk.PermRead | zk.PermWrite 51 ) 52 53 var ( 54 maxConcurrency = 64 55 baseTimeout = 30 * time.Second 56 57 certPath, keyPath, caPath, authFile string 58 ) 59 60 func init() { 61 servenv.RegisterFlagsForTopoBinaries(registerFlags) 62 } 63 64 func registerFlags(fs *pflag.FlagSet) { 65 fs.IntVar(&maxConcurrency, "topo_zk_max_concurrency", maxConcurrency, "maximum number of pending requests to send to a Zookeeper server.") 66 fs.DurationVar(&baseTimeout, "topo_zk_base_timeout", baseTimeout, "zk base timeout (see zk.Connect)") 67 fs.StringVar(&certPath, "topo_zk_tls_cert", certPath, "the cert to use to connect to the zk topo server, requires topo_zk_tls_key, enables TLS") 68 fs.StringVar(&keyPath, "topo_zk_tls_key", keyPath, "the key to use to connect to the zk topo server, enables TLS") 69 fs.StringVar(&caPath, "topo_zk_tls_ca", caPath, "the server ca to use to validate servers when connecting to the zk topo server") 70 fs.StringVar(&authFile, "topo_zk_auth_file", authFile, "auth to use when connecting to the zk topo server, file contents should be <scheme>:<auth>, e.g., digest:user:pass") 71 72 } 73 74 // Time returns a time.Time from a ZK int64 milliseconds since Epoch time. 75 func Time(i int64) time.Time { 76 return time.Unix(i/1000, i%1000*1000000) 77 } 78 79 // ZkTime returns a ZK time (int64) from a time.Time 80 func ZkTime(t time.Time) int64 { 81 return t.Unix()*1000 + int64(t.Nanosecond()/1000000) 82 } 83 84 // ZkConn is a wrapper class on top of a zk.Conn. 85 // It will do a few things for us: 86 // - add the context parameter. However, we do not enforce its deadlines 87 // necessarily. 88 // - enforce a max concurrency of access to Zookeeper. We just don't 89 // want to make too many calls concurrently, to not take too many resources. 90 // - retry some calls to Zookeeper. If we were disconnected from the 91 // server, we want to try connecting again before failing. 92 type ZkConn struct { 93 // addr is set at construction time, and immutable. 94 addr string 95 96 // sem protects concurrent calls to Zookeeper. 97 sem *sync2.Semaphore 98 99 // mu protects the following fields. 100 mu sync.Mutex 101 conn *zk.Conn 102 } 103 104 // Connect to the Zookeeper servers specified in addr 105 // addr can be a comma separated list of servers and each server can be a DNS entry with multiple values. 106 // Connects to the endpoints in a randomized order to avoid hot spots. 107 func Connect(addr string) *ZkConn { 108 return &ZkConn{ 109 addr: addr, 110 sem: sync2.NewSemaphore(maxConcurrency, 0), 111 } 112 } 113 114 // Get is part of the Conn interface. 115 func (c *ZkConn) Get(ctx context.Context, path string) (data []byte, stat *zk.Stat, err error) { 116 err = c.withRetry(ctx, func(conn *zk.Conn) error { 117 data, stat, err = conn.Get(path) 118 return err 119 }) 120 return 121 } 122 123 // GetW is part of the Conn interface. 124 func (c *ZkConn) GetW(ctx context.Context, path string) (data []byte, stat *zk.Stat, watch <-chan zk.Event, err error) { 125 err = c.withRetry(ctx, func(conn *zk.Conn) error { 126 data, stat, watch, err = conn.GetW(path) 127 return err 128 }) 129 return 130 } 131 132 // Children is part of the Conn interface. 133 func (c *ZkConn) Children(ctx context.Context, path string) (children []string, stat *zk.Stat, err error) { 134 err = c.withRetry(ctx, func(conn *zk.Conn) error { 135 children, stat, err = conn.Children(path) 136 return err 137 }) 138 return 139 } 140 141 // ChildrenW is part of the Conn interface. 142 func (c *ZkConn) ChildrenW(ctx context.Context, path string) (children []string, stat *zk.Stat, watch <-chan zk.Event, err error) { 143 err = c.withRetry(ctx, func(conn *zk.Conn) error { 144 children, stat, watch, err = conn.ChildrenW(path) 145 return err 146 }) 147 return 148 } 149 150 // Exists is part of the Conn interface. 151 func (c *ZkConn) Exists(ctx context.Context, path string) (exists bool, stat *zk.Stat, err error) { 152 err = c.withRetry(ctx, func(conn *zk.Conn) error { 153 exists, stat, err = conn.Exists(path) 154 return err 155 }) 156 return 157 } 158 159 // ExistsW is part of the Conn interface. 160 func (c *ZkConn) ExistsW(ctx context.Context, path string) (exists bool, stat *zk.Stat, watch <-chan zk.Event, err error) { 161 err = c.withRetry(ctx, func(conn *zk.Conn) error { 162 exists, stat, watch, err = conn.ExistsW(path) 163 return err 164 }) 165 return 166 } 167 168 // Create is part of the Conn interface. 169 func (c *ZkConn) Create(ctx context.Context, path string, value []byte, flags int32, aclv []zk.ACL) (pathCreated string, err error) { 170 err = c.withRetry(ctx, func(conn *zk.Conn) error { 171 pathCreated, err = conn.Create(path, value, flags, aclv) 172 return err 173 }) 174 return 175 } 176 177 // Set is part of the Conn interface. 178 func (c *ZkConn) Set(ctx context.Context, path string, value []byte, version int32) (stat *zk.Stat, err error) { 179 err = c.withRetry(ctx, func(conn *zk.Conn) error { 180 stat, err = conn.Set(path, value, version) 181 return err 182 }) 183 return 184 } 185 186 // Delete is part of the Conn interface. 187 func (c *ZkConn) Delete(ctx context.Context, path string, version int32) error { 188 return c.withRetry(ctx, func(conn *zk.Conn) error { 189 return conn.Delete(path, version) 190 }) 191 } 192 193 // GetACL is part of the Conn interface. 194 func (c *ZkConn) GetACL(ctx context.Context, path string) (aclv []zk.ACL, stat *zk.Stat, err error) { 195 err = c.withRetry(ctx, func(conn *zk.Conn) error { 196 aclv, stat, err = conn.GetACL(path) 197 return err 198 }) 199 return 200 } 201 202 // SetACL is part of the Conn interface. 203 func (c *ZkConn) SetACL(ctx context.Context, path string, aclv []zk.ACL, version int32) error { 204 return c.withRetry(ctx, func(conn *zk.Conn) error { 205 _, err := conn.SetACL(path, aclv, version) 206 return err 207 }) 208 } 209 210 // AddAuth is part of the Conn interface. 211 func (c *ZkConn) AddAuth(ctx context.Context, scheme string, auth []byte) error { 212 return c.withRetry(ctx, func(conn *zk.Conn) error { 213 err := conn.AddAuth(scheme, auth) 214 return err 215 }) 216 } 217 218 // Close is part of the Conn interface. 219 func (c *ZkConn) Close() error { 220 c.mu.Lock() 221 defer c.mu.Unlock() 222 if c.conn != nil { 223 c.conn.Close() 224 } 225 return nil 226 } 227 228 // withRetry encapsulates the retry logic and concurrent access to 229 // Zookeeper. 230 // 231 // Some errors are not handled gracefully by the Zookeeper client. This is 232 // sort of odd, but in general it doesn't affect the kind of code you 233 // need to have a truly reliable client. 234 // 235 // However, it can manifest itself as an annoying transient error that 236 // is likely avoidable when trying simple operations like Get. 237 // To that end, we retry when possible to minimize annoyance at 238 // higher levels. 239 // 240 // https://issues.apache.org/jira/browse/ZOOKEEPER-22 241 func (c *ZkConn) withRetry(ctx context.Context, action func(conn *zk.Conn) error) (err error) { 242 243 // Handle concurrent access to a Zookeeper server here. 244 c.sem.Acquire() 245 defer c.sem.Release() 246 247 for i := 0; i < maxAttempts; i++ { 248 if i > 0 { 249 // Add a bit of backoff time before retrying: 250 // 1 second base + up to 5 seconds. 251 time.Sleep(1*time.Second + time.Duration(rand.Int63n(5e9))) 252 } 253 254 // Get the current connection, or connect. 255 var conn *zk.Conn 256 conn, err = c.getConn(ctx) 257 if err != nil { 258 // We can't connect, try again. 259 continue 260 } 261 262 // Execute the action. 263 err = action(conn) 264 if err != zk.ErrConnectionClosed { 265 // It worked, or it failed for another reason 266 // than connection related. 267 return 268 } 269 270 // We got an error, because the connection was closed. 271 // Let's clear up our errored connection and try again. 272 c.mu.Lock() 273 if c.conn == conn { 274 c.conn = nil 275 } 276 c.mu.Unlock() 277 } 278 return 279 } 280 281 // getConn returns the connection in a thread safe way. It will try to connect 282 // if not connected yet. 283 func (c *ZkConn) getConn(ctx context.Context) (*zk.Conn, error) { 284 c.mu.Lock() 285 defer c.mu.Unlock() 286 287 if c.conn == nil { 288 conn, events, err := dialZk(ctx, c.addr) 289 if err != nil { 290 return nil, err 291 } 292 c.conn = conn 293 go c.handleSessionEvents(conn, events) 294 c.maybeAddAuth(ctx) 295 } 296 return c.conn, nil 297 } 298 299 // maybeAddAuth calls AddAuth if the `-topo_zk_auth_file` flag was specified 300 func (c *ZkConn) maybeAddAuth(ctx context.Context) { 301 if authFile == "" { 302 return 303 } 304 authInfoBytes, err := os.ReadFile(authFile) 305 if err != nil { 306 log.Errorf("failed to read topo_zk_auth_file: %v", err) 307 return 308 } 309 authInfo := strings.TrimRight(string(authInfoBytes), "\n") 310 authInfoParts := strings.SplitN(authInfo, ":", 2) 311 if len(authInfoParts) != 2 { 312 log.Errorf("failed to parse topo_zk_auth_file contents, expected format <scheme>:<auth> but saw: %s", authInfo) 313 return 314 } 315 err = c.conn.AddAuth(authInfoParts[0], []byte(authInfoParts[1])) 316 if err != nil { 317 log.Errorf("failed to add auth from topo_zk_auth_file: %v", err) 318 return 319 } 320 } 321 322 // handleSessionEvents is processing events from the session channel. 323 // When it detects that the connection is not working any more, it 324 // clears out the connection record. 325 func (c *ZkConn) handleSessionEvents(conn *zk.Conn, session <-chan zk.Event) { 326 for event := range session { 327 closeRequired := false 328 329 switch event.State { 330 case zk.StateExpired, zk.StateConnecting: 331 closeRequired = true 332 fallthrough 333 case zk.StateDisconnected: 334 c.mu.Lock() 335 if c.conn == conn { 336 // The ZkConn still references this 337 // connection, let's nil it. 338 c.conn = nil 339 } 340 c.mu.Unlock() 341 if closeRequired { 342 conn.Close() 343 } 344 log.Infof("zk conn: session for addr %v ended: %v", c.addr, event) 345 return 346 } 347 log.Infof("zk conn: session for addr %v event: %v", c.addr, event) 348 } 349 } 350 351 // dialZk dials the server, and waits until connection. 352 func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error) { 353 servers := strings.Split(addr, ",") 354 dialer := zk.WithDialer(net.DialTimeout) 355 ctx, cancel := context.WithTimeout(ctx, baseTimeout) 356 defer cancel() 357 // If TLS is enabled use a TLS enabled dialer option 358 if certPath != "" && keyPath != "" { 359 if strings.Contains(addr, ",") { 360 log.Fatalf("This TLS zk code requires that the all the zk servers validate to a single server name.") 361 } 362 363 serverName := strings.Split(addr, ":")[0] 364 365 log.Infof("Using TLS ZK, connecting to %v server name %v", addr, serverName) 366 cert, err := tls.LoadX509KeyPair(certPath, keyPath) 367 if err != nil { 368 log.Fatalf("Unable to load cert %v and key %v, err %v", certPath, keyPath, err) 369 } 370 371 clientCACert, err := os.ReadFile(caPath) 372 if err != nil { 373 log.Fatalf("Unable to open ca cert %v, err %v", caPath, err) 374 } 375 376 clientCertPool := x509.NewCertPool() 377 clientCertPool.AppendCertsFromPEM(clientCACert) 378 379 tlsConfig := &tls.Config{ 380 Certificates: []tls.Certificate{cert}, 381 RootCAs: clientCertPool, 382 ServerName: serverName, 383 } 384 385 dialer = zk.WithDialer(func(network, address string, timeout time.Duration) (net.Conn, error) { 386 d := net.Dialer{Timeout: timeout} 387 388 return tls.DialWithDialer(&d, network, address, tlsConfig) 389 }) 390 } 391 // Make sure we re-resolve the DNS name every time we reconnect to a server 392 // In environments where DNS changes such as Kubernetes we can't cache the IP address 393 hostProvider := zk.WithHostProvider(&zk.SimpleDNSHostProvider{}) 394 395 // zk.Connect automatically shuffles the servers 396 zconn, session, err := zk.Connect(servers, baseTimeout, dialer, hostProvider) 397 if err != nil { 398 return nil, nil, err 399 } 400 401 // Wait for connection, skipping transition states. 402 for { 403 select { 404 case <-ctx.Done(): 405 zconn.Close() 406 return nil, nil, ctx.Err() 407 case event := <-session: 408 switch event.State { 409 case zk.StateConnected: 410 // success 411 return zconn, session, nil 412 413 case zk.StateAuthFailed: 414 // fast fail this one 415 zconn.Close() 416 return nil, nil, fmt.Errorf("zk connect failed: StateAuthFailed") 417 } 418 } 419 } 420 }