github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/sqlite/clientstore.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlite
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	log "github.com/golang/glog"
    26  
    27  	"github.com/google/fleetspeak/fleetspeak/src/common"
    28  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    29  
    30  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    31  	mpb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak_monitoring"
    32  	spb "github.com/google/fleetspeak/fleetspeak/src/server/proto/fleetspeak_server"
    33  	tspb "google.golang.org/protobuf/types/known/timestamppb"
    34  )
    35  
    36  const (
    37  	bytesToMIB = 1.0 / float64(1<<20)
    38  )
    39  
    40  func (d *Datastore) ListClients(ctx context.Context, ids []common.ClientID) ([]*spb.Client, error) {
    41  	d.l.Lock()
    42  	defer d.l.Unlock()
    43  
    44  	// Return value map, maps string client ids to the return values.
    45  	retm := make(map[string]*spb.Client)
    46  
    47  	h := func(rows *sql.Rows, err error) error {
    48  		if err != nil {
    49  			return err
    50  		}
    51  		defer rows.Close()
    52  		for rows.Next() {
    53  			var sid string
    54  			var timeNS int64
    55  			var addr sql.NullString
    56  			var clockSecs, clockNanos sql.NullInt64
    57  			var streamTo sql.NullString
    58  			var blacklisted bool
    59  			if err := rows.Scan(&sid, &timeNS, &addr, &streamTo, &clockSecs, &clockNanos, &blacklisted); err != nil {
    60  				return err
    61  			}
    62  
    63  			id, err := common.StringToClientID(sid)
    64  			if err != nil {
    65  				return err
    66  			}
    67  
    68  			ts := tspb.New(time.Unix(0, timeNS))
    69  			if err := ts.CheckValid(); err != nil {
    70  				return err
    71  			}
    72  
    73  			if !addr.Valid {
    74  				addr.String = ""
    75  			}
    76  
    77  			if !streamTo.Valid {
    78  				streamTo.String = ""
    79  			}
    80  
    81  			var lastClock *tspb.Timestamp
    82  			if clockSecs.Valid && clockNanos.Valid {
    83  				lastClock = &tspb.Timestamp{
    84  					Seconds: clockSecs.Int64,
    85  					Nanos:   int32(clockNanos.Int64),
    86  				}
    87  			}
    88  			retm[sid] = &spb.Client{
    89  				ClientId:               id.Bytes(),
    90  				LastContactTime:        ts,
    91  				LastContactAddress:     addr.String,
    92  				LastContactStreamingTo: streamTo.String,
    93  				LastClock:              lastClock,
    94  				Blacklisted:            blacklisted,
    95  			}
    96  		}
    97  		return rows.Err()
    98  	}
    99  
   100  	j := func(rows *sql.Rows, err error) error {
   101  		if err != nil {
   102  			return err
   103  		}
   104  		defer rows.Close()
   105  
   106  		for rows.Next() {
   107  			var sid string
   108  			l := &fspb.Label{}
   109  			if err := rows.Scan(&sid, &l.ServiceName, &l.Label); err != nil {
   110  				return err
   111  			}
   112  
   113  			retm[sid].Labels = append(retm[sid].Labels, l)
   114  		}
   115  		return nil
   116  	}
   117  
   118  	err := d.runInTx(func(tx *sql.Tx) error {
   119  		if len(ids) == 0 {
   120  			if err := h(tx.QueryContext(ctx, "SELECT client_id, last_contact_time, last_contact_address, last_contact_streaming_to, last_clock_seconds, last_clock_nanos, blacklisted FROM clients")); err != nil {
   121  				return err
   122  			}
   123  		} else {
   124  			for _, id := range ids {
   125  				if err := h(tx.QueryContext(ctx, "SELECT client_id, last_contact_time, last_contact_address, last_contact_streaming_to, last_clock_seconds, last_clock_nanos, blacklisted FROM clients WHERE client_id = ?", id.String())); err != nil {
   126  					return err
   127  				}
   128  			}
   129  		}
   130  
   131  		// Match all the labels in the database with the client ids noted in the
   132  		// previous step. Note that clients.client_id is a foreign key of
   133  		// client_labels.
   134  		if len(ids) == 0 {
   135  			if err := j(tx.QueryContext(ctx, "SELECT client_id, service_name, label FROM client_labels")); err != nil {
   136  				return err
   137  			}
   138  		} else {
   139  			for _, id := range ids {
   140  				if err := j(tx.QueryContext(ctx, "SELECT client_id, service_name, label FROM client_labels WHERE client_id = ?", id.String())); err != nil {
   141  					return err
   142  				}
   143  			}
   144  		}
   145  		return nil
   146  	})
   147  
   148  	var ret []*spb.Client
   149  	for _, v := range retm {
   150  		ret = append(ret, v)
   151  	}
   152  
   153  	return ret, err
   154  }
   155  
   156  func (d *Datastore) StreamClientIds(ctx context.Context, includeBlacklisted bool, lastContactAfter *time.Time, callback func(common.ClientID) error) error {
   157  	d.l.Lock()
   158  	defer d.l.Unlock()
   159  	return d.runInTx(func(tx *sql.Tx) error {
   160  		args := []any{}
   161  		query := "SELECT client_id FROM clients"
   162  
   163  		conditions := []string{}
   164  		if !includeBlacklisted {
   165  			conditions = append(conditions, "NOT blacklisted")
   166  		}
   167  
   168  		if lastContactAfter != nil {
   169  			conditions = append(conditions, "last_contact_time > ?")
   170  			args = append(args, lastContactAfter.UnixNano())
   171  		}
   172  
   173  		if len(conditions) > 0 {
   174  			query = fmt.Sprintf("%s WHERE %s", query, strings.Join(conditions, " AND "))
   175  		}
   176  
   177  		rs, err := tx.QueryContext(ctx, query, args...)
   178  		if err != nil {
   179  			return err
   180  		}
   181  		defer rs.Close()
   182  		for rs.Next() {
   183  			var sid string
   184  			err := rs.Scan(&sid)
   185  			if err != nil {
   186  				return err
   187  			}
   188  			id, err := common.StringToClientID(sid)
   189  			if err != nil {
   190  				return err
   191  			}
   192  			err = callback(id)
   193  			if err != nil {
   194  				return err
   195  			}
   196  		}
   197  		return nil
   198  	})
   199  }
   200  
   201  func (d *Datastore) GetClientData(ctx context.Context, id common.ClientID) (*db.ClientData, error) {
   202  	d.l.Lock()
   203  	defer d.l.Unlock()
   204  	var cd *db.ClientData
   205  	err := d.runInTx(func(tx *sql.Tx) error {
   206  		sid := id.String()
   207  
   208  		r := tx.QueryRowContext(ctx, "SELECT client_key, blacklisted FROM clients WHERE client_id=?", sid)
   209  		var c db.ClientData
   210  
   211  		err := r.Scan(&c.Key, &c.Blacklisted)
   212  		if err != nil {
   213  			return err
   214  		}
   215  
   216  		rs, err := tx.QueryContext(ctx, "SELECT service_name, label FROM client_labels WHERE client_id=?", sid)
   217  		if err != nil {
   218  			return err
   219  		}
   220  		defer rs.Close()
   221  		for rs.Next() {
   222  			l := &fspb.Label{}
   223  			err = rs.Scan(&l.ServiceName, &l.Label)
   224  			if err != nil {
   225  				return err
   226  			}
   227  			c.Labels = append(c.Labels, l)
   228  		}
   229  		if err := rs.Err(); err != nil {
   230  			return err
   231  		}
   232  		cd = &c
   233  		return nil
   234  	})
   235  	return cd, err
   236  }
   237  
   238  func (d *Datastore) AddClient(ctx context.Context, id common.ClientID, data *db.ClientData) error {
   239  	d.l.Lock()
   240  	defer d.l.Unlock()
   241  	return d.runInTx(func(tx *sql.Tx) error {
   242  		sid := id.String()
   243  		if _, err := tx.ExecContext(ctx, "INSERT INTO clients(client_id, client_key, blacklisted, last_contact_time) VALUES(?, ?, ?, ?)", sid, data.Key, data.Blacklisted, db.Now().UnixNano()); err != nil {
   244  			return err
   245  		}
   246  		for _, l := range data.Labels {
   247  			if _, err := tx.ExecContext(ctx, "INSERT INTO client_labels(client_id, service_name, label) VALUES(?, ?, ?)", sid, l.ServiceName, l.Label); err != nil {
   248  				return err
   249  			}
   250  		}
   251  		return nil
   252  	})
   253  }
   254  
   255  func (d *Datastore) AddClientLabel(ctx context.Context, id common.ClientID, l *fspb.Label) error {
   256  	d.l.Lock()
   257  	defer d.l.Unlock()
   258  	_, err := d.db.ExecContext(ctx, "INSERT INTO client_labels(client_id, service_name, label) VALUES(?, ?, ?)", id.String(), l.ServiceName, l.Label)
   259  	return err
   260  }
   261  
   262  func (d *Datastore) RemoveClientLabel(ctx context.Context, id common.ClientID, l *fspb.Label) error {
   263  	d.l.Lock()
   264  	defer d.l.Unlock()
   265  	_, err := d.db.ExecContext(ctx, "DELETE FROM client_labels WHERE client_id=? AND service_name=? AND label=?", id.String(), l.ServiceName, l.Label)
   266  	return err
   267  }
   268  
   269  func (d *Datastore) BlacklistClient(ctx context.Context, id common.ClientID) error {
   270  	d.l.Lock()
   271  	defer d.l.Unlock()
   272  	_, err := d.db.ExecContext(ctx, "UPDATE clients SET blacklisted=1 WHERE client_id=?", id.String())
   273  	return err
   274  }
   275  
   276  func (d *Datastore) RecordClientContact(ctx context.Context, data db.ContactData) (db.ContactID, error) {
   277  	d.l.Lock()
   278  	defer d.l.Unlock()
   279  
   280  	var res db.ContactID
   281  	err := d.runInTx(func(tx *sql.Tx) error {
   282  		n := db.Now().UnixNano()
   283  		r, err := tx.ExecContext(ctx, "INSERT INTO client_contacts(client_id, time, sent_nonce, received_nonce, address) VALUES(?, ?, ?, ?, ?)",
   284  			data.ClientID.String(), n, strconv.FormatUint(data.NonceSent, 16), strconv.FormatUint(data.NonceReceived, 16), data.Addr)
   285  		if err != nil {
   286  			return err
   287  		}
   288  		id, err := r.LastInsertId()
   289  		if err != nil {
   290  			return err
   291  		}
   292  		var lcs, lcn sql.NullInt64
   293  		if data.ClientClock != nil {
   294  			lcs.Int64, lcs.Valid = data.ClientClock.Seconds, true
   295  			lcn.Int64, lcn.Valid = int64(data.ClientClock.Nanos), true
   296  		}
   297  		var lcst sql.NullString
   298  		if data.StreamingTo != "" {
   299  			lcst.String, lcst.Valid = data.StreamingTo, true
   300  		}
   301  		if _, err := tx.ExecContext(ctx, "UPDATE clients SET last_contact_time = ?, last_contact_streaming_to = ?, last_contact_address = ?, last_clock_seconds = ?, last_clock_nanos = ? WHERE client_id = ?", n, lcst, data.Addr, lcs, lcn, data.ClientID.String()); err != nil {
   302  			return err
   303  		}
   304  		res = db.ContactID(strconv.FormatUint(uint64(id), 16))
   305  		return nil
   306  	})
   307  	return res, err
   308  }
   309  
   310  func (d *Datastore) StreamClientContacts(ctx context.Context, id common.ClientID, callback func(*spb.ClientContact) error) error {
   311  	d.l.Lock()
   312  	defer d.l.Unlock()
   313  
   314  	if err := d.runInTx(func(tx *sql.Tx) error {
   315  		rows, err := tx.QueryContext(
   316  			ctx,
   317  			"SELECT time, sent_nonce, received_nonce, address FROM client_contacts WHERE client_id = ?",
   318  			id.String())
   319  		if err != nil {
   320  			return err
   321  		}
   322  		defer rows.Close()
   323  		for rows.Next() {
   324  			var addr sql.NullString
   325  			var timeNS int64
   326  			var sn, rn string
   327  			c := &spb.ClientContact{}
   328  			if err := rows.Scan(&timeNS, &sn, &rn, &addr); err != nil {
   329  				return err
   330  			}
   331  
   332  			c.SentNonce, err = strconv.ParseUint(sn, 16, 64)
   333  			if err != nil {
   334  				return fmt.Errorf("Unable to parse sent_nonce in db: %v", err)
   335  			}
   336  			c.ReceivedNonce, err = strconv.ParseUint(rn, 16, 64)
   337  			if err != nil {
   338  				return fmt.Errorf("Unable to parse received_nonce in db: %v", err)
   339  			}
   340  
   341  			if addr.Valid {
   342  				c.ObservedAddress = addr.String
   343  			}
   344  
   345  			ts := tspb.New(time.Unix(0, timeNS))
   346  			if err := ts.CheckValid(); err != nil {
   347  				return err
   348  			}
   349  			c.Timestamp = ts
   350  
   351  			err = callback(c)
   352  			if err != nil {
   353  				return err
   354  			}
   355  		}
   356  		return nil
   357  	}); err != nil {
   358  		return err
   359  	}
   360  
   361  	return nil
   362  }
   363  
   364  func (d *Datastore) ListClientContacts(ctx context.Context, id common.ClientID) ([]*spb.ClientContact, error) {
   365  	var res []*spb.ClientContact
   366  	callback := func(c *spb.ClientContact) error {
   367  		res = append(res, c)
   368  		return nil
   369  	}
   370  	return res, d.StreamClientContacts(ctx, id, callback)
   371  }
   372  
   373  func (d *Datastore) LinkMessagesToContact(ctx context.Context, contact db.ContactID, ids []common.MessageID) error {
   374  	c, err := strconv.ParseUint(string(contact), 16, 64)
   375  	if err != nil {
   376  		e := fmt.Errorf("unable to parse ContactID [%v]: %v", contact, err)
   377  		log.Error(e)
   378  		return e
   379  	}
   380  	d.l.Lock()
   381  	defer d.l.Unlock()
   382  	return d.runInTx(func(tx *sql.Tx) error {
   383  		for _, id := range ids {
   384  			if _, err := tx.ExecContext(ctx, "INSERT INTO client_contact_messages(client_contact_id, message_id) VALUES (?, ?)", c, id.String()); err != nil {
   385  				return err
   386  			}
   387  		}
   388  		return nil
   389  	})
   390  }
   391  
   392  func (d *Datastore) RecordResourceUsageData(ctx context.Context, id common.ClientID, rud *mpb.ResourceUsageData) error {
   393  	d.l.Lock()
   394  	defer d.l.Unlock()
   395  	if err := rud.ProcessStartTime.CheckValid(); err != nil {
   396  		return fmt.Errorf("failed to parse process start time: %v", err)
   397  	}
   398  	processStartTime := rud.ProcessStartTime.AsTime()
   399  	if err := rud.DataTimestamp.CheckValid(); err != nil {
   400  		return fmt.Errorf("failed to parse data timestamp: %v", err)
   401  	}
   402  	clientTimestamp := rud.DataTimestamp.AsTime()
   403  	return d.runInTx(func(tx *sql.Tx) error {
   404  		_, err := tx.ExecContext(
   405  			ctx,
   406  			"INSERT INTO client_resource_usage_records VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
   407  			id.String(),
   408  			rud.Scope,
   409  			rud.Pid,
   410  			processStartTime.UnixNano(),
   411  			clientTimestamp.UnixNano(),
   412  			db.Now().UnixNano(),
   413  			rud.ProcessTerminated,
   414  			rud.ResourceUsage.MeanUserCpuRate,
   415  			rud.ResourceUsage.MaxUserCpuRate,
   416  			rud.ResourceUsage.MeanSystemCpuRate,
   417  			rud.ResourceUsage.MaxSystemCpuRate,
   418  			int32(rud.ResourceUsage.MeanResidentMemory*bytesToMIB),
   419  			int32(float64(rud.ResourceUsage.MaxResidentMemory)*bytesToMIB),
   420  			int32(rud.ResourceUsage.MeanNumFds),
   421  			rud.ResourceUsage.MaxNumFds)
   422  		return err
   423  	})
   424  }
   425  
   426  func (d *Datastore) FetchResourceUsageRecords(ctx context.Context, id common.ClientID, startTimestamp, endTimestamp *tspb.Timestamp) ([]*spb.ClientResourceUsageRecord, error) {
   427  	d.l.Lock()
   428  	defer d.l.Unlock()
   429  	if err := startTimestamp.CheckValid(); err != nil {
   430  		return nil, err
   431  	}
   432  	startTimeRange := startTimestamp.AsTime()
   433  	if err := endTimestamp.CheckValid(); err != nil {
   434  		return nil, err
   435  	}
   436  	endTimeRange := endTimestamp.AsTime()
   437  	if startTimeRange.After(endTimeRange) {
   438  		return nil, fmt.Errorf("timerange is invalid: start timestamp is after end timestamp")
   439  	}
   440  	var records []*spb.ClientResourceUsageRecord
   441  	err := d.runInTx(func(tx *sql.Tx) error {
   442  		rows, err := tx.QueryContext(
   443  			ctx,
   444  			"SELECT "+
   445  				"scope, pid, process_start_time, client_timestamp, server_timestamp, "+
   446  				"process_terminated, mean_user_cpu_rate, max_user_cpu_rate, mean_system_cpu_rate, "+
   447  				"max_system_cpu_rate, mean_resident_memory_mib, max_resident_memory_mib, "+
   448  				"mean_num_fds, max_num_fds "+
   449  				"FROM client_resource_usage_records WHERE client_id=? "+
   450  				"AND server_timestamp >= ? AND server_timestamp < ?",
   451  			id.String(),
   452  			startTimeRange.UnixNano(),
   453  			endTimeRange.UnixNano())
   454  
   455  		if err != nil {
   456  			return err
   457  		}
   458  
   459  		defer rows.Close()
   460  
   461  		for rows.Next() {
   462  			record := &spb.ClientResourceUsageRecord{}
   463  			var processStartTime, clientTimestamp, serverTimestamp int64
   464  			err := rows.Scan(
   465  				&record.Scope, &record.Pid, &processStartTime, &clientTimestamp, &serverTimestamp,
   466  				&record.ProcessTerminated, &record.MeanUserCpuRate, &record.MaxUserCpuRate, &record.MeanSystemCpuRate,
   467  				&record.MaxSystemCpuRate, &record.MeanResidentMemoryMib, &record.MaxResidentMemoryMib,
   468  				&record.MeanNumFds, &record.MaxNumFds)
   469  
   470  			if err != nil {
   471  				return err
   472  			}
   473  
   474  			record.ProcessStartTime = timestampProto(processStartTime)
   475  			record.ClientTimestamp = timestampProto(clientTimestamp)
   476  			record.ServerTimestamp = timestampProto(serverTimestamp)
   477  			records = append(records, record)
   478  		}
   479  
   480  		return nil
   481  	})
   482  	if err != nil {
   483  		return nil, err
   484  	}
   485  	return records, nil
   486  }
   487  
   488  func timestampProto(nanos int64) *tspb.Timestamp {
   489  	return &tspb.Timestamp{
   490  		Seconds: nanos / time.Second.Nanoseconds(),
   491  		Nanos:   int32(nanos % time.Second.Nanoseconds()),
   492  	}
   493  }