github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/sqlite/broadcaststore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlite
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"fmt"
    22  	"time"
    23  
    24  	"github.com/google/fleetspeak/fleetspeak/src/common"
    25  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    26  	"github.com/google/fleetspeak/fleetspeak/src/server/ids"
    27  
    28  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    29  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    30  	anypb "google.golang.org/protobuf/types/known/anypb"
    31  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    32  )
    33  
    34  // dbBroadcast matches the schema of the broadcasts table.
    35  type dbBroadcast struct {
    36  	broadcastID           string
    37  	sourceServiceName     string
    38  	messageType           string
    39  	expirationTimeSeconds sql.NullInt64
    40  	expirationTimeNanos   sql.NullInt64
    41  	dataTypeURL           sql.NullString
    42  	dataValue             []byte
    43  	sent                  uint64
    44  	allocated             uint64
    45  	messageLimit          uint64
    46  }
    47  
    48  func fromBroadcastProto(b *spb.Broadcast) (*dbBroadcast, error) {
    49  	if b == nil {
    50  		return nil, errors.New("cannot convert nil Broadcast")
    51  	}
    52  	id, err := ids.BytesToBroadcastID(b.BroadcastId)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	if b.Source == nil {
    57  		return nil, fmt.Errorf("Broadcast must have Source. Get: %v", b)
    58  	}
    59  
    60  	res := dbBroadcast{
    61  		broadcastID:       id.String(),
    62  		sourceServiceName: b.Source.ServiceName,
    63  		messageType:       b.MessageType,
    64  	}
    65  	if b.ExpirationTime != nil {
    66  		res.expirationTimeSeconds = sql.NullInt64{Int64: b.ExpirationTime.Seconds, Valid: true}
    67  		res.expirationTimeNanos = sql.NullInt64{Int64: int64(b.ExpirationTime.Nanos), Valid: true}
    68  	}
    69  	if b.Data != nil {
    70  		res.dataTypeURL = sql.NullString{String: b.Data.TypeUrl, Valid: true}
    71  		res.dataValue = b.Data.Value
    72  	}
    73  	return &res, nil
    74  }
    75  
    76  func toBroadcastProto(b *dbBroadcast) (*spb.Broadcast, error) {
    77  	bid, err := ids.StringToBroadcastID(b.broadcastID)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	ret := &spb.Broadcast{
    82  		BroadcastId: bid.Bytes(),
    83  		Source:      &fspb.Address{ServiceName: b.sourceServiceName},
    84  		MessageType: b.messageType,
    85  	}
    86  	if b.expirationTimeSeconds.Valid && b.expirationTimeNanos.Valid {
    87  		ret.ExpirationTime = &tspb.Timestamp{
    88  			Seconds: b.expirationTimeSeconds.Int64,
    89  			Nanos:   int32(b.expirationTimeNanos.Int64),
    90  		}
    91  	}
    92  	if b.dataTypeURL.Valid {
    93  		ret.Data = &anypb.Any{
    94  			TypeUrl: b.dataTypeURL.String,
    95  			Value:   b.dataValue,
    96  		}
    97  	}
    98  	return ret, nil
    99  }
   100  
   101  func (d *Datastore) CreateBroadcast(ctx context.Context, b *spb.Broadcast, limit uint64) error {
   102  	d.l.Lock()
   103  	defer d.l.Unlock()
   104  	dbB, err := fromBroadcastProto(b)
   105  	if err != nil {
   106  		return err
   107  	}
   108  	dbB.messageLimit = limit
   109  	return d.runInTx(func(tx *sql.Tx) error {
   110  		if _, err := tx.ExecContext(ctx, "INSERT INTO broadcasts("+
   111  			"broadcast_id, "+
   112  			"source_service_name, "+
   113  			"message_type, "+
   114  			"expiration_time_seconds, "+
   115  			"expiration_time_nanos, "+
   116  			"data_type_url, "+
   117  			"data_value, "+
   118  			"sent, "+
   119  			"allocated, "+
   120  			"message_limit) "+
   121  			"VALUES(?, ?, ?, ?, ?, ?, ?, 0, 0, ?)",
   122  			dbB.broadcastID,
   123  			dbB.sourceServiceName,
   124  			dbB.messageType,
   125  			dbB.expirationTimeSeconds,
   126  			dbB.expirationTimeNanos,
   127  			dbB.dataTypeURL,
   128  			dbB.dataValue,
   129  			dbB.messageLimit,
   130  		); err != nil {
   131  			return err
   132  		}
   133  		for _, l := range b.RequiredLabels {
   134  			if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_labels(broadcast_id, service_name, label) VALUES(?,?,?)", dbB.broadcastID, l.ServiceName, l.Label); err != nil {
   135  				return err
   136  			}
   137  
   138  		}
   139  		return nil
   140  	})
   141  }
   142  
   143  func (d *Datastore) SetBroadcastLimit(ctx context.Context, id ids.BroadcastID, limit uint64) error {
   144  	d.l.Lock()
   145  	defer d.l.Unlock()
   146  	return d.runInTx(func(tx *sql.Tx) error {
   147  		_, err := tx.ExecContext(ctx, "UPDATE broadcasts(message_limit) VALUES(?) WHERE broadcast_id=?", limit, id.String())
   148  		return err
   149  	})
   150  }
   151  
   152  func (d *Datastore) SaveBroadcastMessage(ctx context.Context, msg *fspb.Message, bID ids.BroadcastID, cID common.ClientID, aID ids.AllocationID) error {
   153  	d.l.Lock()
   154  	defer d.l.Unlock()
   155  	dbm, err := fromMessageProto(msg)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	return d.runInTx(func(tx *sql.Tx) error {
   161  		var as, al uint64
   162  		exp := &tspb.Timestamp{}
   163  		r := tx.QueryRowContext(ctx, "SELECT sent, message_limit, expiration_time_seconds, expiration_time_nanos FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String())
   164  		if err := r.Scan(&as, &al, &exp.Seconds, &exp.Nanos); err != nil {
   165  			return err
   166  		}
   167  		if as >= al {
   168  			return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is full: Sent: %v Limit: %v", aID, bID, as, al)
   169  		}
   170  		if err := exp.CheckValid(); err != nil {
   171  			return fmt.Errorf("SaveBroadcastMessage: unable to convert expiry to time: %v", err)
   172  		}
   173  		et := exp.AsTime()
   174  		if db.Now().After(et) {
   175  			return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is expired: %v", aID, bID, et)
   176  		}
   177  
   178  		if err := d.tryStoreMessage(ctx, tx, dbm, true); err != nil {
   179  			return err
   180  		}
   181  
   182  		if _, err := tx.ExecContext(ctx, "UPDATE broadcast_allocations SET sent = ? WHERE broadcast_id = ? AND allocation_id = ?", as+1, bID.String(), aID.String()); err != nil {
   183  			return err
   184  		}
   185  		_, err = tx.ExecContext(ctx, "INSERT INTO broadcast_sent(broadcast_id, client_id) VALUES (?, ?)", bID.String(), cID.String())
   186  		return err
   187  	})
   188  }
   189  
   190  func (d *Datastore) ListActiveBroadcasts(ctx context.Context) ([]*db.BroadcastInfo, error) {
   191  	d.l.Lock()
   192  	defer d.l.Unlock()
   193  	var ret []*db.BroadcastInfo
   194  	err := d.runInTx(func(tx *sql.Tx) error {
   195  		now := db.NowProto()
   196  		rs, err := tx.QueryContext(ctx, "SELECT "+
   197  			"broadcast_id, "+
   198  			"source_service_name, "+
   199  			"message_type, "+
   200  			"expiration_time_seconds, "+
   201  			"expiration_time_nanos, "+
   202  			"data_type_url, "+
   203  			"data_value, "+
   204  			"sent, "+
   205  			"allocated, "+
   206  			"message_limit "+
   207  			"FROM broadcasts "+
   208  			"WHERE sent < message_limit "+
   209  			"AND (expiration_time_seconds IS NULL OR (expiration_time_seconds > ?) "+
   210  			"OR (expiration_time_seconds = ? "+
   211  			"AND expiration_time_nanos > ?))",
   212  			now.Seconds, now.Seconds, now.Nanos)
   213  		if err != nil {
   214  			return err
   215  		}
   216  		defer rs.Close()
   217  		for rs.Next() {
   218  			var b dbBroadcast
   219  			if err := rs.Scan(
   220  				&b.broadcastID,
   221  				&b.sourceServiceName,
   222  				&b.messageType,
   223  				&b.expirationTimeSeconds,
   224  				&b.expirationTimeNanos,
   225  				&b.dataTypeURL,
   226  				&b.dataValue,
   227  				&b.sent,
   228  				&b.allocated,
   229  				&b.messageLimit,
   230  			); err != nil {
   231  				return err
   232  			}
   233  			bp, err := toBroadcastProto(&b)
   234  			if err != nil {
   235  				return err
   236  			}
   237  			ret = append(ret, &db.BroadcastInfo{
   238  				Broadcast: bp,
   239  				Sent:      b.sent,
   240  				Limit:     b.messageLimit,
   241  			})
   242  		}
   243  		if err := rs.Err(); err != nil {
   244  			return err
   245  		}
   246  		rs.Close()
   247  		stmt, err := tx.Prepare("SELECT service_name, label FROM broadcast_labels WHERE broadcast_id = ?")
   248  		if err != nil {
   249  			return err
   250  		}
   251  		defer stmt.Close()
   252  		for _, i := range ret {
   253  			id, err := ids.BytesToBroadcastID(i.Broadcast.BroadcastId)
   254  			if err != nil {
   255  				return err
   256  			}
   257  			r, err := stmt.QueryContext(ctx, id.String())
   258  			if err != nil {
   259  				return err
   260  			}
   261  			for r.Next() {
   262  				l := &fspb.Label{}
   263  				if err := r.Scan(&l.ServiceName, &l.Label); err != nil {
   264  					return err
   265  				}
   266  				i.Broadcast.RequiredLabels = append(i.Broadcast.RequiredLabels, l)
   267  			}
   268  			if err := r.Err(); err != nil {
   269  				return err
   270  			}
   271  		}
   272  		return nil
   273  	})
   274  	return ret, err
   275  }
   276  
   277  func (d *Datastore) ListSentBroadcasts(ctx context.Context, id common.ClientID) ([]ids.BroadcastID, error) {
   278  	d.l.Lock()
   279  	defer d.l.Unlock()
   280  	rs, err := d.db.QueryContext(ctx, "SELECT broadcast_id FROM broadcast_sent WHERE client_id = ?", id.String())
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	defer rs.Close()
   285  	var res []ids.BroadcastID
   286  	for rs.Next() {
   287  		var b string
   288  		err = rs.Scan(&b)
   289  		if err != nil {
   290  			return nil, err
   291  		}
   292  		bID, err := ids.StringToBroadcastID(b)
   293  		if err != nil {
   294  			return nil, fmt.Errorf("ListSentBroadcasts: bad broadcast id for client %v: %v", id, b)
   295  		}
   296  		res = append(res, bID)
   297  	}
   298  	if err := rs.Err(); err != nil {
   299  		return nil, err
   300  	}
   301  	return res, nil
   302  }
   303  
   304  func (d *Datastore) CreateAllocation(ctx context.Context, id ids.BroadcastID, frac float32, expiry time.Time) (*db.AllocationInfo, error) {
   305  	d.l.Lock()
   306  	defer d.l.Unlock()
   307  	var ret *db.AllocationInfo
   308  	err := d.runInTx(func(tx *sql.Tx) error {
   309  		ep := tspb.New(expiry)
   310  		if err := ep.CheckValid(); err != nil {
   311  			return err
   312  		}
   313  		aid, err := ids.RandomAllocationID()
   314  		if err != nil {
   315  			return err
   316  		}
   317  
   318  		var b dbBroadcast
   319  		r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", id.String())
   320  		if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil {
   321  			return err
   322  		}
   323  		toAllocate, newAllocated := db.ComputeBroadcastAllocation(b.messageLimit, b.allocated, b.sent, frac)
   324  		if toAllocate == 0 {
   325  			return nil
   326  		}
   327  
   328  		if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET allocated = ? WHERE broadcast_id = ?", newAllocated, id.String()); err != nil {
   329  			return err
   330  		}
   331  		if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_allocations("+
   332  			"broadcast_id, "+
   333  			"allocation_id, "+
   334  			"sent, "+
   335  			"message_limit, "+
   336  			"expiration_time_seconds, "+
   337  			"expiration_time_nanos) "+
   338  			"VALUES (?, ?, 0, ?, ?, ?) ",
   339  			id.String(), aid.String(), toAllocate, ep.Seconds, ep.Nanos); err != nil {
   340  			return err
   341  		}
   342  
   343  		ret = &db.AllocationInfo{
   344  			ID:     aid,
   345  			Limit:  toAllocate,
   346  			Expiry: expiry,
   347  		}
   348  		return nil
   349  	})
   350  	return ret, err
   351  }
   352  
   353  func (d *Datastore) CleanupAllocation(ctx context.Context, bID ids.BroadcastID, aID ids.AllocationID) error {
   354  	d.l.Lock()
   355  	defer d.l.Unlock()
   356  	return d.runInTx(func(tx *sql.Tx) error {
   357  		var b dbBroadcast
   358  		r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", bID.String())
   359  		if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil {
   360  			return err
   361  		}
   362  
   363  		var as, al uint64
   364  		r = tx.QueryRowContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String())
   365  		if err := r.Scan(&as, &al); err != nil {
   366  			return err
   367  		}
   368  		newAllocated, err := db.ComputeBroadcastAllocationCleanup(al, b.allocated)
   369  		if err != nil {
   370  			return fmt.Errorf("unable to clear allocation [%v,%v]: %v", bID.String(), aID.String(), err)
   371  		}
   372  		if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET sent = ?, allocated = ? WHERE broadcast_id = ?", b.sent+as, newAllocated, bID.String()); err != nil {
   373  			return err
   374  		}
   375  		if _, err := tx.ExecContext(ctx, "DELETE from broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String()); err != nil {
   376  			return err
   377  		}
   378  		return nil
   379  	})
   380  }