github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/libkb/connmgr.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"fmt"
     8  	"sort"
     9  	"sync"
    10  	"time"
    11  
    12  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    13  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    14  )
    15  
    16  // ConnectionID is a sequential integer assigned to each RPC connection
    17  // that this process serves. No IDs are reused.
    18  type ConnectionID int
    19  
    20  // ApplyFn can be applied to every connection. It is called with the
    21  // RPC transporter, and also the connectionID. It should return a bool
    22  // true to keep going and false to stop.
    23  type ApplyFn func(i ConnectionID, xp rpc.Transporter) bool
    24  
    25  // ApplyDetailsFn can be applied to every connection. It is called with the
    26  // RPC transporter, and also the connectionID. It should return a bool
    27  // true to keep going and false to stop.
    28  type ApplyDetailsFn func(i ConnectionID, xp rpc.Transporter, details *keybase1.ClientDetails) bool
    29  
    30  // LabelCb is a callback to be run when a client connects and labels itself.
    31  type LabelCb func(typ keybase1.ClientType)
    32  
    33  type rpcConnection struct {
    34  	transporter rpc.Transporter
    35  	details     *keybase1.ClientStatus
    36  }
    37  
    38  // ConnectionManager manages all active connections for a given service.
    39  // It can be called from multiple goroutines.
    40  type ConnectionManager struct {
    41  	sync.Mutex
    42  	nxt      ConnectionID
    43  	lookup   map[ConnectionID](*rpcConnection)
    44  	labelCbs []LabelCb
    45  }
    46  
    47  // AddConnection adds a new connection to the table of Connection object, with a
    48  // related closeListener. We'll listen for a close on that channel, and when one occurs,
    49  // we'll remove the connection from the pool.
    50  func (c *ConnectionManager) AddConnection(xp rpc.Transporter, closeListener chan error) ConnectionID {
    51  	c.Lock()
    52  	c.nxt++ // increment first, since 0 is reserved
    53  	id := c.nxt
    54  	c.lookup[id] = &rpcConnection{transporter: xp}
    55  	c.Unlock()
    56  
    57  	if closeListener != nil {
    58  		go func() {
    59  			<-closeListener
    60  			c.removeConnection(id)
    61  		}()
    62  	}
    63  
    64  	return id
    65  }
    66  
    67  func (c *ConnectionManager) removeConnection(id ConnectionID) {
    68  	c.Lock()
    69  	delete(c.lookup, id)
    70  	c.Unlock()
    71  }
    72  
    73  // LookupConnection looks up a connection given a connectionID, or returns nil
    74  // if no such connection was found.
    75  func (c *ConnectionManager) LookupConnection(i ConnectionID) rpc.Transporter {
    76  	c.Lock()
    77  	defer c.Unlock()
    78  	if conn := c.lookup[i]; conn != nil {
    79  		return conn.transporter
    80  	}
    81  	return nil
    82  }
    83  
    84  func (c *ConnectionManager) Shutdown() {
    85  }
    86  
    87  func (c *ConnectionManager) LookupByClientType(clientType keybase1.ClientType) rpc.Transporter {
    88  	c.Lock()
    89  	defer c.Unlock()
    90  	for _, v := range c.lookup {
    91  		if v.details != nil && v.details.Details.ClientType == clientType {
    92  			return v.transporter
    93  		}
    94  	}
    95  	return nil
    96  }
    97  
    98  func (c *ConnectionManager) Label(id ConnectionID, d keybase1.ClientDetails) error {
    99  	c.Lock()
   100  	defer c.Unlock()
   101  
   102  	var err error
   103  	if conn := c.lookup[id]; conn != nil {
   104  		conn.details = &keybase1.ClientStatus{
   105  			Details:      d,
   106  			ConnectionID: int(id),
   107  		}
   108  	} else {
   109  		err = NotFoundError{Msg: fmt.Sprintf("connection %d not found", id)}
   110  	}
   111  
   112  	// Hit all the callbacks with the client type
   113  	for _, lloop := range c.labelCbs {
   114  		go func(l LabelCb) { l(d.ClientType) }(lloop)
   115  	}
   116  
   117  	return err
   118  }
   119  
   120  func (c *ConnectionManager) RegisterLabelCallback(f LabelCb) {
   121  	c.Lock()
   122  	c.labelCbs = append(c.labelCbs, f)
   123  	c.Unlock()
   124  }
   125  
   126  func (c *ConnectionManager) hasClientType(clientType keybase1.ClientType) bool {
   127  	for _, con := range c.ListAllLabeledConnections() {
   128  		if clientType == con.Details.ClientType {
   129  			return true
   130  		}
   131  	}
   132  	return false
   133  }
   134  
   135  // WaitForClientType returns true if client type is connected, or waits until timeout for the connection
   136  func (c *ConnectionManager) WaitForClientType(clientType keybase1.ClientType, timeout time.Duration) bool {
   137  	if c.hasClientType(clientType) {
   138  		return true
   139  	}
   140  	ticker := time.NewTicker(time.Second)
   141  	deadline := time.After(timeout)
   142  	defer ticker.Stop()
   143  	for {
   144  		select {
   145  		case <-ticker.C:
   146  			if c.hasClientType(clientType) {
   147  				return true
   148  			}
   149  		case <-deadline:
   150  			return false
   151  		}
   152  	}
   153  }
   154  
   155  func (c *ConnectionManager) ListAllLabeledConnections() (ret []keybase1.ClientStatus) {
   156  	c.Lock()
   157  	defer c.Unlock()
   158  	for _, v := range c.lookup {
   159  		if v.details != nil {
   160  			ret = append(ret, *v.details)
   161  		}
   162  	}
   163  	sort.Sort(byClientType(ret))
   164  	return ret
   165  }
   166  
   167  type byClientType []keybase1.ClientStatus
   168  
   169  func (a byClientType) Len() int           { return len(a) }
   170  func (a byClientType) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
   171  func (a byClientType) Less(i, j int) bool { return a[i].Details.ClientType < a[j].Details.ClientType }
   172  
   173  // ApplyAll applies the given function f to all connections in the table.
   174  // If you're going to do something blocking, please do it in a GoRoutine,
   175  // since we're holding the lock for all connections as we do this.
   176  func (c *ConnectionManager) ApplyAll(f ApplyFn) {
   177  	c.Lock()
   178  	defer c.Unlock()
   179  	for k, v := range c.lookup {
   180  		if !f(k, v.transporter) {
   181  			break
   182  		}
   183  	}
   184  }
   185  
   186  // ApplyAllDetails applies the given function f to all connections in the table.
   187  // If you're going to do something blocking, please do it in a GoRoutine,
   188  // since we're holding the lock for all connections as we do this.
   189  func (c *ConnectionManager) ApplyAllDetails(f ApplyDetailsFn) {
   190  	c.Lock()
   191  	defer c.Unlock()
   192  	for k, v := range c.lookup {
   193  		status := v.details
   194  		var details *keybase1.ClientDetails
   195  		if status != nil {
   196  			details = &status.Details
   197  		}
   198  		if !f(k, v.transporter, details) {
   199  			break
   200  		}
   201  	}
   202  }
   203  
   204  // NewConnectionManager makes a new ConnectionManager.
   205  func NewConnectionManager() *ConnectionManager {
   206  	return &ConnectionManager{
   207  		lookup: make(map[ConnectionID](*rpcConnection)),
   208  	}
   209  }