go.etcd.io/etcd@v3.3.27+incompatible/etcdserver/api/v3rpc/watch.go (about)

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package v3rpc
    16  
    17  import (
    18  	"context"
    19  	"io"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/coreos/etcd/auth"
    24  	"github.com/coreos/etcd/etcdserver"
    25  	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
    26  	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
    27  	"github.com/coreos/etcd/mvcc"
    28  	"github.com/coreos/etcd/mvcc/mvccpb"
    29  )
    30  
    31  type watchServer struct {
    32  	clusterID int64
    33  	memberID  int64
    34  
    35  	maxRequestBytes int
    36  
    37  	raftTimer etcdserver.RaftTimer
    38  	watchable mvcc.WatchableKV
    39  
    40  	ag AuthGetter
    41  }
    42  
    43  func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
    44  	return &watchServer{
    45  		clusterID:       int64(s.Cluster().ID()),
    46  		memberID:        int64(s.ID()),
    47  		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
    48  		raftTimer:       s,
    49  		watchable:       s.Watchable(),
    50  		ag:              s,
    51  	}
    52  }
    53  
    54  var (
    55  	// External test can read this with GetProgressReportInterval()
    56  	// and change this to a small value to finish fast with
    57  	// SetProgressReportInterval().
    58  	progressReportInterval   = 10 * time.Minute
    59  	progressReportIntervalMu sync.RWMutex
    60  )
    61  
    62  func GetProgressReportInterval() time.Duration {
    63  	progressReportIntervalMu.RLock()
    64  	defer progressReportIntervalMu.RUnlock()
    65  	return progressReportInterval
    66  }
    67  
    68  func SetProgressReportInterval(newTimeout time.Duration) {
    69  	progressReportIntervalMu.Lock()
    70  	defer progressReportIntervalMu.Unlock()
    71  	progressReportInterval = newTimeout
    72  }
    73  
    74  const (
    75  	// We send ctrl response inside the read loop. We do not want
    76  	// send to block read, but we still want ctrl response we sent to
    77  	// be serialized. Thus we use a buffered chan to solve the problem.
    78  	// A small buffer should be OK for most cases, since we expect the
    79  	// ctrl requests are infrequent.
    80  	ctrlStreamBufLen = 16
    81  )
    82  
    83  // serverWatchStream is an etcd server side stream. It receives requests
    84  // from client side gRPC stream. It receives watch events from mvcc.WatchStream,
    85  // and creates responses that forwarded to gRPC stream.
    86  // It also forwards control message like watch created and canceled.
    87  type serverWatchStream struct {
    88  	clusterID int64
    89  	memberID  int64
    90  
    91  	maxRequestBytes int
    92  
    93  	raftTimer etcdserver.RaftTimer
    94  
    95  	watchable mvcc.WatchableKV
    96  
    97  	gRPCStream  pb.Watch_WatchServer
    98  	watchStream mvcc.WatchStream
    99  	ctrlStream  chan *pb.WatchResponse
   100  
   101  	// mu protects progress, prevKV
   102  	mu sync.RWMutex
   103  	// progress tracks the watchID that stream might need to send
   104  	// progress to.
   105  	// TODO: combine progress and prevKV into a single struct?
   106  	progress map[mvcc.WatchID]bool
   107  	prevKV   map[mvcc.WatchID]bool
   108  	// records fragmented watch IDs
   109  	fragment map[mvcc.WatchID]bool
   110  
   111  	// closec indicates the stream is closed.
   112  	closec chan struct{}
   113  
   114  	// wg waits for the send loop to complete
   115  	wg sync.WaitGroup
   116  
   117  	ag AuthGetter
   118  }
   119  
   120  func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
   121  	sws := serverWatchStream{
   122  		clusterID: ws.clusterID,
   123  		memberID:  ws.memberID,
   124  
   125  		maxRequestBytes: ws.maxRequestBytes,
   126  
   127  		raftTimer: ws.raftTimer,
   128  
   129  		watchable: ws.watchable,
   130  
   131  		gRPCStream:  stream,
   132  		watchStream: ws.watchable.NewWatchStream(),
   133  		// chan for sending control response like watcher created and canceled.
   134  		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
   135  		progress:   make(map[mvcc.WatchID]bool),
   136  		prevKV:     make(map[mvcc.WatchID]bool),
   137  		fragment:   make(map[mvcc.WatchID]bool),
   138  		closec:     make(chan struct{}),
   139  
   140  		ag: ws.ag,
   141  	}
   142  
   143  	sws.wg.Add(1)
   144  	go func() {
   145  		sws.sendLoop()
   146  		sws.wg.Done()
   147  	}()
   148  
   149  	errc := make(chan error, 1)
   150  	// Ideally recvLoop would also use sws.wg to signal its completion
   151  	// but when stream.Context().Done() is closed, the stream's recv
   152  	// may continue to block since it uses a different context, leading to
   153  	// deadlock when calling sws.close().
   154  	go func() {
   155  		if rerr := sws.recvLoop(); rerr != nil {
   156  			if isClientCtxErr(stream.Context().Err(), rerr) {
   157  				plog.Debugf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
   158  			} else {
   159  				plog.Warningf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
   160  			}
   161  			errc <- rerr
   162  		}
   163  	}()
   164  	select {
   165  	case err = <-errc:
   166  		close(sws.ctrlStream)
   167  	case <-stream.Context().Done():
   168  		err = stream.Context().Err()
   169  		// the only server-side cancellation is noleader for now.
   170  		if err == context.Canceled {
   171  			err = rpctypes.ErrGRPCNoLeader
   172  		}
   173  	}
   174  	sws.close()
   175  	return err
   176  }
   177  
   178  func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
   179  	authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
   180  	if err != nil {
   181  		return false
   182  	}
   183  	if authInfo == nil {
   184  		// if auth is enabled, IsRangePermitted() can cause an error
   185  		authInfo = &auth.AuthInfo{}
   186  	}
   187  
   188  	return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
   189  }
   190  
   191  func (sws *serverWatchStream) recvLoop() error {
   192  	for {
   193  		req, err := sws.gRPCStream.Recv()
   194  		if err == io.EOF {
   195  			return nil
   196  		}
   197  		if err != nil {
   198  			return err
   199  		}
   200  
   201  		switch uv := req.RequestUnion.(type) {
   202  		case *pb.WatchRequest_CreateRequest:
   203  			if uv.CreateRequest == nil {
   204  				break
   205  			}
   206  
   207  			creq := uv.CreateRequest
   208  			if len(creq.Key) == 0 {
   209  				// \x00 is the smallest key
   210  				creq.Key = []byte{0}
   211  			}
   212  			if len(creq.RangeEnd) == 0 {
   213  				// force nil since watchstream.Watch distinguishes
   214  				// between nil and []byte{} for single key / >=
   215  				creq.RangeEnd = nil
   216  			}
   217  			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
   218  				// support  >= key queries
   219  				creq.RangeEnd = []byte{}
   220  			}
   221  
   222  			if !sws.isWatchPermitted(creq) {
   223  				wr := &pb.WatchResponse{
   224  					Header:       sws.newResponseHeader(sws.watchStream.Rev()),
   225  					WatchId:      -1,
   226  					Canceled:     true,
   227  					Created:      true,
   228  					CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(),
   229  				}
   230  
   231  				select {
   232  				case sws.ctrlStream <- wr:
   233  					continue
   234  				case <-sws.closec:
   235  					return nil
   236  				}
   237  			}
   238  
   239  			filters := FiltersFromRequest(creq)
   240  
   241  			wsrev := sws.watchStream.Rev()
   242  			rev := creq.StartRevision
   243  			if rev == 0 {
   244  				rev = wsrev + 1
   245  			}
   246  			id := sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev, filters...)
   247  			if id != -1 {
   248  				sws.mu.Lock()
   249  				if creq.ProgressNotify {
   250  					sws.progress[id] = true
   251  				}
   252  				if creq.PrevKv {
   253  					sws.prevKV[id] = true
   254  				}
   255  				if creq.Fragment {
   256  					sws.fragment[id] = true
   257  				}
   258  				sws.mu.Unlock()
   259  			}
   260  			wr := &pb.WatchResponse{
   261  				Header:   sws.newResponseHeader(wsrev),
   262  				WatchId:  int64(id),
   263  				Created:  true,
   264  				Canceled: id == -1,
   265  			}
   266  			select {
   267  			case sws.ctrlStream <- wr:
   268  			case <-sws.closec:
   269  				return nil
   270  			}
   271  		case *pb.WatchRequest_CancelRequest:
   272  			if uv.CancelRequest != nil {
   273  				id := uv.CancelRequest.WatchId
   274  				err := sws.watchStream.Cancel(mvcc.WatchID(id))
   275  				if err == nil {
   276  					sws.ctrlStream <- &pb.WatchResponse{
   277  						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
   278  						WatchId:  id,
   279  						Canceled: true,
   280  					}
   281  					sws.mu.Lock()
   282  					delete(sws.progress, mvcc.WatchID(id))
   283  					delete(sws.prevKV, mvcc.WatchID(id))
   284  					delete(sws.fragment, mvcc.WatchID(id))
   285  					sws.mu.Unlock()
   286  				}
   287  			}
   288  		case *pb.WatchRequest_ProgressRequest:
   289  			if uv.ProgressRequest != nil {
   290  				sws.ctrlStream <- &pb.WatchResponse{
   291  					Header:  sws.newResponseHeader(sws.watchStream.Rev()),
   292  					WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels
   293  				}
   294  			}
   295  		default:
   296  			// we probably should not shutdown the entire stream when
   297  			// receive an valid command.
   298  			// so just do nothing instead.
   299  			continue
   300  		}
   301  	}
   302  }
   303  
   304  func (sws *serverWatchStream) sendLoop() {
   305  	// watch ids that are currently active
   306  	ids := make(map[mvcc.WatchID]struct{})
   307  	// watch responses pending on a watch id creation message
   308  	pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
   309  
   310  	interval := GetProgressReportInterval()
   311  	progressTicker := time.NewTicker(interval)
   312  
   313  	defer func() {
   314  		progressTicker.Stop()
   315  		// drain the chan to clean up pending events
   316  		for ws := range sws.watchStream.Chan() {
   317  			mvcc.ReportEventReceived(len(ws.Events))
   318  		}
   319  		for _, wrs := range pending {
   320  			for _, ws := range wrs {
   321  				mvcc.ReportEventReceived(len(ws.Events))
   322  			}
   323  		}
   324  	}()
   325  
   326  	for {
   327  		select {
   328  		case wresp, ok := <-sws.watchStream.Chan():
   329  			if !ok {
   330  				return
   331  			}
   332  
   333  			// TODO: evs is []mvccpb.Event type
   334  			// either return []*mvccpb.Event from the mvcc package
   335  			// or define protocol buffer with []mvccpb.Event.
   336  			evs := wresp.Events
   337  			events := make([]*mvccpb.Event, len(evs))
   338  			sws.mu.RLock()
   339  			needPrevKV := sws.prevKV[wresp.WatchID]
   340  			sws.mu.RUnlock()
   341  			for i := range evs {
   342  				events[i] = &evs[i]
   343  
   344  				if needPrevKV {
   345  					opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
   346  					r, err := sws.watchable.Range(evs[i].Kv.Key, nil, opt)
   347  					if err == nil && len(r.KVs) != 0 {
   348  						events[i].PrevKv = &(r.KVs[0])
   349  					}
   350  				}
   351  			}
   352  
   353  			canceled := wresp.CompactRevision != 0
   354  			wr := &pb.WatchResponse{
   355  				Header:          sws.newResponseHeader(wresp.Revision),
   356  				WatchId:         int64(wresp.WatchID),
   357  				Events:          events,
   358  				CompactRevision: wresp.CompactRevision,
   359  				Canceled:        canceled,
   360  			}
   361  
   362  			if _, hasId := ids[wresp.WatchID]; !hasId {
   363  				// buffer if id not yet announced
   364  				wrs := append(pending[wresp.WatchID], wr)
   365  				pending[wresp.WatchID] = wrs
   366  				continue
   367  			}
   368  
   369  			mvcc.ReportEventReceived(len(evs))
   370  
   371  			sws.mu.RLock()
   372  			fragmented, ok := sws.fragment[wresp.WatchID]
   373  			sws.mu.RUnlock()
   374  
   375  			var serr error
   376  			if !fragmented && !ok {
   377  				serr = sws.gRPCStream.Send(wr)
   378  			} else {
   379  				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
   380  			}
   381  
   382  			if serr != nil {
   383  				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
   384  					plog.Debugf("failed to send watch response to gRPC stream (%q)", serr.Error())
   385  				} else {
   386  					plog.Warningf("failed to send watch response to gRPC stream (%q)", serr.Error())
   387  				}
   388  				return
   389  			}
   390  
   391  			sws.mu.Lock()
   392  			if len(evs) > 0 && sws.progress[wresp.WatchID] {
   393  				// elide next progress update if sent a key update
   394  				sws.progress[wresp.WatchID] = false
   395  			}
   396  			sws.mu.Unlock()
   397  
   398  		case c, ok := <-sws.ctrlStream:
   399  			if !ok {
   400  				return
   401  			}
   402  
   403  			if err := sws.gRPCStream.Send(c); err != nil {
   404  				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
   405  					plog.Debugf("failed to send watch control response to gRPC stream (%q)", err.Error())
   406  				} else {
   407  					plog.Warningf("failed to send watch control response to gRPC stream (%q)", err.Error())
   408  				}
   409  				return
   410  			}
   411  
   412  			// track id creation
   413  			wid := mvcc.WatchID(c.WatchId)
   414  			if c.Canceled {
   415  				delete(ids, wid)
   416  				continue
   417  			}
   418  			if c.Created {
   419  				// flush buffered events
   420  				ids[wid] = struct{}{}
   421  				for _, v := range pending[wid] {
   422  					mvcc.ReportEventReceived(len(v.Events))
   423  					if err := sws.gRPCStream.Send(v); err != nil {
   424  						if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
   425  							plog.Debugf("failed to send pending watch response to gRPC stream (%q)", err.Error())
   426  						} else {
   427  							plog.Warningf("failed to send pending watch response to gRPC stream (%q)", err.Error())
   428  						}
   429  						return
   430  					}
   431  				}
   432  				delete(pending, wid)
   433  			}
   434  		case <-progressTicker.C:
   435  			sws.mu.Lock()
   436  			for id, ok := range sws.progress {
   437  				if ok {
   438  					sws.watchStream.RequestProgress(id)
   439  				}
   440  				sws.progress[id] = true
   441  			}
   442  			sws.mu.Unlock()
   443  		case <-sws.closec:
   444  			return
   445  		}
   446  	}
   447  }
   448  
   449  func sendFragments(
   450  	wr *pb.WatchResponse,
   451  	maxRequestBytes int,
   452  	sendFunc func(*pb.WatchResponse) error) error {
   453  	// no need to fragment if total request size is smaller
   454  	// than max request limit or response contains only one event
   455  	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
   456  		return sendFunc(wr)
   457  	}
   458  
   459  	ow := *wr
   460  	ow.Events = make([]*mvccpb.Event, 0)
   461  	ow.Fragment = true
   462  
   463  	var idx int
   464  	for {
   465  		cur := ow
   466  		for _, ev := range wr.Events[idx:] {
   467  			cur.Events = append(cur.Events, ev)
   468  			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
   469  				cur.Events = cur.Events[:len(cur.Events)-1]
   470  				break
   471  			}
   472  			idx++
   473  		}
   474  		if idx == len(wr.Events) {
   475  			// last response has no more fragment
   476  			cur.Fragment = false
   477  		}
   478  		if err := sendFunc(&cur); err != nil {
   479  			return err
   480  		}
   481  		if !cur.Fragment {
   482  			break
   483  		}
   484  	}
   485  	return nil
   486  }
   487  
   488  func (sws *serverWatchStream) close() {
   489  	sws.watchStream.Close()
   490  	close(sws.closec)
   491  	sws.wg.Wait()
   492  }
   493  
   494  func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
   495  	return &pb.ResponseHeader{
   496  		ClusterId: uint64(sws.clusterID),
   497  		MemberId:  uint64(sws.memberID),
   498  		Revision:  rev,
   499  		RaftTerm:  sws.raftTimer.Term(),
   500  	}
   501  }
   502  
   503  func filterNoDelete(e mvccpb.Event) bool {
   504  	return e.Type == mvccpb.DELETE
   505  }
   506  
   507  func filterNoPut(e mvccpb.Event) bool {
   508  	return e.Type == mvccpb.PUT
   509  }
   510  
   511  func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
   512  	filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
   513  	for _, ft := range creq.Filters {
   514  		switch ft {
   515  		case pb.WatchCreateRequest_NOPUT:
   516  			filters = append(filters, filterNoPut)
   517  		case pb.WatchCreateRequest_NODELETE:
   518  			filters = append(filters, filterNoDelete)
   519  		default:
   520  		}
   521  	}
   522  	return filters
   523  }