github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/admin/admin.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package admin defines an administrative interface into the fleetspeak system.
    16  package admin
    17  
    18  import (
    19  	"bytes"
    20  	"errors"
    21  	"fmt"
    22  	"time"
    23  
    24  	"context"
    25  
    26  	log "github.com/golang/glog"
    27  	"github.com/google/fleetspeak/fleetspeak/src/common"
    28  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    29  	inotifications "github.com/google/fleetspeak/fleetspeak/src/server/internal/notifications"
    30  	"github.com/google/fleetspeak/fleetspeak/src/server/notifications"
    31  	"google.golang.org/protobuf/proto"
    32  
    33  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    34  	sgrpc "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    35  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    36  )
    37  
    38  const (
    39  	// Max size of messages accepted by Fleetspeak clients.
    40  	maxMsgSize = 2 << 20 // 2MiB
    41  )
    42  
    43  // NewServer returns an admin_grpc.AdminServer which performs operations using
    44  // the provided db.Store.
    45  func NewServer(s db.Store, n notifications.Notifier) sgrpc.AdminServer {
    46  	if n == nil {
    47  		n = inotifications.NoopNotifier{}
    48  	}
    49  	return adminServer{
    50  		store:    s,
    51  		notifier: n,
    52  	}
    53  }
    54  
    55  // adminServer implements admin_grpc.AdminServer.
    56  type adminServer struct {
    57  	sgrpc.UnimplementedAdminServer
    58  
    59  	store    db.Store
    60  	notifier notifications.Notifier
    61  }
    62  
    63  func (s adminServer) CreateBroadcast(ctx context.Context, req *spb.CreateBroadcastRequest) (*fspb.EmptyMessage, error) {
    64  	if err := s.store.CreateBroadcast(ctx, req.Broadcast, req.Limit); err != nil {
    65  		return nil, err
    66  	}
    67  	return &fspb.EmptyMessage{}, nil
    68  }
    69  
    70  func (s adminServer) ListActiveBroadcasts(ctx context.Context, req *spb.ListActiveBroadcastsRequest) (*spb.ListActiveBroadcastsResponse, error) {
    71  	var ret spb.ListActiveBroadcastsResponse
    72  	bis, err := s.store.ListActiveBroadcasts(ctx)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	for _, bi := range bis {
    77  		if req.ServiceName != "" && req.ServiceName != bi.Broadcast.Source.ServiceName {
    78  			continue
    79  		}
    80  		ret.Broadcasts = append(ret.Broadcasts, bi.Broadcast)
    81  	}
    82  	return &ret, nil
    83  }
    84  
    85  func (s adminServer) GetMessageStatus(ctx context.Context, req *spb.GetMessageStatusRequest) (*spb.GetMessageStatusResponse, error) {
    86  	mid, err := common.BytesToMessageID(req.MessageId)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	msgs, err := s.store.GetMessages(ctx, []common.MessageID{mid}, false)
    92  	if err != nil {
    93  		if s.store.IsNotFound(err) {
    94  			return &spb.GetMessageStatusResponse{}, nil
    95  		}
    96  		return nil, err
    97  	}
    98  	if len(msgs) != 1 {
    99  		return nil, fmt.Errorf("Internal error, expected 1 message, got %d", len(msgs))
   100  	}
   101  
   102  	return &spb.GetMessageStatusResponse{
   103  			CreationTime: msgs[0].CreationTime,
   104  			Result:       msgs[0].Result},
   105  		nil
   106  }
   107  
   108  func (s adminServer) ListClients(ctx context.Context, req *spb.ListClientsRequest) (*spb.ListClientsResponse, error) {
   109  	ids := make([]common.ClientID, 0, len(req.ClientIds))
   110  	for i, b := range req.ClientIds {
   111  		id, err := common.BytesToClientID(b)
   112  		if err != nil {
   113  			return nil, fmt.Errorf("unable to parse id [%d]: %v", i, err)
   114  		}
   115  		ids = append(ids, id)
   116  	}
   117  	clients, err := s.store.ListClients(ctx, ids)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	return &spb.ListClientsResponse{
   123  		Clients: clients,
   124  	}, nil
   125  }
   126  
   127  func (s adminServer) StreamClientIds(req *spb.StreamClientIdsRequest, srv sgrpc.Admin_StreamClientIdsServer) error {
   128  	callback := func(id common.ClientID) error {
   129  		return srv.Send(&spb.StreamClientIdsResponse{
   130  			ClientId: id.Bytes(),
   131  		})
   132  	}
   133  	lastContactAfter := req.LastContactAfter.AsTime()
   134  	return s.store.StreamClientIds(srv.Context(), req.IncludeBlacklisted, &lastContactAfter, callback)
   135  }
   136  
   137  func (s adminServer) ListClientContacts(ctx context.Context, req *spb.ListClientContactsRequest) (*spb.ListClientContactsResponse, error) {
   138  	id, err := common.BytesToClientID(req.ClientId)
   139  	if err != nil {
   140  		return nil, fmt.Errorf("unable to parse id [%d]: %v", req.ClientId, err)
   141  	}
   142  
   143  	contacts, err := s.store.ListClientContacts(ctx, id)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	return &spb.ListClientContactsResponse{
   148  		Contacts: contacts,
   149  	}, nil
   150  }
   151  
   152  func (s adminServer) StreamClientContacts(req *spb.StreamClientContactsRequest, srv sgrpc.Admin_StreamClientContactsServer) error {
   153  	callback := func(contact *spb.ClientContact) error {
   154  		return srv.Send(&spb.StreamClientContactsResponse{
   155  			Contact: contact,
   156  		})
   157  	}
   158  	id, err := common.BytesToClientID(req.ClientId)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	return s.store.StreamClientContacts(srv.Context(), id, callback)
   163  }
   164  
   165  func (s adminServer) InsertMessage(ctx context.Context, m *fspb.Message) (*fspb.EmptyMessage, error) {
   166  	// At this point, we mostly trust the message we get, but do some basic
   167  	// sanity checks and generate missing metadata.
   168  	if m.Destination == nil || m.Destination.ServiceName == "" {
   169  		return nil, errors.New("message must have Destination")
   170  	}
   171  	if m.Source == nil || m.Source.ServiceName == "" {
   172  		return nil, errors.New("message must have Source")
   173  	}
   174  	if len(m.MessageId) == 0 {
   175  		id, err := common.RandomMessageID()
   176  		if err != nil {
   177  			return nil, fmt.Errorf("unable to create random MessageID: %v", err)
   178  		}
   179  		m.MessageId = id.Bytes()
   180  	}
   181  	if m.CreationTime == nil {
   182  		m.CreationTime = db.NowProto()
   183  	}
   184  
   185  	// If the message is to a client, we'll want to notify any server that it is
   186  	// connected to. Gather the data for this now, doing the validation implicit
   187  	// in this before saving the message.
   188  	var cid common.ClientID
   189  	var st string
   190  	var lc time.Time
   191  	if m.Destination.ClientId != nil {
   192  		var err error
   193  		cid, err = common.BytesToClientID(m.Destination.ClientId)
   194  		if err != nil {
   195  			return nil, fmt.Errorf("error parsing destination.client_id (%x): %v", m.Destination.ClientId, err)
   196  		}
   197  		cls, err := s.store.ListClients(ctx, []common.ClientID{cid})
   198  		if err != nil {
   199  			return nil, fmt.Errorf("error listing destination client (%x): %v", m.Destination.ClientId, err)
   200  		}
   201  		if len(cls) != 1 {
   202  			return nil, fmt.Errorf("expected 1 destination client result, got %d", len(cls))
   203  		}
   204  		st = cls[0].LastContactStreamingTo
   205  		if cls[0].LastContactTime != nil {
   206  			if err := cls[0].LastContactTime.CheckValid(); err != nil {
   207  				log.Errorf("Failed to convert last contact time from database: %v", err)
   208  				lc = time.Time{}
   209  			} else {
   210  				lc = cls[0].LastContactTime.AsTime()
   211  			}
   212  		}
   213  		msgSize := proto.Size(m)
   214  		if msgSize > maxMsgSize {
   215  			return nil, fmt.Errorf("message intended for client %x is of size %d, which exceeds the %d-byte limit", m.Destination.ClientId, msgSize, maxMsgSize)
   216  		}
   217  	}
   218  
   219  	if err := s.store.StoreMessages(ctx, []*fspb.Message{m}, ""); err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	// Notify the most recent connection to the client. Don't fail the RPC if we
   224  	// have trouble though, as we do have the message and it should get there
   225  	// eventually.
   226  	if st != "" && time.Since(lc) < 10*time.Minute {
   227  		if err := s.notifier.NewMessageForClient(ctx, st, cid); err != nil {
   228  			log.Warningf("Failure trying to notify of new message for client (%x): %v", m.Destination.ClientId, err)
   229  		}
   230  	}
   231  
   232  	return &fspb.EmptyMessage{}, nil
   233  }
   234  
   235  func (s adminServer) bytesToClientIds(ids [][]byte) ([]common.ClientID, error) {
   236  	result := make([]common.ClientID, len(ids))
   237  	for i, b := range ids {
   238  		bid, err := common.BytesToClientID(b)
   239  		if err != nil {
   240  			return nil, fmt.Errorf("Can't convert bytes to ClientID: %v", err)
   241  		}
   242  
   243  		result[i] = bid
   244  	}
   245  	return result, nil
   246  }
   247  
   248  func (s adminServer) DeletePendingMessages(ctx context.Context, r *spb.DeletePendingMessagesRequest) (*fspb.EmptyMessage, error) {
   249  	ids, err := s.bytesToClientIds(r.ClientIds)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  
   254  	if err := s.store.DeletePendingMessages(ctx, ids); err != nil {
   255  		return nil, fmt.Errorf("Can't delete pending messages: %v", err)
   256  	}
   257  
   258  	return &fspb.EmptyMessage{}, nil
   259  }
   260  
   261  func (s adminServer) GetPendingMessages(ctx context.Context, r *spb.GetPendingMessagesRequest) (*spb.GetPendingMessagesResponse, error) {
   262  	ids, err := s.bytesToClientIds(r.ClientIds)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	res, err := s.store.GetPendingMessages(ctx, ids, r.Offset, r.Limit, r.WantData)
   268  	if err != nil {
   269  		return nil, fmt.Errorf("Can't read pending messages: %v", err)
   270  	}
   271  
   272  	return &spb.GetPendingMessagesResponse{Messages: res}, nil
   273  }
   274  
   275  func (s adminServer) GetPendingMessageCount(ctx context.Context, r *spb.GetPendingMessageCountRequest) (*spb.GetPendingMessageCountResponse, error) {
   276  	ids, err := s.bytesToClientIds(r.ClientIds)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  
   281  	res, err := s.store.GetPendingMessageCount(ctx, ids)
   282  	if err != nil {
   283  		return nil, fmt.Errorf("Can't read pending message count: %v", err)
   284  	}
   285  
   286  	return &spb.GetPendingMessageCountResponse{Count: res}, nil
   287  }
   288  
   289  func (s adminServer) StoreFile(ctx context.Context, req *spb.StoreFileRequest) (*fspb.EmptyMessage, error) {
   290  	if req.ServiceName == "" || req.FileName == "" {
   291  		return nil, errors.New("file must have service_name and file_name")
   292  	}
   293  	if err := s.store.StoreFile(ctx, req.ServiceName, req.FileName, bytes.NewReader(req.Data)); err != nil {
   294  		return nil, err
   295  	}
   296  	return &fspb.EmptyMessage{}, nil
   297  }
   298  
   299  func (s adminServer) KeepAlive(ctx context.Context, _ *fspb.EmptyMessage) (*fspb.EmptyMessage, error) {
   300  	return &fspb.EmptyMessage{}, nil
   301  }
   302  
   303  func (s adminServer) BlacklistClient(ctx context.Context, req *spb.BlacklistClientRequest) (*fspb.EmptyMessage, error) {
   304  	id, err := common.BytesToClientID(req.ClientId)
   305  	if err != nil {
   306  		return nil, fmt.Errorf("unable to parse id [%d]: %v", req.ClientId, err)
   307  	}
   308  	if err := s.store.BlacklistClient(ctx, id); err != nil {
   309  		return nil, err
   310  	}
   311  	return &fspb.EmptyMessage{}, nil
   312  }
   313  
   314  func (s adminServer) FetchClientResourceUsageRecords(ctx context.Context, req *spb.FetchClientResourceUsageRecordsRequest) (*spb.FetchClientResourceUsageRecordsResponse, error) {
   315  	clientID, idErr := common.BytesToClientID(req.ClientId)
   316  	if idErr != nil {
   317  		return nil, idErr
   318  	}
   319  	records, dbErr := s.store.FetchResourceUsageRecords(ctx, clientID, req.StartTimestamp, req.EndTimestamp)
   320  	if dbErr != nil {
   321  		return nil, dbErr
   322  	}
   323  	return &spb.FetchClientResourceUsageRecordsResponse{
   324  		Records: records,
   325  	}, nil
   326  }