github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/auditstreamer.go (about)

     1  /*
     2  Copyright 2020 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"sync"
    22  
    23  	"github.com/gravitational/trace"
    24  	"google.golang.org/grpc"
    25  	ggzip "google.golang.org/grpc/encoding/gzip"
    26  
    27  	"github.com/gravitational/teleport/api/client/proto"
    28  	"github.com/gravitational/teleport/api/types/events"
    29  )
    30  
    31  // createOrResumeAuditStream creates or resumes audit stream described in the request.
    32  func (c *Client) createOrResumeAuditStream(ctx context.Context, request proto.AuditStreamRequest) (events.Stream, error) {
    33  	closeCtx, cancel := context.WithCancel(ctx)
    34  	stream, err := c.grpc.CreateAuditStream(closeCtx, grpc.UseCompressor(ggzip.Name))
    35  	if err != nil {
    36  		cancel()
    37  		return nil, trace.Wrap(err)
    38  	}
    39  	s := &auditStreamer{
    40  		stream:   stream,
    41  		statusCh: make(chan events.StreamStatus, 1),
    42  		closeCtx: closeCtx,
    43  		cancel:   cancel,
    44  	}
    45  	go s.recv()
    46  	err = s.stream.Send(&request)
    47  	if err != nil {
    48  		return nil, trace.NewAggregate(s.Close(ctx), trace.Wrap(err))
    49  	}
    50  	return s, nil
    51  }
    52  
    53  // ResumeAuditStream resumes existing audit stream.
    54  func (c *Client) ResumeAuditStream(ctx context.Context, sessionID, uploadID string) (events.Stream, error) {
    55  	return c.createOrResumeAuditStream(ctx, proto.AuditStreamRequest{
    56  		Request: &proto.AuditStreamRequest_ResumeStream{
    57  			ResumeStream: &proto.ResumeStream{
    58  				SessionID: sessionID,
    59  				UploadID:  uploadID,
    60  			},
    61  		},
    62  	})
    63  }
    64  
    65  // CreateAuditStream creates new audit stream.
    66  func (c *Client) CreateAuditStream(ctx context.Context, sessionID string) (events.Stream, error) {
    67  	return c.createOrResumeAuditStream(ctx, proto.AuditStreamRequest{
    68  		Request: &proto.AuditStreamRequest_CreateStream{
    69  			CreateStream: &proto.CreateStream{SessionID: sessionID},
    70  		},
    71  	})
    72  }
    73  
    74  type auditStreamer struct {
    75  	statusCh chan events.StreamStatus
    76  	mu       sync.RWMutex
    77  	stream   proto.AuthService_CreateAuditStreamClient
    78  	err      error
    79  	closeCtx context.Context
    80  	cancel   context.CancelFunc
    81  }
    82  
    83  // Close flushes non-uploaded flight stream data without marking
    84  // the stream completed and closes the stream instance.
    85  func (s *auditStreamer) Close(ctx context.Context) error {
    86  	defer s.closeWithError(nil)
    87  	return trace.Wrap(s.stream.Send(&proto.AuditStreamRequest{
    88  		Request: &proto.AuditStreamRequest_FlushAndCloseStream{
    89  			FlushAndCloseStream: &proto.FlushAndCloseStream{},
    90  		},
    91  	}))
    92  }
    93  
    94  // Complete completes stream.
    95  func (s *auditStreamer) Complete(ctx context.Context) error {
    96  	return trace.Wrap(s.stream.Send(&proto.AuditStreamRequest{
    97  		Request: &proto.AuditStreamRequest_CompleteStream{
    98  			CompleteStream: &proto.CompleteStream{},
    99  		},
   100  	}))
   101  }
   102  
   103  // Status returns a StreamStatus channel for the auditStreamer,
   104  // which can be received from to interact with new updates.
   105  func (s *auditStreamer) Status() <-chan events.StreamStatus {
   106  	return s.statusCh
   107  }
   108  
   109  // RecordEvent records adds an event to a session recording.
   110  func (s *auditStreamer) RecordEvent(ctx context.Context, event events.PreparedSessionEvent) error {
   111  	oneof, err := events.ToOneOf(event.GetAuditEvent())
   112  	if err != nil {
   113  		return trace.Wrap(err)
   114  	}
   115  	err = trace.Wrap(s.stream.Send(&proto.AuditStreamRequest{
   116  		Request: &proto.AuditStreamRequest_Event{Event: oneof},
   117  	}))
   118  	if err != nil {
   119  		s.closeWithError(err)
   120  		return trace.Wrap(err)
   121  	}
   122  	return nil
   123  }
   124  
   125  // Done returns channel closed when streamer is closed.
   126  func (s *auditStreamer) Done() <-chan struct{} {
   127  	return s.closeCtx.Done()
   128  }
   129  
   130  // Error returns last error of the stream.
   131  func (s *auditStreamer) Error() error {
   132  	s.mu.RLock()
   133  	defer s.mu.RUnlock()
   134  	return s.err
   135  }
   136  
   137  // recv is necessary to receive errors from the
   138  // server, otherwise no errors will be propagated.
   139  func (s *auditStreamer) recv() {
   140  	for {
   141  		status, err := s.stream.Recv()
   142  		if err != nil {
   143  			s.closeWithError(trace.Wrap(err))
   144  			return
   145  		}
   146  		select {
   147  		case <-s.closeCtx.Done():
   148  			return
   149  		case s.statusCh <- *status:
   150  		default:
   151  		}
   152  	}
   153  }
   154  
   155  func (s *auditStreamer) closeWithError(err error) {
   156  	s.cancel()
   157  	s.mu.Lock()
   158  	defer s.mu.Unlock()
   159  	s.err = err
   160  }