github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/https/streaming_message_server.go (about)

     1  // Copyright 2018 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package https
    16  
    17  import (
    18  	"bufio"
    19  	"context"
    20  	"crypto"
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"math"
    26  	"math/rand"
    27  	"net"
    28  	"net/http"
    29  	"sync"
    30  	"time"
    31  
    32  	log "github.com/golang/glog"
    33  	"github.com/google/fleetspeak/fleetspeak/src/common"
    34  	"github.com/google/fleetspeak/fleetspeak/src/server/comms"
    35  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    36  	"github.com/google/fleetspeak/fleetspeak/src/server/stats"
    37  	"google.golang.org/protobuf/proto"
    38  
    39  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    40  )
    41  
    42  const magic = uint32(0xf1ee1001)
    43  const baseErrorDelay = float64(100 * time.Millisecond)
    44  
    45  type fullResponseWriter interface {
    46  	http.ResponseWriter
    47  	http.CloseNotifier
    48  	http.Flusher
    49  }
    50  
    51  func readUint32(body *bufio.Reader) (uint32, error) {
    52  	b := make([]byte, 4)
    53  	if _, err := io.ReadAtLeast(body, b, 4); err != nil {
    54  		return 0, err
    55  	}
    56  	return binary.LittleEndian.Uint32(b), nil
    57  }
    58  
    59  func writeUint32(res fullResponseWriter, i uint32) error {
    60  	return binary.Write(res, binary.LittleEndian, i)
    61  }
    62  
    63  func newStreamingMessageServer(c *Communicator, maxPerClientBatchProcessors uint32) *streamingMessageServer {
    64  	return &streamingMessageServer{c, maxPerClientBatchProcessors}
    65  }
    66  
    67  // messageServer wraps a Communicator in order to handle clients polls.
    68  type streamingMessageServer struct {
    69  	*Communicator
    70  	maxPerClientBatchProcessors uint32
    71  }
    72  
    73  func (s *streamingMessageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) {
    74  	earlyError := func(msg string, status int) {
    75  		log.ErrorDepth(1, fmt.Sprintf("%s: %s", http.StatusText(status), msg))
    76  		s.fs.StatsCollector().ClientPoll(stats.PollInfo{
    77  			CTX:    req.Context(),
    78  			Start:  db.Now(),
    79  			End:    db.Now(),
    80  			Status: status,
    81  			Type:   stats.StreamStart,
    82  		})
    83  	}
    84  
    85  	if !s.startProcessing() {
    86  		earlyError("server not ready", http.StatusInternalServerError)
    87  		return
    88  	}
    89  	defer s.stopProcessing()
    90  
    91  	fullRes, ok := res.(fullResponseWriter)
    92  	if !ok {
    93  		earlyError("/streaming-message requested, but not supported. ResponseWriter is not a fullResponseWriter", http.StatusNotFound)
    94  		return
    95  	}
    96  
    97  	if req.Method != http.MethodPost {
    98  		earlyError(fmt.Sprintf("%v not supported", req.Method), http.StatusBadRequest)
    99  		return
   100  	}
   101  
   102  	cert, err := GetClientCert(req, s.p.FrontendConfig)
   103  
   104  	if err != nil {
   105  		earlyError(err.Error(), http.StatusBadRequest)
   106  		return
   107  	}
   108  
   109  	if cert.PublicKey == nil {
   110  		earlyError("public key not present in client cert", http.StatusBadRequest)
   111  		return
   112  	}
   113  
   114  	body := bufio.NewReader(req.Body)
   115  
   116  	// Set a 9-11 minute overall maximum lifespan of the connection.
   117  	ctx, fin := context.WithTimeout(req.Context(), s.p.StreamingLifespan+time.Duration(float32(s.p.StreamingJitter)*rand.Float32()))
   118  	defer fin()
   119  
   120  	// Also create a way to terminate early in case of error.
   121  	ctx, cancel := context.WithCancel(ctx)
   122  	defer cancel()
   123  
   124  	addr := addrFromString(req.RemoteAddr)
   125  	info, moreMsgs, err := s.initialPoll(ctx, addr, cert.PublicKey, fullRes, body)
   126  	if err != nil || info == nil {
   127  		return
   128  	}
   129  
   130  	m := streamManager{
   131  		ctx:  ctx,
   132  		s:    s,
   133  		info: info,
   134  		res:  fullRes,
   135  		body: body,
   136  
   137  		localNotices: make(chan struct{}, 1),
   138  		out:          make(chan *fspb.ContactData, 5),
   139  
   140  		cancel: cancel,
   141  	}
   142  	defer func() {
   143  		// Shutdown is a bit subtle.
   144  		//
   145  		// We get here iff m.ctx is canceled, timed out, etc.
   146  		//
   147  		// Once ctx is canceled, writeLoop will notice, close the outgoing
   148  		// ResponseWriter, and begin blindly draining m.out.
   149  		//
   150  		// Closing ResponseWriter will cause any pending read to error out and
   151  		// readLoop to return.
   152  		//
   153  		// Once the readLoop returns, we can safely close m.out and wait for
   154  		// writeLoop to finish.
   155  		info.Fin()
   156  		m.reading.Wait()
   157  		close(m.out)
   158  		m.writing.Wait()
   159  	}()
   160  
   161  	m.reading.Add(2)
   162  	go m.readLoop()
   163  	go m.notifyLoop(s.p.StreamingCloseTime, moreMsgs)
   164  
   165  	m.writing.Add(1)
   166  	go m.writeLoop()
   167  
   168  	select {
   169  	case <-ctx.Done():
   170  	case <-fullRes.CloseNotify():
   171  	case <-s.stopping:
   172  	}
   173  	m.cancel()
   174  }
   175  
   176  func (s *streamingMessageServer) initialPoll(ctx context.Context, addr net.Addr, key crypto.PublicKey, res fullResponseWriter, body *bufio.Reader) (*comms.ConnectionInfo, bool, error) {
   177  	ctx, fin := context.WithTimeout(ctx, 3*time.Minute)
   178  
   179  	pi := stats.PollInfo{
   180  		CTX:    ctx,
   181  		Start:  db.Now(),
   182  		Status: http.StatusTeapot, // Should never actually be returned
   183  		Type:   stats.StreamStart,
   184  	}
   185  	defer func() {
   186  		fin()
   187  		if pi.Status == http.StatusTeapot {
   188  			log.Errorf("Forgot to set status, PollInfo: %v", pi)
   189  		}
   190  		pi.End = db.Now()
   191  		s.fs.StatsCollector().ClientPoll(pi)
   192  	}()
   193  
   194  	makeError := func(msg string, status int) error {
   195  		log.ErrorDepth(1, fmt.Sprintf("%s: [id:%v addr:%v] %s", http.StatusText(status), pi.ID, addr, msg))
   196  		pi.Status = status
   197  		return errors.New(msg)
   198  	}
   199  
   200  	id, err := common.MakeClientID(key)
   201  	if err != nil {
   202  		return nil, false, makeError(fmt.Sprintf("unable to create client id from public key: %v", err), http.StatusBadRequest)
   203  	}
   204  	pi.ID = id
   205  
   206  	m, err := readUint32(body)
   207  	if err != nil {
   208  		return nil, false, makeError(fmt.Sprintf("error reading magic number: %v", err), http.StatusBadRequest)
   209  	}
   210  	if m != magic {
   211  		return nil, false, makeError(fmt.Sprintf("unknown magic number: got %x, expected %x", m, magic), http.StatusBadRequest)
   212  	}
   213  
   214  	st := time.Now()
   215  	size, err := binary.ReadUvarint(body)
   216  	if err != nil {
   217  		return nil, false, makeError(fmt.Sprintf("error reading size: %v", err), http.StatusBadRequest)
   218  	}
   219  	if size > MaxContactSize {
   220  		return nil, false, makeError(fmt.Sprintf("initial contact size too large: got %d, expected at most %d", size, MaxContactSize), http.StatusBadRequest)
   221  	}
   222  
   223  	buf := make([]byte, size)
   224  	_, err = io.ReadFull(body, buf)
   225  	if err != nil {
   226  		return nil, false, makeError(fmt.Sprintf("error reading body for initial exchange: %v", err), http.StatusBadRequest)
   227  	}
   228  	pi.ReadTime = time.Since(st)
   229  	pi.ReadBytes = int(size)
   230  
   231  	var wcd fspb.WrappedContactData
   232  	if err := proto.Unmarshal(buf, &wcd); err != nil {
   233  		return nil, false, makeError(fmt.Sprintf("error parsing body: %v", err), http.StatusBadRequest)
   234  	}
   235  
   236  	info, toSend, more, err := s.fs.InitializeConnection(ctx, addr, key, &wcd, true)
   237  	if err == comms.ErrNotAuthorized {
   238  		return nil, false, makeError("not authorized", http.StatusServiceUnavailable)
   239  	}
   240  	if err != nil {
   241  		return nil, false, makeError(fmt.Sprintf("error processing contact: %v", err), http.StatusInternalServerError)
   242  	}
   243  	pi.CacheHit = info.Client.Cached
   244  
   245  	outBuf, err := proto.Marshal(toSend)
   246  	if err != nil {
   247  		info.Fin()
   248  		return nil, false, makeError(fmt.Sprintf("error preparing messages: %v", err), http.StatusInternalServerError)
   249  	}
   250  	sizeBuf := make([]byte, 0, 16)
   251  	sizeBuf = binary.AppendUvarint(sizeBuf, uint64(len(outBuf)))
   252  
   253  	st = time.Now()
   254  	sizeWritten, err := res.Write(sizeBuf)
   255  	if err != nil {
   256  		info.Fin()
   257  		return nil, false, makeError(fmt.Sprintf("error writing body: %v", err), http.StatusInternalServerError)
   258  	}
   259  	bufWritten, err := res.Write(outBuf)
   260  	if err != nil {
   261  		info.Fin()
   262  		return nil, false, makeError(fmt.Sprintf("error writing body: %v", err), http.StatusInternalServerError)
   263  	}
   264  	res.Flush()
   265  
   266  	pi.WriteTime = time.Since(st)
   267  	pi.End = time.Now()
   268  	pi.WriteBytes = sizeWritten + bufWritten
   269  	pi.Status = http.StatusOK
   270  	return info, more, nil
   271  }
   272  
   273  type streamManager struct {
   274  	ctx context.Context
   275  	s   *streamingMessageServer
   276  
   277  	info *comms.ConnectionInfo
   278  	res  fullResponseWriter
   279  	body *bufio.Reader
   280  
   281  	// Signals that a we have more tokens and might retry sending.
   282  	localNotices chan struct{}
   283  
   284  	// The read- and writeLoop will wait for these. Separate because readloop
   285  	// needs to finish before writeLoop.
   286  	reading sync.WaitGroup
   287  	writing sync.WaitGroup
   288  
   289  	out chan *fspb.ContactData
   290  
   291  	cancel func() // Shuts down the stream when called.
   292  }
   293  
   294  func (m *streamManager) readLoop() {
   295  	defer m.reading.Done()
   296  	defer m.cancel()
   297  
   298  	cnt := uint64(0)
   299  
   300  	// Number of batches from the same client that will be processed concurrently.
   301  	const maxBatchProcessors = 10
   302  	batchCh := make(chan *fspb.WrappedContactData, m.s.maxPerClientBatchProcessors)
   303  
   304  	for {
   305  		pi, wcd, err := m.readOne()
   306  		if err != nil {
   307  			// If the context has been canceled, it is probably a 'normal' termination
   308  			// - disconnect, max connection durating, etc. But if it is still active,
   309  			// we are going to tear down everything because of an unexpected read
   310  			// error and should log/record why.
   311  			if m.ctx.Err() == nil && pi != nil {
   312  				m.s.fs.StatsCollector().ClientPoll(*pi)
   313  				log.Errorf("Streaming Connection to %v terminated with error: %v", m.info.Client.ID, err)
   314  			}
   315  			return
   316  		}
   317  
   318  		// Increment the counter with every processed message.
   319  		cnt++
   320  
   321  		// This will block if number of concurrent processors is greater than maxBatchProcessors.
   322  		batchCh <- wcd
   323  		// Ensure the m.out stays open while the message processing is not done.
   324  		m.reading.Add(1)
   325  		// Given that the processing is done concurrently, capture the current counter value in
   326  		// the function argument.
   327  		go func(curCnt uint64) {
   328  			defer m.reading.Done()
   329  
   330  			wcd := <-batchCh
   331  			if err := m.processOne(wcd); err != nil {
   332  				log.Errorf("Error processing message from %v: %v", m.info.Client.ID, err)
   333  				return
   334  			}
   335  			m.out <- &fspb.ContactData{AckIndex: curCnt}
   336  		}(cnt)
   337  
   338  		m.s.fs.StatsCollector().ClientPoll(*pi)
   339  	}
   340  }
   341  
   342  func (m *streamManager) readOne() (*stats.PollInfo, *fspb.WrappedContactData, error) {
   343  	size, err := binary.ReadUvarint(m.body)
   344  	if err != nil {
   345  		return nil, nil, err
   346  	}
   347  	if size > MaxContactSize {
   348  		return nil, nil, fmt.Errorf("streaming contact size too large: got %d, expected at most %d", size, MaxContactSize)
   349  	}
   350  
   351  	pi := &stats.PollInfo{
   352  		CTX:      m.ctx,
   353  		ID:       m.info.Client.ID,
   354  		Start:    db.Now(),
   355  		Status:   http.StatusTeapot,
   356  		CacheHit: true,
   357  		Type:     stats.StreamFromClient,
   358  	}
   359  	defer func() {
   360  		if pi.Status == http.StatusTeapot {
   361  			log.Errorf("Forgot to set status.")
   362  		}
   363  		pi.End = db.Now()
   364  	}()
   365  	buf := make([]byte, size)
   366  	if _, err := io.ReadFull(m.body, buf); err != nil {
   367  		pi.Status = http.StatusBadRequest
   368  		return pi, nil, fmt.Errorf("error reading streamed data: %v", err)
   369  	}
   370  	pi.ReadTime = time.Since(pi.Start)
   371  	pi.ReadBytes = int(size)
   372  
   373  	wcd := &fspb.WrappedContactData{}
   374  	if err = proto.Unmarshal(buf, wcd); err != nil {
   375  		pi.Status = http.StatusBadRequest
   376  		return pi, nil, fmt.Errorf("error parsing streamed data: %v", err)
   377  	}
   378  
   379  	// Validate message early to provide feedback to the agent and fail with a
   380  	// descriptive HTTP code.
   381  	_, err = m.s.fs.ValidateMessagesFromClient(context.Background(), m.info, wcd)
   382  	if err != nil {
   383  		pi.Status = http.StatusServiceUnavailable
   384  		return pi, nil, fmt.Errorf("message validation failed: %v", err)
   385  	}
   386  
   387  	pi.Status = http.StatusOK
   388  	return pi, wcd, nil
   389  }
   390  
   391  func (m *streamManager) processOne(wcd *fspb.WrappedContactData) error {
   392  	var blockedServices []string
   393  	for k, v := range m.info.MessageTokens() {
   394  		if v == 0 {
   395  			blockedServices = append(blockedServices, k)
   396  		}
   397  	}
   398  	// We might be close to the connection's natural end. Accept up to 15
   399  	// seconds of overrun trying to process what we've been given. This
   400  	// should only happen when things are unexpectedly slow and likely
   401  	// causes duplicate messages.
   402  	ctx, fin := context.WithCancel(context.Background())
   403  	go func() {
   404  		defer fin()
   405  		select {
   406  		case <-ctx.Done():
   407  			return
   408  		case <-m.ctx.Done():
   409  			log.Warningf("Extra time required while processing message from %v.", m.info.Client.ID)
   410  			t := time.NewTimer(15 * time.Second)
   411  			defer t.Stop()
   412  			select {
   413  			case <-ctx.Done():
   414  				return
   415  			case <-t.C:
   416  				return
   417  			}
   418  		}
   419  	}()
   420  	err := m.s.fs.HandleMessagesFromClient(ctx, m.info, wcd)
   421  	fin()
   422  	if err != nil {
   423  		if err == comms.ErrNotAuthorized {
   424  			log.Infof("Message not authoried: %v", err)
   425  		} else {
   426  			err = fmt.Errorf("error processing streamed messages: %v", err)
   427  		}
   428  		return err
   429  	}
   430  	tokens := m.info.MessageTokens()
   431  	for _, s := range blockedServices {
   432  		if tokens[s] > 0 {
   433  			select {
   434  			case m.localNotices <- struct{}{}:
   435  			default:
   436  			}
   437  		}
   438  	}
   439  	return nil
   440  }
   441  
   442  func (m *streamManager) notifyLoop(closeTime time.Duration, moreMsgs bool) {
   443  	defer m.reading.Done()
   444  
   445  	// Stop sending messages to the client closeTime (e.g. 30 sec) before our hard deadline.
   446  	d, ok := m.ctx.Deadline()
   447  	if !ok {
   448  		// Shouldn't happen, ctx is created with a deadline.
   449  		log.Fatalf("m.ctx does not have a deadline set")
   450  	}
   451  	deadline := d.Add(-closeTime)
   452  	stop := time.NewTimer(time.Until(deadline))
   453  	defer stop.Stop()
   454  
   455  	// Number of sequential errors getting messages for the client.
   456  	var errCnt int
   457  
   458  	for {
   459  		// This switch decides how long we should wait before trying to
   460  		// get more messages for the client, and returns when it is time
   461  		// to stop.
   462  		switch {
   463  		case errCnt > 0:
   464  			// Last attempt to get messages failed - try again with
   465  			// a jittery exponential backoff in the hopes that the
   466  			// database recovers.
   467  			errDelay := time.Duration((baseErrorDelay + rand.Float64()*baseErrorDelay) * math.Pow(1.5, float64(errCnt)))
   468  			t := time.NewTimer(errDelay)
   469  			log.V(1).Infof("NotifyLoop(%v): waiting %v due to previous error.", m.info.Client.ID, errDelay)
   470  			select {
   471  			case <-m.ctx.Done():
   472  				t.Stop()
   473  				return
   474  			case <-stop.C:
   475  				t.Stop()
   476  				m.out <- &fspb.ContactData{DoneSending: true}
   477  				return
   478  			case <-t.C:
   479  			}
   480  		case moreMsgs:
   481  			// We believe that there are more messages already
   482  			// available, just check if it is time to shutdown.
   483  			log.V(1).Infof("NotifyLoop(%v): continuing, more messages possible.", m.info.Client.ID)
   484  			if time.Now().After(deadline) {
   485  				m.out <- &fspb.ContactData{DoneSending: true}
   486  				return
   487  			}
   488  			if m.ctx.Err() != nil {
   489  				return
   490  			}
   491  		default:
   492  			// Wait for a notification, then wait 1 more second in
   493  			// case more messages arrive.
   494  			log.V(1).Infof("NotifyLoop(%v): waiting for notifications.", m.info.Client.ID)
   495  			select {
   496  			case <-m.ctx.Done():
   497  				return
   498  			case <-stop.C:
   499  				m.out <- &fspb.ContactData{DoneSending: true}
   500  				return
   501  			case _, ok := <-m.info.Notices:
   502  				if !ok {
   503  					return
   504  				}
   505  			case <-m.localNotices:
   506  			}
   507  			t := time.NewTimer(time.Second)
   508  		L:
   509  			for {
   510  				select {
   511  				case <-m.ctx.Done():
   512  					return
   513  				case _, ok := <-m.info.Notices:
   514  					if !ok {
   515  						break L
   516  					}
   517  					continue L
   518  				case <-t.C:
   519  					break L
   520  				}
   521  			}
   522  			t.Stop()
   523  		}
   524  		var cd *fspb.ContactData
   525  		var err error
   526  		cd, moreMsgs, err = m.s.fs.GetMessagesForClient(m.ctx, m.info)
   527  		if err != nil {
   528  			if err == m.ctx.Err() {
   529  				return
   530  			}
   531  			log.Errorf("Error getting messages for streaming client [%v]: %v", m.info.Client.ID, err)
   532  			errCnt++
   533  		} else {
   534  			errCnt = 0
   535  		}
   536  		if cd != nil {
   537  			m.out <- cd
   538  		}
   539  	}
   540  }
   541  
   542  func (m *streamManager) writeLoop() {
   543  	defer m.writing.Done()
   544  	defer func() {
   545  		for range m.out {
   546  		}
   547  	}()
   548  
   549  	for {
   550  		select {
   551  		case cd, ok := <-m.out:
   552  			if !ok {
   553  				return
   554  			}
   555  			pi, err := m.writeOne(cd)
   556  			if err != nil {
   557  				if m.ctx.Err() != nil {
   558  					log.Errorf("Error sending ContactData to client [%v]: %v", m.info.Client.ID, err)
   559  					m.cancel()
   560  					m.s.fs.StatsCollector().ClientPoll(pi)
   561  				}
   562  				// ctx was already canceled - more or less normal shutdown, so don't log
   563  				// as a poll.
   564  				return
   565  			}
   566  			if len(cd.Messages) > 0 {
   567  				m.s.fs.StatsCollector().ClientPoll(pi)
   568  			}
   569  		case <-m.ctx.Done():
   570  			return
   571  		}
   572  	}
   573  }
   574  
   575  func (m *streamManager) writeOne(cd *fspb.ContactData) (stats.PollInfo, error) {
   576  	pi := stats.PollInfo{
   577  		CTX:      m.ctx,
   578  		ID:       m.info.Client.ID,
   579  		Start:    db.Now(),
   580  		Status:   http.StatusTeapot,
   581  		CacheHit: true,
   582  		Type:     stats.StreamToClient,
   583  	}
   584  	defer func() {
   585  		if pi.Status == http.StatusTeapot {
   586  			log.Errorf("Forgot to set status.")
   587  		}
   588  		pi.End = db.Now()
   589  	}()
   590  
   591  	buf, err := proto.Marshal(cd)
   592  	if err != nil {
   593  		return pi, err
   594  	}
   595  	sizeBuf := make([]byte, 0, 16)
   596  	sizeBuf = binary.AppendUvarint(sizeBuf, uint64(len(buf)))
   597  
   598  	sw := time.Now()
   599  	sizeWritten, err := m.res.Write(sizeBuf)
   600  	if err != nil {
   601  		return pi, err
   602  	}
   603  	bufWritten, err := m.res.Write(buf)
   604  	if err != nil {
   605  		return pi, err
   606  	}
   607  	m.res.Flush()
   608  	pi.WriteTime = time.Since(sw)
   609  	pi.WriteBytes = sizeWritten + bufWritten
   610  	pi.Status = http.StatusOK
   611  
   612  	return pi, nil
   613  }