github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/broadcaststore.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysql 16 17 import ( 18 "context" 19 "database/sql" 20 "encoding/hex" 21 "errors" 22 "fmt" 23 "time" 24 25 log "github.com/golang/glog" 26 "github.com/google/fleetspeak/fleetspeak/src/common" 27 "github.com/google/fleetspeak/fleetspeak/src/server/db" 28 "github.com/google/fleetspeak/fleetspeak/src/server/ids" 29 30 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 31 spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server" 32 anypb "google.golang.org/protobuf/types/known/anypb" 33 tspb "google.golang.org/protobuf/types/known/timestamppb" 34 ) 35 36 // dbBroadcast matches the schema of the broadcasts table. 37 type dbBroadcast struct { 38 broadcastID []byte 39 sourceServiceName string 40 messageType string 41 expirationTimeSeconds sql.NullInt64 42 expirationTimeNanos sql.NullInt64 43 dataTypeURL sql.NullString 44 dataValue []byte 45 sent uint64 46 allocated uint64 47 messageLimit uint64 48 } 49 50 func fromBroadcastProto(b *spb.Broadcast) (*dbBroadcast, error) { 51 if b == nil { 52 return nil, errors.New("cannot convert nil Broadcast") 53 } 54 id, err := ids.BytesToBroadcastID(b.BroadcastId) 55 if err != nil { 56 return nil, err 57 } 58 if b.Source == nil { 59 return nil, fmt.Errorf("Broadcast must have Source. Get: %v", b) 60 } 61 62 res := dbBroadcast{ 63 broadcastID: id.Bytes(), 64 sourceServiceName: b.Source.ServiceName, 65 messageType: b.MessageType, 66 } 67 if b.ExpirationTime != nil { 68 res.expirationTimeSeconds = sql.NullInt64{Int64: b.ExpirationTime.Seconds, Valid: true} 69 res.expirationTimeNanos = sql.NullInt64{Int64: int64(b.ExpirationTime.Nanos), Valid: true} 70 } 71 if b.Data != nil { 72 res.dataTypeURL = sql.NullString{String: b.Data.TypeUrl, Valid: true} 73 res.dataValue = b.Data.Value 74 } 75 return &res, nil 76 } 77 78 func toBroadcastProto(b *dbBroadcast) (*spb.Broadcast, error) { 79 bid, err := ids.BytesToBroadcastID(b.broadcastID) 80 if err != nil { 81 return nil, err 82 } 83 ret := &spb.Broadcast{ 84 BroadcastId: bid.Bytes(), 85 Source: &fspb.Address{ServiceName: b.sourceServiceName}, 86 MessageType: b.messageType, 87 } 88 if b.expirationTimeSeconds.Valid && b.expirationTimeNanos.Valid { 89 ret.ExpirationTime = &tspb.Timestamp{ 90 Seconds: b.expirationTimeSeconds.Int64, 91 Nanos: int32(b.expirationTimeNanos.Int64), 92 } 93 } 94 if b.dataTypeURL.Valid { 95 ret.Data = &anypb.Any{ 96 TypeUrl: b.dataTypeURL.String, 97 Value: b.dataValue, 98 } 99 } 100 return ret, nil 101 } 102 103 func (d *Datastore) CreateBroadcast(ctx context.Context, b *spb.Broadcast, limit uint64) error { 104 dbB, err := fromBroadcastProto(b) 105 if err != nil { 106 return err 107 } 108 dbB.messageLimit = limit 109 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 110 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcasts("+ 111 "broadcast_id, "+ 112 "source_service_name, "+ 113 "message_type, "+ 114 "expiration_time_seconds, "+ 115 "expiration_time_nanos, "+ 116 "data_type_url, "+ 117 "data_value, "+ 118 "sent, "+ 119 "allocated, "+ 120 "message_limit) "+ 121 "VALUES(?, ?, ?, ?, ?, ?, ?, 0, 0, ?)", 122 dbB.broadcastID, 123 dbB.sourceServiceName, 124 dbB.messageType, 125 dbB.expirationTimeSeconds, 126 dbB.expirationTimeNanos, 127 dbB.dataTypeURL, 128 dbB.dataValue, 129 dbB.messageLimit, 130 ); err != nil { 131 return err 132 } 133 for _, l := range b.RequiredLabels { 134 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_labels(broadcast_id, service_name, label) VALUES(?,?,?)", dbB.broadcastID, l.ServiceName, l.Label); err != nil { 135 return err 136 } 137 138 } 139 return nil 140 }) 141 } 142 143 func (d *Datastore) SetBroadcastLimit(ctx context.Context, id ids.BroadcastID, limit uint64) error { 144 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 145 _, err := tx.ExecContext(ctx, "UPDATE broadcasts(message_limit) VALUES(?) WHERE broadcast_id=?", limit, id.Bytes()) 146 return err 147 }) 148 } 149 150 func (d *Datastore) SaveBroadcastMessage(ctx context.Context, msg *fspb.Message, bID ids.BroadcastID, cID common.ClientID, aID ids.AllocationID) error { 151 dbm, err := fromMessageProto(msg) 152 if err != nil { 153 return err 154 } 155 156 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 157 var as, al uint64 158 exp := &tspb.Timestamp{} 159 r := tx.QueryRowContext(ctx, "SELECT sent, message_limit, expiration_time_seconds, expiration_time_nanos FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes()) 160 if err := r.Scan(&as, &al, &exp.Seconds, &exp.Nanos); err != nil { 161 return err 162 } 163 if as >= al { 164 return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is full: Sent: %v Limit: %v", aID, bID, as, al) 165 } 166 if err := exp.CheckValid(); err != nil { 167 return fmt.Errorf("SaveBroadcastMessage: unable to convert expiry to time: %v", err) 168 } 169 et := exp.AsTime() 170 if db.Now().After(et) { 171 return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is expired: %v", aID, bID, et) 172 } 173 174 if err := d.tryStoreMessage(ctx, tx, dbm, true); err != nil { 175 return err 176 } 177 178 if _, err := tx.ExecContext(ctx, "UPDATE broadcast_allocations SET sent = ? WHERE broadcast_id = ? AND allocation_id = ?", as+1, bID.Bytes(), aID.Bytes()); err != nil { 179 return err 180 } 181 _, err = tx.ExecContext(ctx, "INSERT INTO broadcast_sent(broadcast_id, client_id) VALUES (?, ?)", bID.Bytes(), cID.Bytes()) 182 return err 183 }) 184 } 185 186 func (d *Datastore) ListActiveBroadcasts(ctx context.Context) ([]*db.BroadcastInfo, error) { 187 var ret []*db.BroadcastInfo 188 err := d.runInTx(ctx, true, func(tx *sql.Tx) error { 189 ret = nil 190 now := db.NowProto() 191 rs, err := tx.QueryContext(ctx, "SELECT "+ 192 "broadcast_id, "+ 193 "source_service_name, "+ 194 "message_type, "+ 195 "expiration_time_seconds, "+ 196 "expiration_time_nanos, "+ 197 "data_type_url, "+ 198 "data_value, "+ 199 "sent, "+ 200 "allocated, "+ 201 "message_limit "+ 202 "FROM broadcasts "+ 203 "WHERE sent < message_limit "+ 204 "AND (expiration_time_seconds IS NULL OR (expiration_time_seconds > ?) "+ 205 "OR (expiration_time_seconds = ? "+ 206 "AND expiration_time_nanos > ?))", 207 now.Seconds, now.Seconds, now.Nanos) 208 if err != nil { 209 return err 210 } 211 defer rs.Close() 212 for rs.Next() { 213 var b dbBroadcast 214 if err := rs.Scan( 215 &b.broadcastID, 216 &b.sourceServiceName, 217 &b.messageType, 218 &b.expirationTimeSeconds, 219 &b.expirationTimeNanos, 220 &b.dataTypeURL, 221 &b.dataValue, 222 &b.sent, 223 &b.allocated, 224 &b.messageLimit, 225 ); err != nil { 226 return err 227 } 228 bp, err := toBroadcastProto(&b) 229 if err != nil { 230 log.Errorf("Failed to convert read broadcast %+v: %v", b, err) 231 return err 232 } 233 ret = append(ret, &db.BroadcastInfo{ 234 Broadcast: bp, 235 Sent: b.sent, 236 Limit: b.messageLimit, 237 }) 238 } 239 if err := rs.Err(); err != nil { 240 return err 241 } 242 rs.Close() 243 stmt, err := tx.Prepare("SELECT service_name, label FROM broadcast_labels WHERE broadcast_id = ?") 244 if err != nil { 245 return err 246 } 247 defer stmt.Close() 248 for _, i := range ret { 249 id, err := ids.BytesToBroadcastID(i.Broadcast.BroadcastId) 250 if err != nil { 251 return err 252 } 253 r, err := stmt.QueryContext(ctx, id.Bytes()) 254 if err != nil { 255 return err 256 } 257 for r.Next() { 258 l := &fspb.Label{} 259 if err := r.Scan(&l.ServiceName, &l.Label); err != nil { 260 return err 261 } 262 i.Broadcast.RequiredLabels = append(i.Broadcast.RequiredLabels, l) 263 } 264 if err := r.Err(); err != nil { 265 return err 266 } 267 } 268 return nil 269 }) 270 return ret, err 271 } 272 273 func (d *Datastore) ListSentBroadcasts(ctx context.Context, id common.ClientID) ([]ids.BroadcastID, error) { 274 rs, err := d.db.QueryContext(ctx, "SELECT broadcast_id FROM broadcast_sent WHERE client_id = ?", id.Bytes()) 275 if err != nil { 276 return nil, err 277 } 278 defer rs.Close() 279 var res []ids.BroadcastID 280 for rs.Next() { 281 var b []byte 282 err = rs.Scan(&b) 283 if err != nil { 284 return nil, err 285 } 286 bID, err := ids.BytesToBroadcastID(b) 287 if err != nil { 288 return nil, fmt.Errorf("ListSentBroadcasts: bad broadcast id [%s] for client %v: %v", hex.EncodeToString(b), id, err) 289 } 290 res = append(res, bID) 291 } 292 if err := rs.Err(); err != nil { 293 return nil, err 294 } 295 return res, nil 296 } 297 298 func (d *Datastore) CreateAllocation(ctx context.Context, id ids.BroadcastID, frac float32, expiry time.Time) (*db.AllocationInfo, error) { 299 var ret *db.AllocationInfo 300 err := d.runInTx(ctx, false, func(tx *sql.Tx) error { 301 ep := tspb.New(expiry) 302 if err := ep.CheckValid(); err != nil { 303 return err 304 } 305 aid, err := ids.RandomAllocationID() 306 if err != nil { 307 return err 308 } 309 310 var b dbBroadcast 311 r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", id.Bytes()) 312 if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil { 313 return err 314 } 315 toAllocate, newAllocated := db.ComputeBroadcastAllocation(b.messageLimit, b.allocated, b.sent, frac) 316 if toAllocate == 0 { 317 return nil 318 } 319 320 if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET allocated = ? WHERE broadcast_id = ?", newAllocated, id.Bytes()); err != nil { 321 return err 322 } 323 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_allocations("+ 324 "broadcast_id, "+ 325 "allocation_id, "+ 326 "sent, "+ 327 "message_limit, "+ 328 "expiration_time_seconds, "+ 329 "expiration_time_nanos) "+ 330 "VALUES (?, ?, 0, ?, ?, ?) ", 331 id.Bytes(), aid.Bytes(), toAllocate, ep.Seconds, ep.Nanos); err != nil { 332 return err 333 } 334 335 ret = &db.AllocationInfo{ 336 ID: aid, 337 Limit: toAllocate, 338 Expiry: expiry, 339 } 340 return nil 341 }) 342 return ret, err 343 } 344 345 func (d *Datastore) CleanupAllocation(ctx context.Context, bID ids.BroadcastID, aID ids.AllocationID) error { 346 return d.runInTx(ctx, false, func(tx *sql.Tx) error { 347 var b dbBroadcast 348 r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", bID.Bytes()) 349 if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil { 350 return err 351 } 352 353 var as, al uint64 354 r = tx.QueryRowContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes()) 355 if err := r.Scan(&as, &al); err != nil { 356 return err 357 } 358 newAllocated, err := db.ComputeBroadcastAllocationCleanup(al, b.allocated) 359 if err != nil { 360 return fmt.Errorf("unable to clear allocation [%v,%v]: %v", bID.String(), aID.String(), err) 361 } 362 if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET sent = ?, allocated = ? WHERE broadcast_id = ?", b.sent+as, newAllocated, bID.Bytes()); err != nil { 363 return err 364 } 365 if _, err := tx.ExecContext(ctx, "DELETE from broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.Bytes(), aID.Bytes()); err != nil { 366 return err 367 } 368 return nil 369 }) 370 }