github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/libkb/connmgr.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package libkb 5 6 import ( 7 "fmt" 8 "sort" 9 "sync" 10 "time" 11 12 keybase1 "github.com/keybase/client/go/protocol/keybase1" 13 "github.com/keybase/go-framed-msgpack-rpc/rpc" 14 ) 15 16 // ConnectionID is a sequential integer assigned to each RPC connection 17 // that this process serves. No IDs are reused. 18 type ConnectionID int 19 20 // ApplyFn can be applied to every connection. It is called with the 21 // RPC transporter, and also the connectionID. It should return a bool 22 // true to keep going and false to stop. 23 type ApplyFn func(i ConnectionID, xp rpc.Transporter) bool 24 25 // ApplyDetailsFn can be applied to every connection. It is called with the 26 // RPC transporter, and also the connectionID. It should return a bool 27 // true to keep going and false to stop. 28 type ApplyDetailsFn func(i ConnectionID, xp rpc.Transporter, details *keybase1.ClientDetails) bool 29 30 // LabelCb is a callback to be run when a client connects and labels itself. 31 type LabelCb func(typ keybase1.ClientType) 32 33 type rpcConnection struct { 34 transporter rpc.Transporter 35 details *keybase1.ClientStatus 36 } 37 38 // ConnectionManager manages all active connections for a given service. 39 // It can be called from multiple goroutines. 40 type ConnectionManager struct { 41 sync.Mutex 42 nxt ConnectionID 43 lookup map[ConnectionID](*rpcConnection) 44 labelCbs []LabelCb 45 } 46 47 // AddConnection adds a new connection to the table of Connection object, with a 48 // related closeListener. We'll listen for a close on that channel, and when one occurs, 49 // we'll remove the connection from the pool. 50 func (c *ConnectionManager) AddConnection(xp rpc.Transporter, closeListener chan error) ConnectionID { 51 c.Lock() 52 c.nxt++ // increment first, since 0 is reserved 53 id := c.nxt 54 c.lookup[id] = &rpcConnection{transporter: xp} 55 c.Unlock() 56 57 if closeListener != nil { 58 go func() { 59 <-closeListener 60 c.removeConnection(id) 61 }() 62 } 63 64 return id 65 } 66 67 func (c *ConnectionManager) removeConnection(id ConnectionID) { 68 c.Lock() 69 delete(c.lookup, id) 70 c.Unlock() 71 } 72 73 // LookupConnection looks up a connection given a connectionID, or returns nil 74 // if no such connection was found. 75 func (c *ConnectionManager) LookupConnection(i ConnectionID) rpc.Transporter { 76 c.Lock() 77 defer c.Unlock() 78 if conn := c.lookup[i]; conn != nil { 79 return conn.transporter 80 } 81 return nil 82 } 83 84 func (c *ConnectionManager) Shutdown() { 85 } 86 87 func (c *ConnectionManager) LookupByClientType(clientType keybase1.ClientType) rpc.Transporter { 88 c.Lock() 89 defer c.Unlock() 90 for _, v := range c.lookup { 91 if v.details != nil && v.details.Details.ClientType == clientType { 92 return v.transporter 93 } 94 } 95 return nil 96 } 97 98 func (c *ConnectionManager) Label(id ConnectionID, d keybase1.ClientDetails) error { 99 c.Lock() 100 defer c.Unlock() 101 102 var err error 103 if conn := c.lookup[id]; conn != nil { 104 conn.details = &keybase1.ClientStatus{ 105 Details: d, 106 ConnectionID: int(id), 107 } 108 } else { 109 err = NotFoundError{Msg: fmt.Sprintf("connection %d not found", id)} 110 } 111 112 // Hit all the callbacks with the client type 113 for _, lloop := range c.labelCbs { 114 go func(l LabelCb) { l(d.ClientType) }(lloop) 115 } 116 117 return err 118 } 119 120 func (c *ConnectionManager) RegisterLabelCallback(f LabelCb) { 121 c.Lock() 122 c.labelCbs = append(c.labelCbs, f) 123 c.Unlock() 124 } 125 126 func (c *ConnectionManager) hasClientType(clientType keybase1.ClientType) bool { 127 for _, con := range c.ListAllLabeledConnections() { 128 if clientType == con.Details.ClientType { 129 return true 130 } 131 } 132 return false 133 } 134 135 // WaitForClientType returns true if client type is connected, or waits until timeout for the connection 136 func (c *ConnectionManager) WaitForClientType(clientType keybase1.ClientType, timeout time.Duration) bool { 137 if c.hasClientType(clientType) { 138 return true 139 } 140 ticker := time.NewTicker(time.Second) 141 deadline := time.After(timeout) 142 defer ticker.Stop() 143 for { 144 select { 145 case <-ticker.C: 146 if c.hasClientType(clientType) { 147 return true 148 } 149 case <-deadline: 150 return false 151 } 152 } 153 } 154 155 func (c *ConnectionManager) ListAllLabeledConnections() (ret []keybase1.ClientStatus) { 156 c.Lock() 157 defer c.Unlock() 158 for _, v := range c.lookup { 159 if v.details != nil { 160 ret = append(ret, *v.details) 161 } 162 } 163 sort.Sort(byClientType(ret)) 164 return ret 165 } 166 167 type byClientType []keybase1.ClientStatus 168 169 func (a byClientType) Len() int { return len(a) } 170 func (a byClientType) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 171 func (a byClientType) Less(i, j int) bool { return a[i].Details.ClientType < a[j].Details.ClientType } 172 173 // ApplyAll applies the given function f to all connections in the table. 174 // If you're going to do something blocking, please do it in a GoRoutine, 175 // since we're holding the lock for all connections as we do this. 176 func (c *ConnectionManager) ApplyAll(f ApplyFn) { 177 c.Lock() 178 defer c.Unlock() 179 for k, v := range c.lookup { 180 if !f(k, v.transporter) { 181 break 182 } 183 } 184 } 185 186 // ApplyAllDetails applies the given function f to all connections in the table. 187 // If you're going to do something blocking, please do it in a GoRoutine, 188 // since we're holding the lock for all connections as we do this. 189 func (c *ConnectionManager) ApplyAllDetails(f ApplyDetailsFn) { 190 c.Lock() 191 defer c.Unlock() 192 for k, v := range c.lookup { 193 status := v.details 194 var details *keybase1.ClientDetails 195 if status != nil { 196 details = &status.Details 197 } 198 if !f(k, v.transporter, details) { 199 break 200 } 201 } 202 } 203 204 // NewConnectionManager makes a new ConnectionManager. 205 func NewConnectionManager() *ConnectionManager { 206 return &ConnectionManager{ 207 lookup: make(map[ConnectionID](*rpcConnection)), 208 } 209 }