vitess.io/vitess@v0.16.2/go/vt/srvtopo/watch.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package srvtopo
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"sync"
    23  	"time"
    24  
    25  	"vitess.io/vitess/go/stats"
    26  	"vitess.io/vitess/go/vt/log"
    27  	"vitess.io/vitess/go/vt/topo"
    28  )
    29  
    30  type watchState int
    31  
    32  const (
    33  	watchStateIdle watchState = iota
    34  	watchStateStarting
    35  	watchStateRunning
    36  )
    37  
    38  type watchEntry struct {
    39  	// immutable values
    40  	rw  *resilientWatcher
    41  	key fmt.Stringer
    42  
    43  	mutex      sync.Mutex
    44  	watchState watchState
    45  
    46  	watchStartingChan chan struct{}
    47  
    48  	value     any
    49  	lastError error
    50  
    51  	lastValueTime time.Time
    52  	lastErrorTime time.Time
    53  
    54  	listeners []func(any, error) bool
    55  }
    56  
    57  type resilientWatcher struct {
    58  	watcher func(entry *watchEntry)
    59  
    60  	counts               *stats.CountersWithSingleLabel
    61  	cacheRefreshInterval time.Duration
    62  	cacheTTL             time.Duration
    63  
    64  	mutex   sync.Mutex
    65  	entries map[string]*watchEntry
    66  }
    67  
    68  func (w *resilientWatcher) getEntry(wkey fmt.Stringer) *watchEntry {
    69  	w.mutex.Lock()
    70  	defer w.mutex.Unlock()
    71  
    72  	key := wkey.String()
    73  	entry, ok := w.entries[key]
    74  	if ok {
    75  		return entry
    76  	}
    77  
    78  	entry = &watchEntry{
    79  		rw:  w,
    80  		key: wkey,
    81  	}
    82  	w.entries[key] = entry
    83  	return entry
    84  }
    85  
    86  func (w *resilientWatcher) getValue(ctx context.Context, wkey fmt.Stringer) (any, error) {
    87  	entry := w.getEntry(wkey)
    88  
    89  	entry.mutex.Lock()
    90  	defer entry.mutex.Unlock()
    91  	return entry.currentValueLocked(ctx)
    92  }
    93  
    94  func (entry *watchEntry) addListener(ctx context.Context, callback func(any, error) bool) {
    95  	entry.mutex.Lock()
    96  	defer entry.mutex.Unlock()
    97  
    98  	entry.listeners = append(entry.listeners, callback)
    99  	v, err := entry.currentValueLocked(ctx)
   100  	callback(v, err)
   101  }
   102  
   103  func (entry *watchEntry) ensureWatchingLocked() {
   104  	switch entry.watchState {
   105  	case watchStateRunning, watchStateStarting:
   106  	case watchStateIdle:
   107  		shouldRefresh := time.Since(entry.lastErrorTime) > entry.rw.cacheRefreshInterval || len(entry.listeners) > 0
   108  
   109  		if shouldRefresh {
   110  			entry.watchState = watchStateStarting
   111  			entry.watchStartingChan = make(chan struct{})
   112  			go entry.rw.watcher(entry)
   113  		}
   114  	}
   115  }
   116  
   117  func (entry *watchEntry) currentValueLocked(ctx context.Context) (any, error) {
   118  	entry.rw.counts.Add(queryCategory, 1)
   119  
   120  	if entry.watchState == watchStateRunning {
   121  		return entry.value, entry.lastError
   122  	}
   123  
   124  	entry.ensureWatchingLocked()
   125  
   126  	cacheValid := entry.value != nil && time.Since(entry.lastValueTime) < entry.rw.cacheTTL
   127  	if cacheValid {
   128  		entry.rw.counts.Add(cachedCategory, 1)
   129  		return entry.value, nil
   130  	}
   131  
   132  	if entry.watchState == watchStateStarting {
   133  		watchStartingChan := entry.watchStartingChan
   134  		entry.mutex.Unlock()
   135  		select {
   136  		case <-watchStartingChan:
   137  		case <-ctx.Done():
   138  			entry.mutex.Lock()
   139  			return nil, ctx.Err()
   140  		}
   141  		entry.mutex.Lock()
   142  	}
   143  	if entry.value != nil {
   144  		return entry.value, nil
   145  	}
   146  	return nil, entry.lastError
   147  }
   148  
   149  func (entry *watchEntry) update(value any, err error, init bool) {
   150  	entry.mutex.Lock()
   151  	defer entry.mutex.Unlock()
   152  
   153  	if err != nil {
   154  		entry.onErrorLocked(err, init)
   155  	} else {
   156  		entry.onValueLocked(value)
   157  	}
   158  
   159  	listeners := entry.listeners
   160  	entry.listeners = entry.listeners[:0]
   161  
   162  	for _, callback := range listeners {
   163  		if callback(entry.value, entry.lastError) {
   164  			entry.listeners = append(entry.listeners, callback)
   165  		}
   166  	}
   167  }
   168  
   169  func (entry *watchEntry) onValueLocked(value any) {
   170  	entry.watchState = watchStateRunning
   171  	if entry.watchStartingChan != nil {
   172  		close(entry.watchStartingChan)
   173  		entry.watchStartingChan = nil
   174  	}
   175  	entry.value = value
   176  	entry.lastValueTime = time.Now()
   177  
   178  	entry.lastError = nil
   179  	entry.lastErrorTime = time.Time{}
   180  }
   181  
   182  func (entry *watchEntry) onErrorLocked(err error, init bool) {
   183  	entry.rw.counts.Add(errorCategory, 1)
   184  
   185  	entry.lastErrorTime = time.Now()
   186  
   187  	// if the node disappears, delete the cached value
   188  	if topo.IsErrType(err, topo.NoNode) {
   189  		entry.value = nil
   190  	}
   191  
   192  	if init {
   193  		entry.lastError = err
   194  
   195  		// This watcher will able to continue to return the last value till it is not able to connect to the topo server even if the cache TTL is reached.
   196  		// TTL cache is only checked if the error is a known error i.e topo.Error.
   197  		_, isTopoErr := err.(topo.Error)
   198  		if isTopoErr && time.Since(entry.lastValueTime) > entry.rw.cacheTTL {
   199  			log.Errorf("WatchSrvKeyspace clearing cached entry for %v", entry.key)
   200  			entry.value = nil
   201  		}
   202  	} else {
   203  		entry.lastError = fmt.Errorf("ResilientWatch stream failed for %v: %w", entry.key, err)
   204  		log.Errorf("%v", entry.lastError)
   205  
   206  		// Even though we didn't get a new value, update the lastValueTime
   207  		// here since the watch was successfully running before and we want
   208  		// the value to be cached for the full TTL from here onwards.
   209  		entry.lastValueTime = time.Now()
   210  	}
   211  
   212  	if entry.watchStartingChan != nil {
   213  		close(entry.watchStartingChan)
   214  		entry.watchStartingChan = nil
   215  	}
   216  
   217  	entry.watchState = watchStateIdle
   218  
   219  	// only retry the watch if we haven't been explicitly interrupted
   220  	if len(entry.listeners) > 0 && !topo.IsErrType(err, topo.Interrupted) {
   221  		go func() {
   222  			time.Sleep(entry.rw.cacheRefreshInterval)
   223  
   224  			entry.mutex.Lock()
   225  			entry.ensureWatchingLocked()
   226  			entry.mutex.Unlock()
   227  		}()
   228  	}
   229  }