github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/https/streaming_message_server.go (about) 1 // Copyright 2018 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package https 16 17 import ( 18 "bufio" 19 "context" 20 "crypto" 21 "encoding/binary" 22 "errors" 23 "fmt" 24 "io" 25 "math" 26 "math/rand" 27 "net" 28 "net/http" 29 "sync" 30 "time" 31 32 log "github.com/golang/glog" 33 "github.com/google/fleetspeak/fleetspeak/src/common" 34 "github.com/google/fleetspeak/fleetspeak/src/server/comms" 35 "github.com/google/fleetspeak/fleetspeak/src/server/db" 36 "github.com/google/fleetspeak/fleetspeak/src/server/stats" 37 "google.golang.org/protobuf/proto" 38 39 fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak" 40 ) 41 42 const magic = uint32(0xf1ee1001) 43 const baseErrorDelay = float64(100 * time.Millisecond) 44 45 type fullResponseWriter interface { 46 http.ResponseWriter 47 http.CloseNotifier 48 http.Flusher 49 } 50 51 func readUint32(body *bufio.Reader) (uint32, error) { 52 b := make([]byte, 4) 53 if _, err := io.ReadAtLeast(body, b, 4); err != nil { 54 return 0, err 55 } 56 return binary.LittleEndian.Uint32(b), nil 57 } 58 59 func writeUint32(res fullResponseWriter, i uint32) error { 60 return binary.Write(res, binary.LittleEndian, i) 61 } 62 63 func newStreamingMessageServer(c *Communicator, maxPerClientBatchProcessors uint32) *streamingMessageServer { 64 return &streamingMessageServer{c, maxPerClientBatchProcessors} 65 } 66 67 // messageServer wraps a Communicator in order to handle clients polls. 68 type streamingMessageServer struct { 69 *Communicator 70 maxPerClientBatchProcessors uint32 71 } 72 73 func (s *streamingMessageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) { 74 earlyError := func(msg string, status int) { 75 log.ErrorDepth(1, fmt.Sprintf("%s: %s", http.StatusText(status), msg)) 76 s.fs.StatsCollector().ClientPoll(stats.PollInfo{ 77 CTX: req.Context(), 78 Start: db.Now(), 79 End: db.Now(), 80 Status: status, 81 Type: stats.StreamStart, 82 }) 83 } 84 85 if !s.startProcessing() { 86 earlyError("server not ready", http.StatusInternalServerError) 87 return 88 } 89 defer s.stopProcessing() 90 91 fullRes, ok := res.(fullResponseWriter) 92 if !ok { 93 earlyError("/streaming-message requested, but not supported. ResponseWriter is not a fullResponseWriter", http.StatusNotFound) 94 return 95 } 96 97 if req.Method != http.MethodPost { 98 earlyError(fmt.Sprintf("%v not supported", req.Method), http.StatusBadRequest) 99 return 100 } 101 102 cert, err := GetClientCert(req, s.p.FrontendConfig) 103 104 if err != nil { 105 earlyError(err.Error(), http.StatusBadRequest) 106 return 107 } 108 109 if cert.PublicKey == nil { 110 earlyError("public key not present in client cert", http.StatusBadRequest) 111 return 112 } 113 114 body := bufio.NewReader(req.Body) 115 116 // Set a 9-11 minute overall maximum lifespan of the connection. 117 ctx, fin := context.WithTimeout(req.Context(), s.p.StreamingLifespan+time.Duration(float32(s.p.StreamingJitter)*rand.Float32())) 118 defer fin() 119 120 // Also create a way to terminate early in case of error. 121 ctx, cancel := context.WithCancel(ctx) 122 defer cancel() 123 124 addr := addrFromString(req.RemoteAddr) 125 info, moreMsgs, err := s.initialPoll(ctx, addr, cert.PublicKey, fullRes, body) 126 if err != nil || info == nil { 127 return 128 } 129 130 m := streamManager{ 131 ctx: ctx, 132 s: s, 133 info: info, 134 res: fullRes, 135 body: body, 136 137 localNotices: make(chan struct{}, 1), 138 out: make(chan *fspb.ContactData, 5), 139 140 cancel: cancel, 141 } 142 defer func() { 143 // Shutdown is a bit subtle. 144 // 145 // We get here iff m.ctx is canceled, timed out, etc. 146 // 147 // Once ctx is canceled, writeLoop will notice, close the outgoing 148 // ResponseWriter, and begin blindly draining m.out. 149 // 150 // Closing ResponseWriter will cause any pending read to error out and 151 // readLoop to return. 152 // 153 // Once the readLoop returns, we can safely close m.out and wait for 154 // writeLoop to finish. 155 info.Fin() 156 m.reading.Wait() 157 close(m.out) 158 m.writing.Wait() 159 }() 160 161 m.reading.Add(2) 162 go m.readLoop() 163 go m.notifyLoop(s.p.StreamingCloseTime, moreMsgs) 164 165 m.writing.Add(1) 166 go m.writeLoop() 167 168 select { 169 case <-ctx.Done(): 170 case <-fullRes.CloseNotify(): 171 case <-s.stopping: 172 } 173 m.cancel() 174 } 175 176 func (s *streamingMessageServer) initialPoll(ctx context.Context, addr net.Addr, key crypto.PublicKey, res fullResponseWriter, body *bufio.Reader) (*comms.ConnectionInfo, bool, error) { 177 ctx, fin := context.WithTimeout(ctx, 3*time.Minute) 178 179 pi := stats.PollInfo{ 180 CTX: ctx, 181 Start: db.Now(), 182 Status: http.StatusTeapot, // Should never actually be returned 183 Type: stats.StreamStart, 184 } 185 defer func() { 186 fin() 187 if pi.Status == http.StatusTeapot { 188 log.Errorf("Forgot to set status, PollInfo: %v", pi) 189 } 190 pi.End = db.Now() 191 s.fs.StatsCollector().ClientPoll(pi) 192 }() 193 194 makeError := func(msg string, status int) error { 195 log.ErrorDepth(1, fmt.Sprintf("%s: [id:%v addr:%v] %s", http.StatusText(status), pi.ID, addr, msg)) 196 pi.Status = status 197 return errors.New(msg) 198 } 199 200 id, err := common.MakeClientID(key) 201 if err != nil { 202 return nil, false, makeError(fmt.Sprintf("unable to create client id from public key: %v", err), http.StatusBadRequest) 203 } 204 pi.ID = id 205 206 m, err := readUint32(body) 207 if err != nil { 208 return nil, false, makeError(fmt.Sprintf("error reading magic number: %v", err), http.StatusBadRequest) 209 } 210 if m != magic { 211 return nil, false, makeError(fmt.Sprintf("unknown magic number: got %x, expected %x", m, magic), http.StatusBadRequest) 212 } 213 214 st := time.Now() 215 size, err := binary.ReadUvarint(body) 216 if err != nil { 217 return nil, false, makeError(fmt.Sprintf("error reading size: %v", err), http.StatusBadRequest) 218 } 219 if size > MaxContactSize { 220 return nil, false, makeError(fmt.Sprintf("initial contact size too large: got %d, expected at most %d", size, MaxContactSize), http.StatusBadRequest) 221 } 222 223 buf := make([]byte, size) 224 _, err = io.ReadFull(body, buf) 225 if err != nil { 226 return nil, false, makeError(fmt.Sprintf("error reading body for initial exchange: %v", err), http.StatusBadRequest) 227 } 228 pi.ReadTime = time.Since(st) 229 pi.ReadBytes = int(size) 230 231 var wcd fspb.WrappedContactData 232 if err := proto.Unmarshal(buf, &wcd); err != nil { 233 return nil, false, makeError(fmt.Sprintf("error parsing body: %v", err), http.StatusBadRequest) 234 } 235 236 info, toSend, more, err := s.fs.InitializeConnection(ctx, addr, key, &wcd, true) 237 if err == comms.ErrNotAuthorized { 238 return nil, false, makeError("not authorized", http.StatusServiceUnavailable) 239 } 240 if err != nil { 241 return nil, false, makeError(fmt.Sprintf("error processing contact: %v", err), http.StatusInternalServerError) 242 } 243 pi.CacheHit = info.Client.Cached 244 245 outBuf, err := proto.Marshal(toSend) 246 if err != nil { 247 info.Fin() 248 return nil, false, makeError(fmt.Sprintf("error preparing messages: %v", err), http.StatusInternalServerError) 249 } 250 sizeBuf := make([]byte, 0, 16) 251 sizeBuf = binary.AppendUvarint(sizeBuf, uint64(len(outBuf))) 252 253 st = time.Now() 254 sizeWritten, err := res.Write(sizeBuf) 255 if err != nil { 256 info.Fin() 257 return nil, false, makeError(fmt.Sprintf("error writing body: %v", err), http.StatusInternalServerError) 258 } 259 bufWritten, err := res.Write(outBuf) 260 if err != nil { 261 info.Fin() 262 return nil, false, makeError(fmt.Sprintf("error writing body: %v", err), http.StatusInternalServerError) 263 } 264 res.Flush() 265 266 pi.WriteTime = time.Since(st) 267 pi.End = time.Now() 268 pi.WriteBytes = sizeWritten + bufWritten 269 pi.Status = http.StatusOK 270 return info, more, nil 271 } 272 273 type streamManager struct { 274 ctx context.Context 275 s *streamingMessageServer 276 277 info *comms.ConnectionInfo 278 res fullResponseWriter 279 body *bufio.Reader 280 281 // Signals that a we have more tokens and might retry sending. 282 localNotices chan struct{} 283 284 // The read- and writeLoop will wait for these. Separate because readloop 285 // needs to finish before writeLoop. 286 reading sync.WaitGroup 287 writing sync.WaitGroup 288 289 out chan *fspb.ContactData 290 291 cancel func() // Shuts down the stream when called. 292 } 293 294 func (m *streamManager) readLoop() { 295 defer m.reading.Done() 296 defer m.cancel() 297 298 cnt := uint64(0) 299 300 // Number of batches from the same client that will be processed concurrently. 301 const maxBatchProcessors = 10 302 batchCh := make(chan *fspb.WrappedContactData, m.s.maxPerClientBatchProcessors) 303 304 for { 305 pi, wcd, err := m.readOne() 306 if err != nil { 307 // If the context has been canceled, it is probably a 'normal' termination 308 // - disconnect, max connection durating, etc. But if it is still active, 309 // we are going to tear down everything because of an unexpected read 310 // error and should log/record why. 311 if m.ctx.Err() == nil && pi != nil { 312 m.s.fs.StatsCollector().ClientPoll(*pi) 313 log.Errorf("Streaming Connection to %v terminated with error: %v", m.info.Client.ID, err) 314 } 315 return 316 } 317 318 // Increment the counter with every processed message. 319 cnt++ 320 321 // This will block if number of concurrent processors is greater than maxBatchProcessors. 322 batchCh <- wcd 323 // Ensure the m.out stays open while the message processing is not done. 324 m.reading.Add(1) 325 // Given that the processing is done concurrently, capture the current counter value in 326 // the function argument. 327 go func(curCnt uint64) { 328 defer m.reading.Done() 329 330 wcd := <-batchCh 331 if err := m.processOne(wcd); err != nil { 332 log.Errorf("Error processing message from %v: %v", m.info.Client.ID, err) 333 return 334 } 335 m.out <- &fspb.ContactData{AckIndex: curCnt} 336 }(cnt) 337 338 m.s.fs.StatsCollector().ClientPoll(*pi) 339 } 340 } 341 342 func (m *streamManager) readOne() (*stats.PollInfo, *fspb.WrappedContactData, error) { 343 size, err := binary.ReadUvarint(m.body) 344 if err != nil { 345 return nil, nil, err 346 } 347 if size > MaxContactSize { 348 return nil, nil, fmt.Errorf("streaming contact size too large: got %d, expected at most %d", size, MaxContactSize) 349 } 350 351 pi := &stats.PollInfo{ 352 CTX: m.ctx, 353 ID: m.info.Client.ID, 354 Start: db.Now(), 355 Status: http.StatusTeapot, 356 CacheHit: true, 357 Type: stats.StreamFromClient, 358 } 359 defer func() { 360 if pi.Status == http.StatusTeapot { 361 log.Errorf("Forgot to set status.") 362 } 363 pi.End = db.Now() 364 }() 365 buf := make([]byte, size) 366 if _, err := io.ReadFull(m.body, buf); err != nil { 367 pi.Status = http.StatusBadRequest 368 return pi, nil, fmt.Errorf("error reading streamed data: %v", err) 369 } 370 pi.ReadTime = time.Since(pi.Start) 371 pi.ReadBytes = int(size) 372 373 wcd := &fspb.WrappedContactData{} 374 if err = proto.Unmarshal(buf, wcd); err != nil { 375 pi.Status = http.StatusBadRequest 376 return pi, nil, fmt.Errorf("error parsing streamed data: %v", err) 377 } 378 379 // Validate message early to provide feedback to the agent and fail with a 380 // descriptive HTTP code. 381 _, err = m.s.fs.ValidateMessagesFromClient(context.Background(), m.info, wcd) 382 if err != nil { 383 pi.Status = http.StatusServiceUnavailable 384 return pi, nil, fmt.Errorf("message validation failed: %v", err) 385 } 386 387 pi.Status = http.StatusOK 388 return pi, wcd, nil 389 } 390 391 func (m *streamManager) processOne(wcd *fspb.WrappedContactData) error { 392 var blockedServices []string 393 for k, v := range m.info.MessageTokens() { 394 if v == 0 { 395 blockedServices = append(blockedServices, k) 396 } 397 } 398 // We might be close to the connection's natural end. Accept up to 15 399 // seconds of overrun trying to process what we've been given. This 400 // should only happen when things are unexpectedly slow and likely 401 // causes duplicate messages. 402 ctx, fin := context.WithCancel(context.Background()) 403 go func() { 404 defer fin() 405 select { 406 case <-ctx.Done(): 407 return 408 case <-m.ctx.Done(): 409 log.Warningf("Extra time required while processing message from %v.", m.info.Client.ID) 410 t := time.NewTimer(15 * time.Second) 411 defer t.Stop() 412 select { 413 case <-ctx.Done(): 414 return 415 case <-t.C: 416 return 417 } 418 } 419 }() 420 err := m.s.fs.HandleMessagesFromClient(ctx, m.info, wcd) 421 fin() 422 if err != nil { 423 if err == comms.ErrNotAuthorized { 424 log.Infof("Message not authoried: %v", err) 425 } else { 426 err = fmt.Errorf("error processing streamed messages: %v", err) 427 } 428 return err 429 } 430 tokens := m.info.MessageTokens() 431 for _, s := range blockedServices { 432 if tokens[s] > 0 { 433 select { 434 case m.localNotices <- struct{}{}: 435 default: 436 } 437 } 438 } 439 return nil 440 } 441 442 func (m *streamManager) notifyLoop(closeTime time.Duration, moreMsgs bool) { 443 defer m.reading.Done() 444 445 // Stop sending messages to the client closeTime (e.g. 30 sec) before our hard deadline. 446 d, ok := m.ctx.Deadline() 447 if !ok { 448 // Shouldn't happen, ctx is created with a deadline. 449 log.Fatalf("m.ctx does not have a deadline set") 450 } 451 deadline := d.Add(-closeTime) 452 stop := time.NewTimer(time.Until(deadline)) 453 defer stop.Stop() 454 455 // Number of sequential errors getting messages for the client. 456 var errCnt int 457 458 for { 459 // This switch decides how long we should wait before trying to 460 // get more messages for the client, and returns when it is time 461 // to stop. 462 switch { 463 case errCnt > 0: 464 // Last attempt to get messages failed - try again with 465 // a jittery exponential backoff in the hopes that the 466 // database recovers. 467 errDelay := time.Duration((baseErrorDelay + rand.Float64()*baseErrorDelay) * math.Pow(1.5, float64(errCnt))) 468 t := time.NewTimer(errDelay) 469 log.V(1).Infof("NotifyLoop(%v): waiting %v due to previous error.", m.info.Client.ID, errDelay) 470 select { 471 case <-m.ctx.Done(): 472 t.Stop() 473 return 474 case <-stop.C: 475 t.Stop() 476 m.out <- &fspb.ContactData{DoneSending: true} 477 return 478 case <-t.C: 479 } 480 case moreMsgs: 481 // We believe that there are more messages already 482 // available, just check if it is time to shutdown. 483 log.V(1).Infof("NotifyLoop(%v): continuing, more messages possible.", m.info.Client.ID) 484 if time.Now().After(deadline) { 485 m.out <- &fspb.ContactData{DoneSending: true} 486 return 487 } 488 if m.ctx.Err() != nil { 489 return 490 } 491 default: 492 // Wait for a notification, then wait 1 more second in 493 // case more messages arrive. 494 log.V(1).Infof("NotifyLoop(%v): waiting for notifications.", m.info.Client.ID) 495 select { 496 case <-m.ctx.Done(): 497 return 498 case <-stop.C: 499 m.out <- &fspb.ContactData{DoneSending: true} 500 return 501 case _, ok := <-m.info.Notices: 502 if !ok { 503 return 504 } 505 case <-m.localNotices: 506 } 507 t := time.NewTimer(time.Second) 508 L: 509 for { 510 select { 511 case <-m.ctx.Done(): 512 return 513 case _, ok := <-m.info.Notices: 514 if !ok { 515 break L 516 } 517 continue L 518 case <-t.C: 519 break L 520 } 521 } 522 t.Stop() 523 } 524 var cd *fspb.ContactData 525 var err error 526 cd, moreMsgs, err = m.s.fs.GetMessagesForClient(m.ctx, m.info) 527 if err != nil { 528 if err == m.ctx.Err() { 529 return 530 } 531 log.Errorf("Error getting messages for streaming client [%v]: %v", m.info.Client.ID, err) 532 errCnt++ 533 } else { 534 errCnt = 0 535 } 536 if cd != nil { 537 m.out <- cd 538 } 539 } 540 } 541 542 func (m *streamManager) writeLoop() { 543 defer m.writing.Done() 544 defer func() { 545 for range m.out { 546 } 547 }() 548 549 for { 550 select { 551 case cd, ok := <-m.out: 552 if !ok { 553 return 554 } 555 pi, err := m.writeOne(cd) 556 if err != nil { 557 if m.ctx.Err() != nil { 558 log.Errorf("Error sending ContactData to client [%v]: %v", m.info.Client.ID, err) 559 m.cancel() 560 m.s.fs.StatsCollector().ClientPoll(pi) 561 } 562 // ctx was already canceled - more or less normal shutdown, so don't log 563 // as a poll. 564 return 565 } 566 if len(cd.Messages) > 0 { 567 m.s.fs.StatsCollector().ClientPoll(pi) 568 } 569 case <-m.ctx.Done(): 570 return 571 } 572 } 573 } 574 575 func (m *streamManager) writeOne(cd *fspb.ContactData) (stats.PollInfo, error) { 576 pi := stats.PollInfo{ 577 CTX: m.ctx, 578 ID: m.info.Client.ID, 579 Start: db.Now(), 580 Status: http.StatusTeapot, 581 CacheHit: true, 582 Type: stats.StreamToClient, 583 } 584 defer func() { 585 if pi.Status == http.StatusTeapot { 586 log.Errorf("Forgot to set status.") 587 } 588 pi.End = db.Now() 589 }() 590 591 buf, err := proto.Marshal(cd) 592 if err != nil { 593 return pi, err 594 } 595 sizeBuf := make([]byte, 0, 16) 596 sizeBuf = binary.AppendUvarint(sizeBuf, uint64(len(buf))) 597 598 sw := time.Now() 599 sizeWritten, err := m.res.Write(sizeBuf) 600 if err != nil { 601 return pi, err 602 } 603 bufWritten, err := m.res.Write(buf) 604 if err != nil { 605 return pi, err 606 } 607 m.res.Flush() 608 pi.WriteTime = time.Since(sw) 609 pi.WriteBytes = sizeWritten + bufWritten 610 pi.Status = http.StatusOK 611 612 return pi, nil 613 }