github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/internal/broadcasts/broadcasts.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package broadcasts contains code for a Fleetspeak server to manage
    16  // broadcasts. See in particular the Manager.
    17  package broadcasts
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"math/rand"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	log "github.com/golang/glog"
    30  	"github.com/google/fleetspeak/fleetspeak/src/common"
    31  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    32  	"github.com/google/fleetspeak/fleetspeak/src/server/ids"
    33  	"github.com/google/fleetspeak/fleetspeak/src/server/internal/cache"
    34  	"github.com/google/fleetspeak/fleetspeak/src/server/internal/ftime"
    35  	"github.com/google/fleetspeak/fleetspeak/src/server/internal/notifications"
    36  
    37  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    38  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    39  )
    40  
    41  const (
    42  	allocFrac     = 0.2
    43  	allocDuration = 10 * time.Minute
    44  )
    45  
    46  // A Manager keeps tracts of the active broadcasts in a system. It allows a
    47  // fleetspeak system to deliver broadcasts to client without all of the
    48  // fleetspeak servers trying to modify the same datastore records at the same
    49  // time.
    50  type Manager struct {
    51  	bs           db.BroadcastStore
    52  	infos        map[ids.BroadcastID]*bInfo
    53  	l            sync.RWMutex   // Protects the structure of i.
    54  	done         chan bool      // Closes to indicate it is time to shut down.
    55  	working      sync.WaitGroup // Indicates that work is ongoing.
    56  	basePollWait time.Duration
    57  
    58  	// These members are used to inform other parts of the system that they might
    59  	// need to check for broadcast messages.
    60  	clientCache *cache.Clients
    61  	dispatcher  *notifications.Dispatcher
    62  }
    63  
    64  // MakeManager creates a Manager, populates it with the
    65  // current set of broadcasts, and begins updating the broadcasts in the
    66  // background, the time between updates is always between pw and 2*pw.
    67  func MakeManager(ctx context.Context, bs db.BroadcastStore, pw time.Duration, clientCache *cache.Clients, dispatcher *notifications.Dispatcher) (*Manager, error) {
    68  	r := &Manager{
    69  		bs:           bs,
    70  		infos:        make(map[ids.BroadcastID]*bInfo),
    71  		done:         make(chan bool),
    72  		basePollWait: pw,
    73  		clientCache:  clientCache,
    74  		dispatcher:   dispatcher,
    75  	}
    76  	if err := r.refreshInfo(ctx); err != nil {
    77  		return nil, err
    78  	}
    79  	r.working.Add(1)
    80  	go r.refreshLoop()
    81  	return r, nil
    82  }
    83  
    84  // bInfo contains what a broadcast manager needs to know about a broadcast.
    85  type bInfo struct {
    86  	bID      ids.BroadcastID
    87  	b        *spb.Broadcast
    88  	useCount sync.WaitGroup // How many goroutines are using this bInfo and associated allocation.
    89  
    90  	// The remaining fields describe the allocation that we have for the broadcast. If limit is
    91  	// set to BroadcastUnlimited, then we don't actually have (or need) an allocation.
    92  	aID    ids.AllocationID
    93  	limit  uint64
    94  	sent   uint64 // How many messages have we sent under the current allocation, only accessed through atomic.
    95  	expiry time.Time
    96  	lock   sync.Mutex // Used to sychronize updates to the allocation record in the database.
    97  }
    98  
    99  // limitedAtomicIncrement atomically adds one to *addr, unless the result would
   100  // exceed limit. Return true if successful.
   101  func limitedAtomicIncrement(addr *uint64, limit uint64) bool {
   102  	for {
   103  		c := atomic.LoadUint64(addr)
   104  		if c >= limit {
   105  			return false
   106  		}
   107  		if atomic.CompareAndSwapUint64(addr, c, c+1) {
   108  			return true
   109  		}
   110  	}
   111  }
   112  
   113  // pollWait picks a time to wait before the next refresh.
   114  func (m *Manager) pollWait() time.Duration {
   115  	return m.basePollWait + time.Duration(float64(m.basePollWait)*rand.Float64())
   116  }
   117  
   118  // label shadows fspb.Label, but is safe to use as a map key
   119  type label struct {
   120  	serviceName string
   121  	label       string
   122  }
   123  
   124  func labelFromProto(l *fspb.Label) label {
   125  	return label{serviceName: l.ServiceName, label: l.Label}
   126  }
   127  
   128  // MakeBroadcastMessagesForClient finds broadcasts that the client is eligible
   129  // to receive.
   130  func (m *Manager) MakeBroadcastMessagesForClient(ctx context.Context, id common.ClientID, labels []*fspb.Label) ([]*fspb.Message, error) {
   131  	labelSet := make(map[label]bool)
   132  	for _, l := range labels {
   133  		labelSet[labelFromProto(l)] = true
   134  	}
   135  
   136  	sent, err := m.bs.ListSentBroadcasts(ctx, id)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	sentSet := make(map[ids.BroadcastID]bool)
   141  	for _, s := range sent {
   142  		sentSet[s] = true
   143  	}
   144  
   145  	var is []*bInfo
   146  	m.l.RLock()
   147  Infos:
   148  	for _, info := range m.infos {
   149  		if sentSet[info.bID] {
   150  			continue
   151  		}
   152  		if !info.expiry.IsZero() && info.expiry.Before(ftime.Now()) {
   153  			continue
   154  		}
   155  		for _, kw := range info.b.RequiredLabels {
   156  			if !labelSet[labelFromProto(kw)] {
   157  				continue Infos
   158  			}
   159  		}
   160  		if info.limit == db.BroadcastUnlimited {
   161  			atomic.AddUint64(&info.sent, 1)
   162  		} else {
   163  			if !limitedAtomicIncrement(&info.sent, info.limit) {
   164  				continue
   165  			}
   166  		}
   167  		info.useCount.Add(1)
   168  		is = append(is, info)
   169  	}
   170  	m.l.RUnlock()
   171  
   172  	// NOTE: we must call useCount.Done() on everything in "is", or the
   173  	// allocation update goroutine will get stuck. From this point on we log
   174  	// errors but don't stop.
   175  	msgs := make([]*fspb.Message, 0, len(is))
   176  	for _, i := range is {
   177  		mid, err := common.RandomMessageID()
   178  		if err != nil {
   179  			log.Errorf("unable to create message id: %v", err)
   180  			if i.limit != db.BroadcastUnlimited {
   181  				// Incantation to decrement a uint64, recommend AddUint64 docs:
   182  				atomic.AddUint64(&i.sent, ^uint64(0))
   183  			}
   184  			i.useCount.Done()
   185  			continue
   186  		}
   187  		msg := &fspb.Message{
   188  			MessageId: mid.Bytes(),
   189  			Source:    i.b.Source,
   190  			Destination: &fspb.Address{
   191  				ClientId:    id.Bytes(),
   192  				ServiceName: i.b.Source.ServiceName,
   193  			},
   194  			MessageType:  i.b.MessageType,
   195  			Data:         i.b.Data,
   196  			CreationTime: db.NowProto(),
   197  		}
   198  		i.lock.Lock()
   199  		err = m.bs.SaveBroadcastMessage(ctx, msg, i.bID, id, i.aID)
   200  		i.lock.Unlock()
   201  		if err != nil {
   202  			log.Errorf("SaveBroadcastMessage of instance of broadcast %v failed. Not sending. [%v]", i.bID, err)
   203  			if i.limit != db.BroadcastUnlimited {
   204  				// Incantation to decrement a uint64, recommend by AddUint64 docs:
   205  				atomic.AddUint64(&i.sent, ^uint64(0))
   206  			}
   207  			i.useCount.Done()
   208  			continue
   209  		}
   210  		msgs = append(msgs, msg)
   211  		i.useCount.Done()
   212  	}
   213  	return msgs, nil
   214  }
   215  
   216  func (m *Manager) refreshLoop() {
   217  	defer m.working.Done()
   218  	ctx := context.Background()
   219  	for {
   220  		select {
   221  		case _, ok := <-m.done:
   222  			if !ok {
   223  				return
   224  			}
   225  		case <-time.After(m.pollWait()):
   226  		}
   227  
   228  		if err := m.refreshInfo(ctx); err != nil {
   229  			log.Errorf("Error refreshing broadcast infos: %v", err)
   230  
   231  		}
   232  	}
   233  }
   234  
   235  // refreshInfo refreshes the bInfo map using the data from the database.
   236  func (m *Manager) refreshInfo(ctx context.Context) error {
   237  	// Find the allocations that we don't need (or want) to change.
   238  	curr := m.findCurrentAllocs()
   239  
   240  	// Find the active broadcasts.
   241  	bs, err := m.bs.ListActiveBroadcasts(ctx)
   242  	if err != nil {
   243  		return fmt.Errorf("unable to list active broadcasts: %v", err)
   244  	}
   245  
   246  	// Create any new allocations that we'll need: everything in bs but not curr.
   247  	newAllocs := make(map[ids.BroadcastID]*bInfo)
   248  	activeIds := make(map[ids.BroadcastID]bool)
   249  	for _, b := range bs {
   250  		id, err := ids.BytesToBroadcastID(b.Broadcast.BroadcastId)
   251  		if err != nil {
   252  			log.Errorf("Broadcast [%v] has bad id, skipping: %v", b.Broadcast, err)
   253  			continue
   254  		}
   255  		activeIds[id] = true
   256  		if !curr[id] {
   257  			if b.Sent == b.Limit {
   258  				continue
   259  			}
   260  			a, err := m.bs.CreateAllocation(ctx, id, allocFrac, ftime.Now().Add(allocDuration))
   261  			if err != nil {
   262  				log.Errorf("Unable to create alloc for broadcast %v, skipping: %v", id, err)
   263  				continue
   264  			}
   265  			if a != nil {
   266  				newAllocs[id] = &bInfo{
   267  					bID: id,
   268  					b:   b.Broadcast,
   269  
   270  					aID:    a.ID,
   271  					limit:  a.Limit,
   272  					sent:   0,
   273  					expiry: a.Expiry,
   274  				}
   275  			}
   276  		}
   277  	}
   278  
   279  	// Some things in curr might no longer be active, e.g. if the broadcast
   280  	// was canceled. Remove them from curr so that updateAllocs knows to
   281  	// clear them.
   282  	for id := range curr {
   283  		if !activeIds[id] {
   284  			delete(curr, id)
   285  		}
   286  	}
   287  
   288  	// Swap/insert the new allocations.
   289  	c := m.updateAllocs(curr, newAllocs)
   290  
   291  	// If we added any new allocations, then we should recompute broadcasts for
   292  	// any cached clients. We also start a process taking up to basePollWait time
   293  	// to notify already connected clients.
   294  	if len(newAllocs) > 0 {
   295  		m.clientCache.Clear()
   296  		go func() {
   297  			ctx, fin := context.WithTimeout(context.Background(), m.basePollWait)
   298  			m.dispatcher.NotifyAll(ctx)
   299  			fin()
   300  		}()
   301  	}
   302  
   303  	var errMsgs []string
   304  	// Cleanup the dead allocations. They've been removed from m.infos, so
   305  	// once useCount reaches 0, no new actions with them will start.  We
   306  	// cleanup everything we can, even if there are errors.
   307  	for _, a := range c {
   308  		a.useCount.Wait()
   309  		if err := m.bs.CleanupAllocation(ctx, a.bID, a.aID); err != nil {
   310  			errMsgs = append(errMsgs, fmt.Sprintf("[%v,%v]:\"%v\"", a.bID, a.aID, err))
   311  		}
   312  	}
   313  
   314  	if len(errMsgs) > 0 {
   315  		return errors.New("errors cleaning up allocations - " + strings.Join(errMsgs, " "))
   316  	}
   317  	return nil
   318  }
   319  
   320  func (m *Manager) findCurrentAllocs() map[ids.BroadcastID]bool {
   321  	r := make(map[ids.BroadcastID]bool)
   322  	// We should be the only goroutine that modifies m.infos, so we don't need to
   323  	// touch m.l.
   324  
   325  	for _, info := range m.infos {
   326  		// If the current allocation has room to send at least one
   327  		// message, and it should last until the next tick, we keep it.
   328  		if (info.limit == db.BroadcastUnlimited || atomic.LoadUint64(&info.sent) < info.limit) &&
   329  			(info.expiry.IsZero() || info.expiry.After(ftime.Now().Add(2*m.basePollWait))) {
   330  			r[info.bID] = true
   331  		}
   332  	}
   333  	return r
   334  }
   335  
   336  // updateAllocs updates the map m.info.
   337  //
   338  // "keep" identifies records which should be preserved while "new" identifies
   339  // records which should be updated. Any other record will be removed.  The
   340  // return value lists those structs which were removed or replaced.
   341  func (m *Manager) updateAllocs(keep map[ids.BroadcastID]bool, new map[ids.BroadcastID]*bInfo) []*bInfo {
   342  	m.l.Lock()
   343  	defer m.l.Unlock()
   344  
   345  	var ret []*bInfo
   346  
   347  	// Make a pass through m.infos, deleting anything not in keep or new.
   348  	for _, info := range m.infos {
   349  		if keep[info.bID] {
   350  			continue
   351  		}
   352  		if new[info.bID] == nil {
   353  			ret = append(ret, info)
   354  			delete(m.infos, info.bID)
   355  			continue
   356  		}
   357  	}
   358  
   359  	// Make a pass through new, updating m.infos.
   360  	for id, info := range new {
   361  		if m.infos[id] != nil {
   362  			ret = append(ret, m.infos[id])
   363  		}
   364  		m.infos[id] = info
   365  	}
   366  
   367  	return ret
   368  }
   369  
   370  // Close attempts to shut down the Manager gracefully.
   371  func (m *Manager) Close(ctx context.Context) error {
   372  	close(m.done)
   373  	m.working.Wait()
   374  
   375  	var errMsgs []string
   376  	for _, i := range m.infos {
   377  		i.useCount.Wait()
   378  		if err := m.bs.CleanupAllocation(ctx, i.bID, i.aID); err != nil {
   379  			errMsgs = append(errMsgs, fmt.Sprintf("[%v,%v]:\"%v\"", i.bID, i.aID, err))
   380  		}
   381  	}
   382  	m.infos = nil
   383  
   384  	if len(errMsgs) > 0 {
   385  		return errors.New("errors cleaning up allocations - " + strings.Join(errMsgs, " "))
   386  	}
   387  	return nil
   388  }