github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/mysql/clientstore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mysql
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"encoding/binary"
    21  	"fmt"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    25  
    26  	log "github.com/golang/glog"
    27  
    28  	"github.com/google/fleetspeak/fleetspeak/src/common"
    29  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    30  
    31  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    32  	mpb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak_monitoring"
    33  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    34  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    35  )
    36  
    37  const (
    38  	bytesToMIB = 1.0 / float64(1<<20)
    39  )
    40  
    41  func bytesToUint64(b []byte) (uint64, error) {
    42  	if len(b) != 8 {
    43  		return 0, fmt.Errorf("error converting to uint64, expected 8 bytes, got %d", len(b))
    44  	}
    45  	return binary.LittleEndian.Uint64(b), nil
    46  }
    47  
    48  func uint64ToBytes(i uint64) []byte {
    49  	b := make([]byte, 8)
    50  	binary.LittleEndian.PutUint64(b, i)
    51  	return b
    52  }
    53  
    54  func (d *Datastore) StreamClientIds(ctx context.Context, includeBlacklisted bool, lastContactAfter *time.Time, callback func(common.ClientID) error) error {
    55  	return d.runOnce(ctx, true, func(tx *sql.Tx) error {
    56  		args := []any{}
    57  		query := "SELECT client_id FROM clients"
    58  
    59  		conditions := []string{}
    60  		if !includeBlacklisted {
    61  			conditions = append(conditions, "NOT blacklisted")
    62  		}
    63  
    64  		if lastContactAfter != nil {
    65  			conditions = append(conditions, "last_contact_time > ?")
    66  			args = append(args, lastContactAfter.UnixNano())
    67  		}
    68  
    69  		if len(conditions) > 0 {
    70  			query = fmt.Sprintf("%s WHERE %s", query, strings.Join(conditions, " AND "))
    71  		}
    72  
    73  		rs, err := tx.QueryContext(ctx, query, args...)
    74  		if err != nil {
    75  			return err
    76  		}
    77  		defer rs.Close()
    78  		for rs.Next() {
    79  			var bid []byte
    80  			err := rs.Scan(&bid)
    81  			if err != nil {
    82  				return err
    83  			}
    84  			id, err := common.BytesToClientID(bid)
    85  			if err != nil {
    86  				return err
    87  			}
    88  			err = callback(id)
    89  			if err != nil {
    90  				return err
    91  			}
    92  		}
    93  		return nil
    94  	})
    95  }
    96  
    97  func (d *Datastore) ListClients(ctx context.Context, ids []common.ClientID) ([]*spb.Client, error) {
    98  	// Return value map, maps string client ids to the return values.
    99  	var retm map[string]*spb.Client
   100  
   101  	// Step one - read data from client rows. Applied to the result of one or more
   102  	// queries, depending on the ids parameter.
   103  	h := func(rows *sql.Rows, err error) error {
   104  		if err != nil {
   105  			return err
   106  		}
   107  		defer rows.Close()
   108  		for rows.Next() {
   109  			var id []byte
   110  			var timeNS int64
   111  			var addr sql.NullString
   112  			var clockSecs, clockNanos sql.NullInt64
   113  			var streamTo sql.NullString
   114  			var blacklisted bool
   115  			if err := rows.Scan(&id, &timeNS, &addr, &streamTo, &clockSecs, &clockNanos, &blacklisted); err != nil {
   116  				return err
   117  			}
   118  
   119  			ts := tspb.New(time.Unix(0, timeNS))
   120  			if err := ts.CheckValid(); err != nil {
   121  				return err
   122  			}
   123  
   124  			if !addr.Valid {
   125  				addr.String = ""
   126  			}
   127  
   128  			if !streamTo.Valid {
   129  				streamTo.String = ""
   130  			}
   131  
   132  			var lastClock *tspb.Timestamp
   133  			if clockSecs.Valid && clockNanos.Valid {
   134  				lastClock = &tspb.Timestamp{
   135  					Seconds: clockSecs.Int64,
   136  					Nanos:   int32(clockNanos.Int64),
   137  				}
   138  			}
   139  			retm[string(id)] = &spb.Client{
   140  				ClientId:               id,
   141  				LastContactTime:        ts,
   142  				LastContactAddress:     addr.String,
   143  				LastContactStreamingTo: streamTo.String,
   144  				LastClock:              lastClock,
   145  				Blacklisted:            blacklisted,
   146  			}
   147  		}
   148  		return rows.Err()
   149  	}
   150  
   151  	// Step two - applied to result of one or more queries reading client labels.
   152  	j := func(rows *sql.Rows, err error) error {
   153  		if err != nil {
   154  			return err
   155  		}
   156  
   157  		defer rows.Close()
   158  
   159  		for rows.Next() {
   160  			var id []byte
   161  			l := &fspb.Label{}
   162  			if err := rows.Scan(&id, &l.ServiceName, &l.Label); err != nil {
   163  				return err
   164  			}
   165  
   166  			retm[string(id)].Labels = append(retm[string(id)].Labels, l)
   167  		}
   168  		return nil
   169  	}
   170  
   171  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   172  		retm = make(map[string]*spb.Client)
   173  		if len(ids) == 0 {
   174  			if err := h(tx.QueryContext(ctx, "SELECT client_id, last_contact_time, last_contact_address, last_contact_streaming_to, last_clock_seconds, last_clock_nanos, blacklisted FROM clients")); err != nil {
   175  				return err
   176  			}
   177  		} else {
   178  			for _, id := range ids {
   179  				if err := h(tx.QueryContext(ctx, "SELECT client_id, last_contact_time, last_contact_address, last_contact_streaming_to, last_clock_seconds, last_clock_nanos, blacklisted FROM clients WHERE client_id = ?", id.Bytes())); err != nil {
   180  					return err
   181  				}
   182  			}
   183  		}
   184  
   185  		// Match all the labels in the database with the client ids noted in the
   186  		// previous step. Note that clients.client_id is a foreign key of
   187  		// client_labels.
   188  		if len(ids) == 0 {
   189  			if err := j(tx.QueryContext(ctx, "SELECT client_id, service_name, label FROM client_labels")); err != nil {
   190  				return err
   191  			}
   192  		} else {
   193  			for _, id := range ids {
   194  				if err := j(tx.QueryContext(ctx, "SELECT client_id, service_name, label FROM client_labels WHERE client_id = ?", id.String())); err != nil {
   195  					return err
   196  				}
   197  			}
   198  		}
   199  		return nil
   200  	})
   201  
   202  	var ret []*spb.Client
   203  	for _, v := range retm {
   204  		ret = append(ret, v)
   205  	}
   206  
   207  	return ret, err
   208  }
   209  
   210  func (d *Datastore) GetClientData(ctx context.Context, id common.ClientID) (*db.ClientData, error) {
   211  	var cd *db.ClientData
   212  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   213  		iid := id.Bytes()
   214  
   215  		r := tx.QueryRowContext(ctx, "SELECT client_key, blacklisted FROM clients WHERE client_id=?", iid)
   216  		var c db.ClientData
   217  
   218  		err := r.Scan(&c.Key, &c.Blacklisted)
   219  		if err != nil {
   220  			return err
   221  		}
   222  
   223  		rs, err := tx.QueryContext(ctx, "SELECT service_name, label FROM client_labels WHERE client_id=?", iid)
   224  		if err != nil {
   225  			return err
   226  		}
   227  		defer rs.Close()
   228  		for rs.Next() {
   229  			l := &fspb.Label{}
   230  			err = rs.Scan(&l.ServiceName, &l.Label)
   231  			if err != nil {
   232  				return err
   233  			}
   234  			c.Labels = append(c.Labels, l)
   235  		}
   236  		if err := rs.Err(); err != nil {
   237  			return err
   238  		}
   239  		cd = &c
   240  		return nil
   241  	})
   242  	return cd, err
   243  }
   244  
   245  func (d *Datastore) AddClient(ctx context.Context, id common.ClientID, data *db.ClientData) error {
   246  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   247  		if _, err := tx.ExecContext(ctx, "INSERT INTO clients(client_id, client_key, blacklisted, last_contact_time) VALUES(?, ?, ?, ?)", id.Bytes(), data.Key, data.Blacklisted, db.Now().UnixNano()); err != nil {
   248  			return err
   249  		}
   250  		for _, l := range data.Labels {
   251  			if _, err := tx.ExecContext(ctx, "INSERT INTO client_labels(client_id, service_name, label) VALUES(?, ?, ?)", id.Bytes(), l.ServiceName, l.Label); err != nil {
   252  				return err
   253  			}
   254  		}
   255  		return nil
   256  	})
   257  }
   258  
   259  func (d *Datastore) AddClientLabel(ctx context.Context, id common.ClientID, l *fspb.Label) error {
   260  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   261  		_, err := d.db.ExecContext(ctx, "INSERT INTO client_labels(client_id, service_name, label) VALUES(?, ?, ?)", id.Bytes(), l.ServiceName, l.Label)
   262  		return err
   263  	})
   264  }
   265  
   266  func (d *Datastore) RemoveClientLabel(ctx context.Context, id common.ClientID, l *fspb.Label) error {
   267  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   268  		_, err := d.db.ExecContext(ctx, "DELETE FROM client_labels WHERE client_id=? AND service_name=? AND label=?", id.Bytes(), l.ServiceName, l.Label)
   269  		return err
   270  	})
   271  }
   272  
   273  func (d *Datastore) BlacklistClient(ctx context.Context, id common.ClientID) error {
   274  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   275  		_, err := d.db.ExecContext(ctx, "UPDATE clients SET blacklisted=TRUE WHERE client_id=?", id.Bytes())
   276  		return err
   277  	})
   278  }
   279  
   280  func (d *Datastore) RecordClientContact(ctx context.Context, data db.ContactData) (db.ContactID, error) {
   281  	var res db.ContactID
   282  	err := d.runInTx(ctx, false, func(tx *sql.Tx) error {
   283  		n := db.Now().UnixNano()
   284  		r, err := tx.ExecContext(ctx, "INSERT INTO client_contacts(client_id, time, sent_nonce, received_nonce, address) VALUES(?, ?, ?, ?, ?)",
   285  			data.ClientID.Bytes(), n, uint64ToBytes(data.NonceSent), uint64ToBytes(data.NonceReceived), data.Addr)
   286  		if err != nil {
   287  			return err
   288  		}
   289  		id, err := r.LastInsertId()
   290  		if err != nil {
   291  			return err
   292  		}
   293  		var lcs, lcn sql.NullInt64
   294  		if data.ClientClock != nil {
   295  			lcs.Int64, lcs.Valid = data.ClientClock.Seconds, true
   296  			lcn.Int64, lcn.Valid = int64(data.ClientClock.Nanos), true
   297  		}
   298  		var lcst sql.NullString
   299  		if data.StreamingTo != "" {
   300  			lcst.String, lcst.Valid = data.StreamingTo, true
   301  		}
   302  		if _, err := tx.ExecContext(ctx, "UPDATE clients SET last_contact_time = ?, last_contact_streaming_to = ?, last_contact_address = ?, last_clock_seconds = ?, last_clock_nanos = ? WHERE client_id = ?", n, lcst, data.Addr, lcs, lcn, data.ClientID.Bytes()); err != nil {
   303  			return err
   304  		}
   305  		res = db.ContactID(strconv.FormatUint(uint64(id), 16))
   306  		return nil
   307  	})
   308  	return res, err
   309  }
   310  
   311  func (d *Datastore) streamClientContacts(ctx context.Context, tx *sql.Tx, id common.ClientID, callback func(*spb.ClientContact) error) error {
   312  	rows, err := tx.QueryContext(
   313  		ctx,
   314  		"SELECT time, sent_nonce, received_nonce, address FROM client_contacts WHERE client_id = ?",
   315  		id.Bytes())
   316  	if err != nil {
   317  		return err
   318  	}
   319  	defer rows.Close()
   320  	for rows.Next() {
   321  		var addr sql.NullString
   322  		var timeNS int64
   323  		var sn, rn []byte
   324  		c := &spb.ClientContact{}
   325  		if err := rows.Scan(&timeNS, &sn, &rn, &addr); err != nil {
   326  			return err
   327  		}
   328  		c.SentNonce, err = bytesToUint64(sn)
   329  		if err != nil {
   330  			return err
   331  		}
   332  		c.ReceivedNonce, err = bytesToUint64(rn)
   333  		if err != nil {
   334  			return err
   335  		}
   336  
   337  		if addr.Valid {
   338  			c.ObservedAddress = addr.String
   339  		}
   340  
   341  		ts := tspb.New(time.Unix(0, timeNS))
   342  		if err := ts.CheckValid(); err != nil {
   343  			return err
   344  		}
   345  		c.Timestamp = ts
   346  
   347  		err = callback(c)
   348  		if err != nil {
   349  			return err
   350  		}
   351  	}
   352  	return nil
   353  }
   354  
   355  func (d *Datastore) StreamClientContacts(ctx context.Context, id common.ClientID, callback func(*spb.ClientContact) error) error {
   356  	return d.runOnce(ctx, true, func(tx *sql.Tx) error {
   357  		return d.streamClientContacts(ctx, tx, id, callback)
   358  	})
   359  }
   360  
   361  func (d *Datastore) ListClientContacts(ctx context.Context, id common.ClientID) ([]*spb.ClientContact, error) {
   362  	var res []*spb.ClientContact
   363  	callback := func(c *spb.ClientContact) error {
   364  		res = append(res, c)
   365  		return nil
   366  	}
   367  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   368  		res = res[:0]
   369  		return d.streamClientContacts(ctx, tx, id, callback)
   370  	})
   371  	return res, err
   372  }
   373  
   374  func (d *Datastore) LinkMessagesToContact(ctx context.Context, contact db.ContactID, ids []common.MessageID) error {
   375  	c, err := strconv.ParseUint(string(contact), 16, 64)
   376  	if err != nil {
   377  		e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err)
   378  		log.Error(e)
   379  		return e
   380  	}
   381  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   382  		for _, id := range ids {
   383  			if _, err := tx.ExecContext(ctx, "INSERT INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id.Bytes()); err != nil {
   384  				return err
   385  			}
   386  		}
   387  		return nil
   388  	})
   389  }
   390  
   391  func (d *Datastore) RecordResourceUsageData(ctx context.Context, id common.ClientID, rud *mpb.ResourceUsageData) error {
   392  	if err := rud.ProcessStartTime.CheckValid(); err != nil {
   393  		return fmt.Errorf("failed to parse process start time: %v", err)
   394  	}
   395  	processStartTime := rud.ProcessStartTime.AsTime()
   396  	if err := rud.DataTimestamp.CheckValid(); err != nil {
   397  		return fmt.Errorf("failed to parse data timestamp: %v", err)
   398  	}
   399  	clientTimestamp := rud.DataTimestamp.AsTime()
   400  	return d.runInTx(ctx, false, func(tx *sql.Tx) error {
   401  		_, err := tx.ExecContext(
   402  			ctx,
   403  			"INSERT INTO client_resource_usage_records VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
   404  			id.Bytes(),
   405  			rud.Scope,
   406  			rud.Pid,
   407  			processStartTime.UnixNano(),
   408  			clientTimestamp.UnixNano(),
   409  			db.Now().UnixNano(),
   410  			rud.ProcessTerminated,
   411  			rud.ResourceUsage.MeanUserCpuRate,
   412  			rud.ResourceUsage.MaxUserCpuRate,
   413  			rud.ResourceUsage.MeanSystemCpuRate,
   414  			rud.ResourceUsage.MaxSystemCpuRate,
   415  			int32(rud.ResourceUsage.MeanResidentMemory*bytesToMIB),
   416  			int32(float64(rud.ResourceUsage.MaxResidentMemory)*bytesToMIB),
   417  			int32(rud.ResourceUsage.MeanNumFds),
   418  			rud.ResourceUsage.MaxNumFds)
   419  		return err
   420  	})
   421  }
   422  
   423  func (d *Datastore) FetchResourceUsageRecords(ctx context.Context, id common.ClientID, startTimestamp, endTimestamp *tspb.Timestamp) ([]*spb.ClientResourceUsageRecord, error) {
   424  	if err := startTimestamp.CheckValid(); err != nil {
   425  		return nil, err
   426  	}
   427  	startTimeRange := startTimestamp.AsTime()
   428  	if err := endTimestamp.CheckValid(); err != nil {
   429  		return nil, err
   430  	}
   431  	endTimeRange := endTimestamp.AsTime()
   432  	if startTimeRange.After(endTimeRange) {
   433  		return nil, fmt.Errorf("timerange is invalid: start timestamp is after end timestamp")
   434  	}
   435  	var records []*spb.ClientResourceUsageRecord
   436  	err := d.runInTx(ctx, true, func(tx *sql.Tx) error {
   437  		records = nil
   438  		rows, err := tx.QueryContext(
   439  			ctx,
   440  			"SELECT "+
   441  				"scope, pid, process_start_time, client_timestamp, server_timestamp, "+
   442  				"process_terminated, mean_user_cpu_rate, max_user_cpu_rate, mean_system_cpu_rate, "+
   443  				"max_system_cpu_rate, mean_resident_memory_mib, max_resident_memory_mib, "+
   444  				"mean_num_fds, max_num_fds "+
   445  				"FROM client_resource_usage_records WHERE client_id=? "+
   446  				"AND server_timestamp >= ? AND server_timestamp < ?",
   447  			id.Bytes(),
   448  			startTimeRange.UnixNano(),
   449  			endTimeRange.UnixNano())
   450  
   451  		if err != nil {
   452  			return err
   453  		}
   454  
   455  		defer rows.Close()
   456  
   457  		for rows.Next() {
   458  			record := &spb.ClientResourceUsageRecord{}
   459  			var processStartTime, clientTimestamp, serverTimestamp int64
   460  			err := rows.Scan(
   461  				&record.Scope, &record.Pid, &processStartTime, &clientTimestamp, &serverTimestamp,
   462  				&record.ProcessTerminated, &record.MeanUserCpuRate, &record.MaxUserCpuRate, &record.MeanSystemCpuRate,
   463  				&record.MaxSystemCpuRate, &record.MeanResidentMemoryMib, &record.MaxResidentMemoryMib,
   464  				&record.MeanNumFds, &record.MaxNumFds)
   465  
   466  			if err != nil {
   467  				return err
   468  			}
   469  
   470  			record.ProcessStartTime = timestampProto(processStartTime)
   471  			record.ClientTimestamp = timestampProto(clientTimestamp)
   472  			record.ServerTimestamp = timestampProto(serverTimestamp)
   473  			records = append(records, record)
   474  		}
   475  
   476  		return nil
   477  	})
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  	return records, nil
   482  }
   483  
   484  func timestampProto(nanos int64) *tspb.Timestamp {
   485  	return &tspb.Timestamp{
   486  		Seconds: nanos / time.Second.Nanoseconds(),
   487  		Nanos:   int32(nanos % time.Second.Nanoseconds()),
   488  	}
   489  }