github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/broadcaststore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"encoding/hex"
    21  	"errors"
    22  	"fmt"
    23  	"time"
    24  
    25  	log "github.com/golang/glog"
    26  	"github.com/google/fleetspeak/fleetspeak/src/common"
    27  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    28  	"github.com/google/fleetspeak/fleetspeak/src/server/ids"
    29  
    30  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    31  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    32  	anypb "google.golang.org/protobuf/types/known/anypb"
    33  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    34  )
    35  
    36  // dbBroadcast matches the schema of the broadcasts table.
    37  type dbBroadcast struct {
    38  	broadcastID           []byte
    39  	sourceServiceName     string
    40  	messageType           string
    41  	expirationTimeSeconds sql.NullInt64
    42  	expirationTimeNanos   sql.NullInt64
    43  	dataTypeURL           sql.NullString
    44  	dataValue             []byte
    45  	sent                  uint64
    46  	allocated             uint64
    47  	messageLimit          uint64
    48  }
    49  
    50  func fromBroadcastProto(b *spb.Broadcast) (*dbBroadcast, error) {
    51  	if b == nil {
    52  		return nil, errors.New("cannot convert nil Broadcast")
    53  	}
    54  	id, err := ids.BytesToBroadcastID(b.BroadcastId)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	if b.Source == nil {
    59  		return nil, fmt.Errorf("Broadcast must have Source. Get: %v", b)
    60  	}
    61  
    62  	res := dbBroadcast{
    63  		broadcastID:       id.Bytes(),
    64  		sourceServiceName: b.Source.ServiceName,
    65  		messageType:       b.MessageType,
    66  	}
    67  	if b.ExpirationTime != nil {
    68  		res.expirationTimeSeconds = sql.NullInt64{Int64: b.ExpirationTime.Seconds, Valid: true}
    69  		res.expirationTimeNanos = sql.NullInt64{Int64: int64(b.ExpirationTime.Nanos), Valid: true}
    70  	}
    71  	if b.Data != nil {
    72  		res.dataTypeURL = sql.NullString{String: b.Data.TypeUrl, Valid: true}
    73  		res.dataValue = b.Data.Value
    74  	}
    75  	return &res, nil
    76  }
    77  
    78  func toBroadcastProto(b *dbBroadcast) (*spb.Broadcast, error) {
    79  	bid, err := ids.BytesToBroadcastID(b.broadcastID)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	ret := &spb.Broadcast{
    84  		BroadcastId: bid.Bytes(),
    85  		Source:      &fspb.Address{ServiceName: b.sourceServiceName},
    86  		MessageType: b.messageType,
    87  	}
    88  	if b.expirationTimeSeconds.Valid && b.expirationTimeNanos.Valid {
    89  		ret.ExpirationTime = &tspb.Timestamp{
    90  			Seconds: b.expirationTimeSeconds.Int64,
    91  			Nanos:   int32(b.expirationTimeNanos.Int64),
    92  		}
    93  	}
    94  	if b.dataTypeURL.Valid {
    95  		ret.Data = &anypb.Any{
    96  			TypeUrl: b.dataTypeURL.String,
    97  			Value:   b.dataValue,
    98  		}
    99  	}
   100  	return ret, nil
   101  }
   102  
   103  func (d *Datastore) CreateBroadcast(ctx context.Context, b *spb.Broadcast, limit uint64) error {
   104  	dbB, err := fromBroadcastProto(b)
   105  	if err != nil {
   106  		return err
   107  	}
   108  	dbB.messageLimit = limit
   109  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   110  		if _, err := tx.ExecContext(ctx, "INSERT INTO broadcasts("+
   111  			"broadcast_id, "+
   112  			"source_service_name, "+
   113  			"message_type, "+
   114  			"expiration_time_seconds, "+
   115  			"expiration_time_nanos, "+
   116  			"data_type_url, "+
   117  			"data_value, "+
   118  			"sent, "+
   119  			"allocated, "+
   120  			"message_limit) "+
   121  			"VALUES(?, ?, ?, ?, ?, ?, ?, 0, 0, ?)",
   122  			dbB.broadcastID,
   123  			dbB.sourceServiceName,
   124  			dbB.messageType,
   125  			dbB.expirationTimeSeconds,
   126  			dbB.expirationTimeNanos,
   127  			dbB.dataTypeURL,
   128  			dbB.dataValue,
   129  			dbB.messageLimit,
   130  		); err != nil {
   131  			return err
   132  		}
   133  		for _, l := range b.RequiredLabels {
   134  			if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_labels(broadcast_id, service_name, label) VALUES(?,?,?)", dbB.broadcastID, l.ServiceName, l.Label); err != nil {
   135  				return err
   136  			}
   137  
   138  		}
   139  		return nil
   140  	})
   141  }
   142  
   143  func (d *Datastore) SetBroadcastLimit(ctx context.Context, id ids.BroadcastID, limit uint64) error {
   144  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   145  		_, err := tx.ExecContext(ctx, "UPDATE broadcasts(message_limit) VALUES(?) WHERE broadcast_id=?", limit, id.Bytes())
   146  		return err
   147  	})
   148  }
   149  
   150  func (d *Datastore) SaveBroadcastMessage(ctx context.Context, msg *fspb.Message, bID ids.BroadcastID, cID common.ClientID, aID ids.AllocationID) error {
   151  	dbm, err := fromMessageProto(msg)
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   157  		var as, al uint64
   158  		exp := &tspb.Timestamp{}
   159  		r := tx.QueryRowContext(ctx, "SELECT sent, message_limit, expiration_time_seconds, expiration_time_nanos FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes())
   160  		if err := r.Scan(&as, &al, &exp.Seconds, &exp.Nanos); err != nil {
   161  			return err
   162  		}
   163  		if as >= al {
   164  			return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is full: Sent: %v Limit: %v", aID, bID, as, al)
   165  		}
   166  		if err := exp.CheckValid(); err != nil {
   167  			return fmt.Errorf("SaveBroadcastMessage: unable to convert expiry to time: %v", err)
   168  		}
   169  		et := exp.AsTime()
   170  		if db.Now().After(et) {
   171  			return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is expired: %v", aID, bID, et)
   172  		}
   173  
   174  		if err := d.tryStoreMessage(ctx, tx, dbm, true); err != nil {
   175  			return err
   176  		}
   177  
   178  		if _, err := tx.ExecContext(ctx, "UPDATE broadcast_allocations SET sent = ? WHERE broadcast_id = ? AND allocation_id = ?", as+1, bID.Bytes(), aID.Bytes()); err != nil {
   179  			return err
   180  		}
   181  		_, err = tx.ExecContext(ctx, "INSERT INTO broadcast_sent(broadcast_id, client_id) VALUES (?, ?)", bID.Bytes(), cID.Bytes())
   182  		return err
   183  	})
   184  }
   185  
   186  func (d *Datastore) ListActiveBroadcasts(ctx context.Context) ([]*db.BroadcastInfo, error) {
   187  	var ret []*db.BroadcastInfo
   188  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   189  		ret = nil
   190  		now := db.NowProto()
   191  		rs, err := tx.QueryContext(ctx, "SELECT "+
   192  			"broadcast_id, "+
   193  			"source_service_name, "+
   194  			"message_type, "+
   195  			"expiration_time_seconds, "+
   196  			"expiration_time_nanos, "+
   197  			"data_type_url, "+
   198  			"data_value, "+
   199  			"sent, "+
   200  			"allocated, "+
   201  			"message_limit "+
   202  			"FROM broadcasts "+
   203  			"WHERE sent < message_limit "+
   204  			"AND (expiration_time_seconds IS NULL OR (expiration_time_seconds > ?) "+
   205  			"OR (expiration_time_seconds = ? "+
   206  			"AND expiration_time_nanos > ?))",
   207  			now.Seconds, now.Seconds, now.Nanos)
   208  		if err != nil {
   209  			return err
   210  		}
   211  		defer rs.Close()
   212  		for rs.Next() {
   213  			var b dbBroadcast
   214  			if err := rs.Scan(
   215  				&b.broadcastID,
   216  				&b.sourceServiceName,
   217  				&b.messageType,
   218  				&b.expirationTimeSeconds,
   219  				&b.expirationTimeNanos,
   220  				&b.dataTypeURL,
   221  				&b.dataValue,
   222  				&b.sent,
   223  				&b.allocated,
   224  				&b.messageLimit,
   225  			); err != nil {
   226  				return err
   227  			}
   228  			bp, err := toBroadcastProto(&b)
   229  			if err != nil {
   230  				log.Errorf("Failed to convert read broadcast %+v: %v", b, err)
   231  				return err
   232  			}
   233  			ret = append(ret, &db.BroadcastInfo{
   234  				Broadcast: bp,
   235  				Sent:      b.sent,
   236  				Limit:     b.messageLimit,
   237  			})
   238  		}
   239  		if err := rs.Err(); err != nil {
   240  			return err
   241  		}
   242  		rs.Close()
   243  		stmt, err := tx.Prepare("SELECT service_name, label FROM broadcast_labels WHERE broadcast_id = ?")
   244  		if err != nil {
   245  			return err
   246  		}
   247  		defer stmt.Close()
   248  		for _, i := range ret {
   249  			id, err := ids.BytesToBroadcastID(i.Broadcast.BroadcastId)
   250  			if err != nil {
   251  				return err
   252  			}
   253  			r, err := stmt.QueryContext(ctx, id.Bytes())
   254  			if err != nil {
   255  				return err
   256  			}
   257  			for r.Next() {
   258  				l := &fspb.Label{}
   259  				if err := r.Scan(&l.ServiceName, &l.Label); err != nil {
   260  					return err
   261  				}
   262  				i.Broadcast.RequiredLabels = append(i.Broadcast.RequiredLabels, l)
   263  			}
   264  			if err := r.Err(); err != nil {
   265  				return err
   266  			}
   267  		}
   268  		return nil
   269  	})
   270  	return ret, err
   271  }
   272  
   273  func (d *Datastore) ListSentBroadcasts(ctx context.Context, id common.ClientID) ([]ids.BroadcastID, error) {
   274  	rs, err := d.db.QueryContext(ctx, "SELECT broadcast_id FROM broadcast_sent WHERE client_id = ?", id.Bytes())
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	defer rs.Close()
   279  	var res []ids.BroadcastID
   280  	for rs.Next() {
   281  		var b []byte
   282  		err = rs.Scan(&b)
   283  		if err != nil {
   284  			return nil, err
   285  		}
   286  		bID, err := ids.BytesToBroadcastID(b)
   287  		if err != nil {
   288  			return nil, fmt.Errorf("ListSentBroadcasts: bad broadcast id [%s] for client %v: %v", hex.EncodeToString(b), id, err)
   289  		}
   290  		res = append(res, bID)
   291  	}
   292  	if err := rs.Err(); err != nil {
   293  		return nil, err
   294  	}
   295  	return res, nil
   296  }
   297  
   298  func (d *Datastore) CreateAllocation(ctx context.Context, id ids.BroadcastID, frac float32, expiry time.Time) (*db.AllocationInfo, error) {
   299  	var ret *db.AllocationInfo
   300  	err := d.runInTx(ctx, false, func(tx *sql.Tx) error {
   301  		ep := tspb.New(expiry)
   302  		if err := ep.CheckValid(); err != nil {
   303  			return err
   304  		}
   305  		aid, err := ids.RandomAllocationID()
   306  		if err != nil {
   307  			return err
   308  		}
   309  
   310  		var b dbBroadcast
   311  		r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", id.Bytes())
   312  		if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil {
   313  			return err
   314  		}
   315  		toAllocate, newAllocated := db.ComputeBroadcastAllocation(b.messageLimit, b.allocated, b.sent, frac)
   316  		if toAllocate == 0 {
   317  			return nil
   318  		}
   319  
   320  		if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET allocated = ? WHERE broadcast_id = ?", newAllocated, id.Bytes()); err != nil {
   321  			return err
   322  		}
   323  		if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_allocations("+
   324  			"broadcast_id, "+
   325  			"allocation_id, "+
   326  			"sent, "+
   327  			"message_limit, "+
   328  			"expiration_time_seconds, "+
   329  			"expiration_time_nanos) "+
   330  			"VALUES (?, ?, 0, ?, ?, ?) ",
   331  			id.Bytes(), aid.Bytes(), toAllocate, ep.Seconds, ep.Nanos); err != nil {
   332  			return err
   333  		}
   334  
   335  		ret = &db.AllocationInfo{
   336  			ID:     aid,
   337  			Limit:  toAllocate,
   338  			Expiry: expiry,
   339  		}
   340  		return nil
   341  	})
   342  	return ret, err
   343  }
   344  
   345  func (d *Datastore) CleanupAllocation(ctx context.Context, bID ids.BroadcastID, aID ids.AllocationID) error {
   346  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   347  		var b dbBroadcast
   348  		r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", bID.Bytes())
   349  		if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil {
   350  			return err
   351  		}
   352  
   353  		var as, al uint64
   354  		r = tx.QueryRowContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes())
   355  		if err := r.Scan(&as, &al); err != nil {
   356  			return err
   357  		}
   358  		newAllocated, err := db.ComputeBroadcastAllocationCleanup(al, b.allocated)
   359  		if err != nil {
   360  			return fmt.Errorf("unable to clear allocation [%v,%v]: %v", bID.String(), aID.String(), err)
   361  		}
   362  		if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET sent = ?, allocated = ? WHERE broadcast_id = ?", b.sent+as, newAllocated, bID.Bytes()); err != nil {
   363  			return err
   364  		}
   365  		if _, err := tx.ExecContext(ctx, "DELETE from broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes()); err != nil {
   366  			return err
   367  		}
   368  		return nil
   369  	})
   370  }