github.com/Rookout/GoSDK@v0.1.48/pkg/com_ws/agent_com_ws.go (about)

     1  package com_ws
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"net/url"
     8  	"strings"
     9  
    10  	"github.com/Rookout/GoSDK/pkg/common"
    11  	"github.com/Rookout/GoSDK/pkg/config"
    12  	"github.com/Rookout/GoSDK/pkg/information"
    13  	"github.com/Rookout/GoSDK/pkg/logger"
    14  	pb "github.com/Rookout/GoSDK/pkg/protobuf"
    15  	"github.com/Rookout/GoSDK/pkg/rookoutErrors"
    16  	"github.com/Rookout/GoSDK/pkg/utils"
    17  	"github.com/google/uuid"
    18  	"google.golang.org/protobuf/types/known/anypb"
    19  )
    20  
    21  type Callable func(*anypb.Any)
    22  
    23  type AgentCom interface {
    24  	ConnectToAgent() error
    25  	RegisterCallback(string, Callable)
    26  	Send([]byte) rookoutErrors.RookoutError
    27  	Stop()
    28  	Flush()
    29  }
    30  
    31  type messageCallback struct {
    32  	callback   Callable
    33  	persistent bool
    34  }
    35  
    36  type agentComWs struct {
    37  	agentID                  string
    38  	output                   Output
    39  	agentURL                 *url.URL
    40  	proxy                    *url.URL
    41  	token                    string
    42  	callbacks                map[string][]messageCallback
    43  	agentInfo                *pb.AgentInformation
    44  	printOnInitialConnection bool
    45  	stopCtx                  context.Context
    46  	stopCtxCancel            context.CancelFunc
    47  	outgoingChan             *SizeLimitedChannel
    48  	gotInitialAugs           chan bool
    49  	clientCreator            WebSocketClientCreator
    50  	client                   WebSocketClient
    51  	backoff                  Backoff
    52  }
    53  
    54  func NewAgentComWs(clientCreator WebSocketClientCreator, output Output, backoff Backoff, agentHost string, agentPort int, proxy string,
    55  	token string, labels map[string]string, printOnInitialConnection bool) (*agentComWs, error) {
    56  	var a agentComWs
    57  	var err error
    58  	a.stopCtx, a.stopCtxCancel = context.WithCancel(context.Background())
    59  	a.setId()
    60  	a.agentURL, err = buildAgentURL(agentHost, agentPort)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	proxyUrl, err := buildProxyURL(proxy)
    65  	if err != nil {
    66  		logger.Logger().Fatalln("Bad proxy address: " + err.Error())
    67  		return nil, err
    68  	}
    69  	a.proxy = proxyUrl
    70  	a.agentInfo, err = information.Collect(labels, "")
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	a.agentInfo.AgentId = a.agentID
    75  	a.token = token
    76  	a.callbacks = map[string][]messageCallback{}
    77  	a.printOnInitialConnection = printOnInitialConnection
    78  	a.outgoingChan = NewSizeLimitedChannel()
    79  	a.gotInitialAugs = make(chan bool, 1)
    80  	a.clientCreator = clientCreator
    81  	a.backoff = backoff
    82  	a.output = output
    83  	a.output.SetAgentID(a.agentID)
    84  
    85  	return &a, nil
    86  }
    87  
    88  func buildProxyURL(proxy string) (*url.URL, error) {
    89  	if proxy == "" {
    90  		return nil, nil
    91  	}
    92  	if !strings.Contains(proxy, "://") {
    93  		proxy = "http://" + proxy
    94  	}
    95  	return url.Parse(proxy)
    96  }
    97  
    98  func buildAgentURL(agentHost string, agentPort int) (*url.URL, error) {
    99  	if agentHost != "" && !strings.Contains(agentHost, "://") {
   100  		agentHost = "ws://" + agentHost
   101  	}
   102  	urlString := fmt.Sprintf("%s:%d/v1", agentHost, agentPort)
   103  	return url.Parse(urlString)
   104  }
   105  
   106  func (a *agentComWs) setId() {
   107  	id, _ := uuid.New().MarshalBinary()
   108  	a.agentID = hex.EncodeToString(id)
   109  }
   110  
   111  func (a *agentComWs) on(messageName string, callback Callable, persistent bool) {
   112  	messageCallback := messageCallback{callback, persistent}
   113  	a.callbacks[messageName] = append(a.callbacks[messageName], messageCallback)
   114  }
   115  
   116  func (a *agentComWs) RegisterCallback(messageName string, callback Callable) {
   117  	a.on(messageName, callback, true)
   118  }
   119  
   120  func (a *agentComWs) ConnectToAgent() error {
   121  	connectionTimeoutCtx, cancelConnectionTimeoutCtx := context.WithTimeout(context.Background(), config.AgentComWsConfig().ConnectionTimeout)
   122  	defer cancelConnectionTimeoutCtx()
   123  	connErrorsChan := make(chan error)
   124  
   125  	utils.CreateRetryingGoroutine(a.stopCtx, func() { a.connectLoop(connErrorsChan) })
   126  
   127  	select {
   128  	case <-connectionTimeoutCtx.Done():
   129  		return rookoutErrors.NewRookConnectToControllerTimeout()
   130  	case err := <-connErrorsChan:
   131  		return err
   132  	}
   133  }
   134  
   135  func (a *agentComWs) Stop() {
   136  	a.output.StopSendingMessages()
   137  	select {
   138  	case <-a.stopCtx.Done():
   139  	default:
   140  		a.stopCtxCancel()
   141  	}
   142  
   143  	if a.client != nil {
   144  		a.client.Close()
   145  	}
   146  }
   147  
   148  func (a *agentComWs) Flush() {
   149  	err := a.outgoingChan.Flush()
   150  	if err != nil {
   151  		logger.Logger().WithError(err).Info("Flush failed")
   152  	}
   153  }
   154  
   155  func (a *agentComWs) connectLoop(connErrorsChan chan error) {
   156  	for {
   157  		if !a.isRunning() {
   158  			return
   159  		}
   160  
   161  		logger.Logger().Info("Connecting to controller.")
   162  		connectionCtx, err := func() (context.Context, error) {
   163  			connectCtx, cancelConnectCtx := context.WithTimeout(a.stopCtx, config.AgentComWsConfig().ConnectTimeout)
   164  			defer cancelConnectCtx()
   165  			return a.connect(connectCtx)
   166  		}()
   167  		if err != nil {
   168  			logger.Logger().WithError(err).Info("Failed to connect to controller")
   169  			select {
   170  			case connErrorsChan <- err:
   171  			default:
   172  			}
   173  			a.backoff.AfterDisconnect(a.stopCtx)
   174  			continue
   175  		}
   176  
   177  		a.backoff.AfterConnect()
   178  		select {
   179  		case connErrorsChan <- nil:
   180  		default:
   181  		}
   182  		if a.printOnInitialConnection {
   183  			a.printOnInitialConnection = false
   184  			logger.QuietPrintln("[Rookout] Successfully connected to controller.")
   185  			logger.Logger().Debug("[Rookout] Agent ID is " + a.agentID)
   186  		}
   187  		logger.Logger().Info("Connected successfully to cloud controller")
   188  		logger.Logger().Info("Finished initialization")
   189  
   190  		select {
   191  		case <-a.stopCtx.Done():
   192  			return
   193  		case <-connectionCtx.Done():
   194  			a.client.Close()
   195  			logger.Logger().Info("Disconnected from controller")
   196  			a.backoff.AfterDisconnect(a.stopCtx)
   197  		}
   198  	}
   199  }
   200  
   201  func (a *agentComWs) connect(ctx context.Context) (context.Context, error) {
   202  	
   203  	a.client = a.clientCreator(a.stopCtx, a.agentURL, a.token, a.proxy, a.agentInfo)
   204  	err := a.dialAndHandshake(ctx, a.client)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	
   210  	a.on(common.MessageTypeInitAugs, func(any *anypb.Any) {
   211  		a.gotInitialAugs <- true
   212  	}, false)
   213  
   214  	connectionCtx, cancelConnectionCtx := context.WithCancel(a.client.GetConnectionCtx())
   215  	utils.CreateGoroutine(func() {
   216  		defer cancelConnectionCtx()
   217  		a.sendLoop(connectionCtx, a.client)
   218  	})
   219  	utils.CreateGoroutine(func() {
   220  		defer cancelConnectionCtx()
   221  		a.receiveLoop(connectionCtx, a.client)
   222  	})
   223  
   224  	select {
   225  	case <-a.gotInitialAugs:
   226  		return connectionCtx, nil
   227  	case <-ctx.Done():
   228  		return nil, rookoutErrors.NewContextEnded(ctx.Err())
   229  	case <-connectionCtx.Done():
   230  		return nil, rookoutErrors.NewContextEnded(connectionCtx.Err())
   231  	}
   232  }
   233  
   234  func (a *agentComWs) sendLoop(ctx context.Context, client WebSocketClient) {
   235  	for {
   236  		
   237  		buf := a.outgoingChan.Poll(ctx)
   238  		if buf == nil {
   239  			return
   240  		}
   241  		err := client.Send(ctx, buf)
   242  		if err != nil {
   243  			logger.Logger().WithError(err).Error("Failed when sending a message")
   244  			
   245  			_ = a.outgoingChan.Offer(buf)
   246  			return
   247  		}
   248  
   249  		select {
   250  		case <-ctx.Done():
   251  			return
   252  		default:
   253  		}
   254  	}
   255  }
   256  
   257  func (a *agentComWs) receiveLoop(ctx context.Context, client WebSocketClient) {
   258  	for {
   259  		
   260  		buf, err := client.Receive(ctx)
   261  		if err != nil {
   262  			logger.Logger().WithError(err).Error("failed when receiving a message")
   263  			return
   264  		}
   265  
   266  		envelope, typeName, err := common.ParseEnvelope(buf)
   267  		if err != nil {
   268  			logger.Logger().WithError(err).Infof("failed to parse message from controller")
   269  			continue
   270  		}
   271  		a.handleIncomingMessage(typeName, envelope)
   272  
   273  		select {
   274  		case <-ctx.Done():
   275  			return
   276  		default:
   277  		}
   278  	}
   279  }
   280  
   281  func (a *agentComWs) Send(buf []byte) rookoutErrors.RookoutError {
   282  	return a.outgoingChan.Offer(buf)
   283  }
   284  
   285  func (a *agentComWs) isRunning() bool {
   286  	select {
   287  	case <-a.stopCtx.Done():
   288  		return false
   289  	default:
   290  		return true
   291  	}
   292  }
   293  
   294  func (a *agentComWs) dialAndHandshake(ctx context.Context, client WebSocketClient) error {
   295  	logger.Logger().Info("Attempting connection to cloud controller")
   296  	err := client.Dial(ctx)
   297  	if err != nil {
   298  		return err
   299  	}
   300  	logger.Logger().Info("Dial to cloud controller returned")
   301  
   302  	logger.Logger().Info("Starting handshake with cloud controller")
   303  	err = client.Handshake(ctx)
   304  	if err != nil {
   305  		logger.Logger().WithError(err).Error("websocket handshake failed")
   306  		client.Close()
   307  		return err
   308  	}
   309  	logger.Logger().Info("Handshake with cloud controller completed successfully")
   310  	return nil
   311  }
   312  
   313  func (a *agentComWs) handleIncomingMessage(typeName string, envelope *pb.Envelope) {
   314  	var persistentCallbacks []messageCallback
   315  	if callbacks, exists := a.callbacks[typeName]; exists {
   316  		for _, messageCB := range callbacks {
   317  			messageCB.callback(envelope.GetMsg())
   318  
   319  			if messageCB.persistent {
   320  				persistentCallbacks = append(persistentCallbacks, messageCB)
   321  			}
   322  		}
   323  		a.callbacks[typeName] = persistentCallbacks
   324  	} else {
   325  		logger.Logger().Infof("Received unknown command: %s", typeName)
   326  	}
   327  }