github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/sqlite/broadcaststore.go (about) 1 // Copyright 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqlite 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "fmt" 22 "time" 23 24 "github.com/google/fleetspeak/fleetspeak/src/common" 25 "github.com/google/fleetspeak/fleetspeak/src/server/db" 26 "github.com/google/fleetspeak/fleetspeak/src/server/ids" 27 28 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 29 spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server" 30 anypb "google.golang.org/protobuf/types/known/anypb" 31 tspb "google.golang.org/protobuf/types/known/timestamppb" 32 ) 33 34 // dbBroadcast matches the schema of the broadcasts table. 35 type dbBroadcast struct { 36 broadcastID string 37 sourceServiceName string 38 messageType string 39 expirationTimeSeconds sql.NullInt64 40 expirationTimeNanos sql.NullInt64 41 dataTypeURL sql.NullString 42 dataValue []byte 43 sent uint64 44 allocated uint64 45 messageLimit uint64 46 } 47 48 func fromBroadcastProto(b *spb.Broadcast) (*dbBroadcast, error) { 49 if b == nil { 50 return nil, errors.New("cannot convert nil Broadcast") 51 } 52 id, err := ids.BytesToBroadcastID(b.BroadcastId) 53 if err != nil { 54 return nil, err 55 } 56 if b.Source == nil { 57 return nil, fmt.Errorf("Broadcast must have Source. Get: %v", b) 58 } 59 60 res := dbBroadcast{ 61 broadcastID: id.String(), 62 sourceServiceName: b.Source.ServiceName, 63 messageType: b.MessageType, 64 } 65 if b.ExpirationTime != nil { 66 res.expirationTimeSeconds = sql.NullInt64{Int64: b.ExpirationTime.Seconds, Valid: true} 67 res.expirationTimeNanos = sql.NullInt64{Int64: int64(b.ExpirationTime.Nanos), Valid: true} 68 } 69 if b.Data != nil { 70 res.dataTypeURL = sql.NullString{String: b.Data.TypeUrl, Valid: true} 71 res.dataValue = b.Data.Value 72 } 73 return &res, nil 74 } 75 76 func toBroadcastProto(b *dbBroadcast) (*spb.Broadcast, error) { 77 bid, err := ids.StringToBroadcastID(b.broadcastID) 78 if err != nil { 79 return nil, err 80 } 81 ret := &spb.Broadcast{ 82 BroadcastId: bid.Bytes(), 83 Source: &fspb.Address{ServiceName: b.sourceServiceName}, 84 MessageType: b.messageType, 85 } 86 if b.expirationTimeSeconds.Valid && b.expirationTimeNanos.Valid { 87 ret.ExpirationTime = &tspb.Timestamp{ 88 Seconds: b.expirationTimeSeconds.Int64, 89 Nanos: int32(b.expirationTimeNanos.Int64), 90 } 91 } 92 if b.dataTypeURL.Valid { 93 ret.Data = &anypb.Any{ 94 TypeUrl: b.dataTypeURL.String, 95 Value: b.dataValue, 96 } 97 } 98 return ret, nil 99 } 100 101 func (d *Datastore) CreateBroadcast(ctx context.Context, b *spb.Broadcast, limit uint64) error { 102 d.l.Lock() 103 defer d.l.Unlock() 104 dbB, err := fromBroadcastProto(b) 105 if err != nil { 106 return err 107 } 108 dbB.messageLimit = limit 109 return d.runInTx(func(tx *sql.Tx) error { 110 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcasts("+ 111 "broadcast_id, "+ 112 "source_service_name, "+ 113 "message_type, "+ 114 "expiration_time_seconds, "+ 115 "expiration_time_nanos, "+ 116 "data_type_url, "+ 117 "data_value, "+ 118 "sent, "+ 119 "allocated, "+ 120 "message_limit) "+ 121 "VALUES(?, ?, ?, ?, ?, ?, ?, 0, 0, ?)", 122 dbB.broadcastID, 123 dbB.sourceServiceName, 124 dbB.messageType, 125 dbB.expirationTimeSeconds, 126 dbB.expirationTimeNanos, 127 dbB.dataTypeURL, 128 dbB.dataValue, 129 dbB.messageLimit, 130 ); err != nil { 131 return err 132 } 133 for _, l := range b.RequiredLabels { 134 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_labels(broadcast_id, service_name, label) VALUES(?,?,?)", dbB.broadcastID, l.ServiceName, l.Label); err != nil { 135 return err 136 } 137 138 } 139 return nil 140 }) 141 } 142 143 func (d *Datastore) SetBroadcastLimit(ctx context.Context, id ids.BroadcastID, limit uint64) error { 144 d.l.Lock() 145 defer d.l.Unlock() 146 return d.runInTx(func(tx *sql.Tx) error { 147 _, err := tx.ExecContext(ctx, "UPDATE broadcasts(message_limit) VALUES(?) WHERE broadcast_id=?", limit, id.String()) 148 return err 149 }) 150 } 151 152 func (d *Datastore) SaveBroadcastMessage(ctx context.Context, msg *fspb.Message, bID ids.BroadcastID, cID common.ClientID, aID ids.AllocationID) error { 153 d.l.Lock() 154 defer d.l.Unlock() 155 dbm, err := fromMessageProto(msg) 156 if err != nil { 157 return err 158 } 159 160 return d.runInTx(func(tx *sql.Tx) error { 161 var as, al uint64 162 exp := &tspb.Timestamp{} 163 r := tx.QueryRowContext(ctx, "SELECT sent, message_limit, expiration_time_seconds, expiration_time_nanos FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String()) 164 if err := r.Scan(&as, &al, &exp.Seconds, &exp.Nanos); err != nil { 165 return err 166 } 167 if as >= al { 168 return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is full: Sent: %v Limit: %v", aID, bID, as, al) 169 } 170 if err := exp.CheckValid(); err != nil { 171 return fmt.Errorf("SaveBroadcastMessage: unable to convert expiry to time: %v", err) 172 } 173 et := exp.AsTime() 174 if db.Now().After(et) { 175 return fmt.Errorf("SaveBroadcastMessage: broadcast allocation [%v, %v] is expired: %v", aID, bID, et) 176 } 177 178 if err := d.tryStoreMessage(ctx, tx, dbm, true); err != nil { 179 return err 180 } 181 182 if _, err := tx.ExecContext(ctx, "UPDATE broadcast_allocations SET sent = ? WHERE broadcast_id = ? AND allocation_id = ?", as+1, bID.String(), aID.String()); err != nil { 183 return err 184 } 185 _, err = tx.ExecContext(ctx, "INSERT INTO broadcast_sent(broadcast_id, client_id) VALUES (?, ?)", bID.String(), cID.String()) 186 return err 187 }) 188 } 189 190 func (d *Datastore) ListActiveBroadcasts(ctx context.Context) ([]*db.BroadcastInfo, error) { 191 d.l.Lock() 192 defer d.l.Unlock() 193 var ret []*db.BroadcastInfo 194 err := d.runInTx(func(tx *sql.Tx) error { 195 now := db.NowProto() 196 rs, err := tx.QueryContext(ctx, "SELECT "+ 197 "broadcast_id, "+ 198 "source_service_name, "+ 199 "message_type, "+ 200 "expiration_time_seconds, "+ 201 "expiration_time_nanos, "+ 202 "data_type_url, "+ 203 "data_value, "+ 204 "sent, "+ 205 "allocated, "+ 206 "message_limit "+ 207 "FROM broadcasts "+ 208 "WHERE sent < message_limit "+ 209 "AND (expiration_time_seconds IS NULL OR (expiration_time_seconds > ?) "+ 210 "OR (expiration_time_seconds = ? "+ 211 "AND expiration_time_nanos > ?))", 212 now.Seconds, now.Seconds, now.Nanos) 213 if err != nil { 214 return err 215 } 216 defer rs.Close() 217 for rs.Next() { 218 var b dbBroadcast 219 if err := rs.Scan( 220 &b.broadcastID, 221 &b.sourceServiceName, 222 &b.messageType, 223 &b.expirationTimeSeconds, 224 &b.expirationTimeNanos, 225 &b.dataTypeURL, 226 &b.dataValue, 227 &b.sent, 228 &b.allocated, 229 &b.messageLimit, 230 ); err != nil { 231 return err 232 } 233 bp, err := toBroadcastProto(&b) 234 if err != nil { 235 return err 236 } 237 ret = append(ret, &db.BroadcastInfo{ 238 Broadcast: bp, 239 Sent: b.sent, 240 Limit: b.messageLimit, 241 }) 242 } 243 if err := rs.Err(); err != nil { 244 return err 245 } 246 rs.Close() 247 stmt, err := tx.Prepare("SELECT service_name, label FROM broadcast_labels WHERE broadcast_id = ?") 248 if err != nil { 249 return err 250 } 251 defer stmt.Close() 252 for _, i := range ret { 253 id, err := ids.BytesToBroadcastID(i.Broadcast.BroadcastId) 254 if err != nil { 255 return err 256 } 257 r, err := stmt.QueryContext(ctx, id.String()) 258 if err != nil { 259 return err 260 } 261 for r.Next() { 262 l := &fspb.Label{} 263 if err := r.Scan(&l.ServiceName, &l.Label); err != nil { 264 return err 265 } 266 i.Broadcast.RequiredLabels = append(i.Broadcast.RequiredLabels, l) 267 } 268 if err := r.Err(); err != nil { 269 return err 270 } 271 } 272 return nil 273 }) 274 return ret, err 275 } 276 277 func (d *Datastore) ListSentBroadcasts(ctx context.Context, id common.ClientID) ([]ids.BroadcastID, error) { 278 d.l.Lock() 279 defer d.l.Unlock() 280 rs, err := d.db.QueryContext(ctx, "SELECT broadcast_id FROM broadcast_sent WHERE client_id = ?", id.String()) 281 if err != nil { 282 return nil, err 283 } 284 defer rs.Close() 285 var res []ids.BroadcastID 286 for rs.Next() { 287 var b string 288 err = rs.Scan(&b) 289 if err != nil { 290 return nil, err 291 } 292 bID, err := ids.StringToBroadcastID(b) 293 if err != nil { 294 return nil, fmt.Errorf("ListSentBroadcasts: bad broadcast id for client %v: %v", id, b) 295 } 296 res = append(res, bID) 297 } 298 if err := rs.Err(); err != nil { 299 return nil, err 300 } 301 return res, nil 302 } 303 304 func (d *Datastore) CreateAllocation(ctx context.Context, id ids.BroadcastID, frac float32, expiry time.Time) (*db.AllocationInfo, error) { 305 d.l.Lock() 306 defer d.l.Unlock() 307 var ret *db.AllocationInfo 308 err := d.runInTx(func(tx *sql.Tx) error { 309 ep := tspb.New(expiry) 310 if err := ep.CheckValid(); err != nil { 311 return err 312 } 313 aid, err := ids.RandomAllocationID() 314 if err != nil { 315 return err 316 } 317 318 var b dbBroadcast 319 r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", id.String()) 320 if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil { 321 return err 322 } 323 toAllocate, newAllocated := db.ComputeBroadcastAllocation(b.messageLimit, b.allocated, b.sent, frac) 324 if toAllocate == 0 { 325 return nil 326 } 327 328 if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET allocated = ? WHERE broadcast_id = ?", newAllocated, id.String()); err != nil { 329 return err 330 } 331 if _, err := tx.ExecContext(ctx, "INSERT INTO broadcast_allocations("+ 332 "broadcast_id, "+ 333 "allocation_id, "+ 334 "sent, "+ 335 "message_limit, "+ 336 "expiration_time_seconds, "+ 337 "expiration_time_nanos) "+ 338 "VALUES (?, ?, 0, ?, ?, ?) ", 339 id.String(), aid.String(), toAllocate, ep.Seconds, ep.Nanos); err != nil { 340 return err 341 } 342 343 ret = &db.AllocationInfo{ 344 ID: aid, 345 Limit: toAllocate, 346 Expiry: expiry, 347 } 348 return nil 349 }) 350 return ret, err 351 } 352 353 func (d *Datastore) CleanupAllocation(ctx context.Context, bID ids.BroadcastID, aID ids.AllocationID) error { 354 d.l.Lock() 355 defer d.l.Unlock() 356 return d.runInTx(func(tx *sql.Tx) error { 357 var b dbBroadcast 358 r := tx.QueryRowContext(ctx, "SELECT sent, allocated, message_limit FROM broadcasts WHERE broadcast_id = ?", bID.String()) 359 if err := r.Scan(&b.sent, &b.allocated, &b.messageLimit); err != nil { 360 return err 361 } 362 363 var as, al uint64 364 r = tx.QueryRowContext(ctx, "SELECT sent, message_limit FROM broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String()) 365 if err := r.Scan(&as, &al); err != nil { 366 return err 367 } 368 newAllocated, err := db.ComputeBroadcastAllocationCleanup(al, b.allocated) 369 if err != nil { 370 return fmt.Errorf("unable to clear allocation [%v,%v]: %v", bID.String(), aID.String(), err) 371 } 372 if _, err := tx.ExecContext(ctx, "UPDATE broadcasts SET sent = ?, allocated = ? WHERE broadcast_id = ?", b.sent+as, newAllocated, bID.String()); err != nil { 373 return err 374 } 375 if _, err := tx.ExecContext(ctx, "DELETE from broadcast_allocations WHERE broadcast_id = ? AND allocation_id = ?", bID.String(), aID.String()); err != nil { 376 return err 377 } 378 return nil 379 }) 380 }