github.com/Rookout/GoSDK@v0.1.48/pkg/com_ws/agent_com_ws.go (about) 1 package com_ws 2 3 import ( 4 "context" 5 "encoding/hex" 6 "fmt" 7 "net/url" 8 "strings" 9 10 "github.com/Rookout/GoSDK/pkg/common" 11 "github.com/Rookout/GoSDK/pkg/config" 12 "github.com/Rookout/GoSDK/pkg/information" 13 "github.com/Rookout/GoSDK/pkg/logger" 14 pb "github.com/Rookout/GoSDK/pkg/protobuf" 15 "github.com/Rookout/GoSDK/pkg/rookoutErrors" 16 "github.com/Rookout/GoSDK/pkg/utils" 17 "github.com/google/uuid" 18 "google.golang.org/protobuf/types/known/anypb" 19 ) 20 21 type Callable func(*anypb.Any) 22 23 type AgentCom interface { 24 ConnectToAgent() error 25 RegisterCallback(string, Callable) 26 Send([]byte) rookoutErrors.RookoutError 27 Stop() 28 Flush() 29 } 30 31 type messageCallback struct { 32 callback Callable 33 persistent bool 34 } 35 36 type agentComWs struct { 37 agentID string 38 output Output 39 agentURL *url.URL 40 proxy *url.URL 41 token string 42 callbacks map[string][]messageCallback 43 agentInfo *pb.AgentInformation 44 printOnInitialConnection bool 45 stopCtx context.Context 46 stopCtxCancel context.CancelFunc 47 outgoingChan *SizeLimitedChannel 48 gotInitialAugs chan bool 49 clientCreator WebSocketClientCreator 50 client WebSocketClient 51 backoff Backoff 52 } 53 54 func NewAgentComWs(clientCreator WebSocketClientCreator, output Output, backoff Backoff, agentHost string, agentPort int, proxy string, 55 token string, labels map[string]string, printOnInitialConnection bool) (*agentComWs, error) { 56 var a agentComWs 57 var err error 58 a.stopCtx, a.stopCtxCancel = context.WithCancel(context.Background()) 59 a.setId() 60 a.agentURL, err = buildAgentURL(agentHost, agentPort) 61 if err != nil { 62 return nil, err 63 } 64 proxyUrl, err := buildProxyURL(proxy) 65 if err != nil { 66 logger.Logger().Fatalln("Bad proxy address: " + err.Error()) 67 return nil, err 68 } 69 a.proxy = proxyUrl 70 a.agentInfo, err = information.Collect(labels, "") 71 if err != nil { 72 return nil, err 73 } 74 a.agentInfo.AgentId = a.agentID 75 a.token = token 76 a.callbacks = map[string][]messageCallback{} 77 a.printOnInitialConnection = printOnInitialConnection 78 a.outgoingChan = NewSizeLimitedChannel() 79 a.gotInitialAugs = make(chan bool, 1) 80 a.clientCreator = clientCreator 81 a.backoff = backoff 82 a.output = output 83 a.output.SetAgentID(a.agentID) 84 85 return &a, nil 86 } 87 88 func buildProxyURL(proxy string) (*url.URL, error) { 89 if proxy == "" { 90 return nil, nil 91 } 92 if !strings.Contains(proxy, "://") { 93 proxy = "http://" + proxy 94 } 95 return url.Parse(proxy) 96 } 97 98 func buildAgentURL(agentHost string, agentPort int) (*url.URL, error) { 99 if agentHost != "" && !strings.Contains(agentHost, "://") { 100 agentHost = "ws://" + agentHost 101 } 102 urlString := fmt.Sprintf("%s:%d/v1", agentHost, agentPort) 103 return url.Parse(urlString) 104 } 105 106 func (a *agentComWs) setId() { 107 id, _ := uuid.New().MarshalBinary() 108 a.agentID = hex.EncodeToString(id) 109 } 110 111 func (a *agentComWs) on(messageName string, callback Callable, persistent bool) { 112 messageCallback := messageCallback{callback, persistent} 113 a.callbacks[messageName] = append(a.callbacks[messageName], messageCallback) 114 } 115 116 func (a *agentComWs) RegisterCallback(messageName string, callback Callable) { 117 a.on(messageName, callback, true) 118 } 119 120 func (a *agentComWs) ConnectToAgent() error { 121 connectionTimeoutCtx, cancelConnectionTimeoutCtx := context.WithTimeout(context.Background(), config.AgentComWsConfig().ConnectionTimeout) 122 defer cancelConnectionTimeoutCtx() 123 connErrorsChan := make(chan error) 124 125 utils.CreateRetryingGoroutine(a.stopCtx, func() { a.connectLoop(connErrorsChan) }) 126 127 select { 128 case <-connectionTimeoutCtx.Done(): 129 return rookoutErrors.NewRookConnectToControllerTimeout() 130 case err := <-connErrorsChan: 131 return err 132 } 133 } 134 135 func (a *agentComWs) Stop() { 136 a.output.StopSendingMessages() 137 select { 138 case <-a.stopCtx.Done(): 139 default: 140 a.stopCtxCancel() 141 } 142 143 if a.client != nil { 144 a.client.Close() 145 } 146 } 147 148 func (a *agentComWs) Flush() { 149 err := a.outgoingChan.Flush() 150 if err != nil { 151 logger.Logger().WithError(err).Info("Flush failed") 152 } 153 } 154 155 func (a *agentComWs) connectLoop(connErrorsChan chan error) { 156 for { 157 if !a.isRunning() { 158 return 159 } 160 161 logger.Logger().Info("Connecting to controller.") 162 connectionCtx, err := func() (context.Context, error) { 163 connectCtx, cancelConnectCtx := context.WithTimeout(a.stopCtx, config.AgentComWsConfig().ConnectTimeout) 164 defer cancelConnectCtx() 165 return a.connect(connectCtx) 166 }() 167 if err != nil { 168 logger.Logger().WithError(err).Info("Failed to connect to controller") 169 select { 170 case connErrorsChan <- err: 171 default: 172 } 173 a.backoff.AfterDisconnect(a.stopCtx) 174 continue 175 } 176 177 a.backoff.AfterConnect() 178 select { 179 case connErrorsChan <- nil: 180 default: 181 } 182 if a.printOnInitialConnection { 183 a.printOnInitialConnection = false 184 logger.QuietPrintln("[Rookout] Successfully connected to controller.") 185 logger.Logger().Debug("[Rookout] Agent ID is " + a.agentID) 186 } 187 logger.Logger().Info("Connected successfully to cloud controller") 188 logger.Logger().Info("Finished initialization") 189 190 select { 191 case <-a.stopCtx.Done(): 192 return 193 case <-connectionCtx.Done(): 194 a.client.Close() 195 logger.Logger().Info("Disconnected from controller") 196 a.backoff.AfterDisconnect(a.stopCtx) 197 } 198 } 199 } 200 201 func (a *agentComWs) connect(ctx context.Context) (context.Context, error) { 202 203 a.client = a.clientCreator(a.stopCtx, a.agentURL, a.token, a.proxy, a.agentInfo) 204 err := a.dialAndHandshake(ctx, a.client) 205 if err != nil { 206 return nil, err 207 } 208 209 210 a.on(common.MessageTypeInitAugs, func(any *anypb.Any) { 211 a.gotInitialAugs <- true 212 }, false) 213 214 connectionCtx, cancelConnectionCtx := context.WithCancel(a.client.GetConnectionCtx()) 215 utils.CreateGoroutine(func() { 216 defer cancelConnectionCtx() 217 a.sendLoop(connectionCtx, a.client) 218 }) 219 utils.CreateGoroutine(func() { 220 defer cancelConnectionCtx() 221 a.receiveLoop(connectionCtx, a.client) 222 }) 223 224 select { 225 case <-a.gotInitialAugs: 226 return connectionCtx, nil 227 case <-ctx.Done(): 228 return nil, rookoutErrors.NewContextEnded(ctx.Err()) 229 case <-connectionCtx.Done(): 230 return nil, rookoutErrors.NewContextEnded(connectionCtx.Err()) 231 } 232 } 233 234 func (a *agentComWs) sendLoop(ctx context.Context, client WebSocketClient) { 235 for { 236 237 buf := a.outgoingChan.Poll(ctx) 238 if buf == nil { 239 return 240 } 241 err := client.Send(ctx, buf) 242 if err != nil { 243 logger.Logger().WithError(err).Error("Failed when sending a message") 244 245 _ = a.outgoingChan.Offer(buf) 246 return 247 } 248 249 select { 250 case <-ctx.Done(): 251 return 252 default: 253 } 254 } 255 } 256 257 func (a *agentComWs) receiveLoop(ctx context.Context, client WebSocketClient) { 258 for { 259 260 buf, err := client.Receive(ctx) 261 if err != nil { 262 logger.Logger().WithError(err).Error("failed when receiving a message") 263 return 264 } 265 266 envelope, typeName, err := common.ParseEnvelope(buf) 267 if err != nil { 268 logger.Logger().WithError(err).Infof("failed to parse message from controller") 269 continue 270 } 271 a.handleIncomingMessage(typeName, envelope) 272 273 select { 274 case <-ctx.Done(): 275 return 276 default: 277 } 278 } 279 } 280 281 func (a *agentComWs) Send(buf []byte) rookoutErrors.RookoutError { 282 return a.outgoingChan.Offer(buf) 283 } 284 285 func (a *agentComWs) isRunning() bool { 286 select { 287 case <-a.stopCtx.Done(): 288 return false 289 default: 290 return true 291 } 292 } 293 294 func (a *agentComWs) dialAndHandshake(ctx context.Context, client WebSocketClient) error { 295 logger.Logger().Info("Attempting connection to cloud controller") 296 err := client.Dial(ctx) 297 if err != nil { 298 return err 299 } 300 logger.Logger().Info("Dial to cloud controller returned") 301 302 logger.Logger().Info("Starting handshake with cloud controller") 303 err = client.Handshake(ctx) 304 if err != nil { 305 logger.Logger().WithError(err).Error("websocket handshake failed") 306 client.Close() 307 return err 308 } 309 logger.Logger().Info("Handshake with cloud controller completed successfully") 310 return nil 311 } 312 313 func (a *agentComWs) handleIncomingMessage(typeName string, envelope *pb.Envelope) { 314 var persistentCallbacks []messageCallback 315 if callbacks, exists := a.callbacks[typeName]; exists { 316 for _, messageCB := range callbacks { 317 messageCB.callback(envelope.GetMsg()) 318 319 if messageCB.persistent { 320 persistentCallbacks = append(persistentCallbacks, messageCB) 321 } 322 } 323 a.callbacks[typeName] = persistentCallbacks 324 } else { 325 logger.Logger().Infof("Received unknown command: %s", typeName) 326 } 327 }