github.com/Rookout/GoSDK@v0.1.48/pkg/com_ws/websocket_client.go (about)

     1  package com_ws
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"net"
     7  	"net/http"
     8  	"net/url"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/Rookout/GoSDK/pkg/common"
    13  	"github.com/Rookout/GoSDK/pkg/config"
    14  	"github.com/Rookout/GoSDK/pkg/logger"
    15  	pb "github.com/Rookout/GoSDK/pkg/protobuf"
    16  	"github.com/Rookout/GoSDK/pkg/rookoutErrors"
    17  	"github.com/Rookout/GoSDK/pkg/utils"
    18  	"github.com/go-errors/errors"
    19  	gorilla "github.com/gorilla/websocket"
    20  )
    21  
    22  var dialer *gorilla.Dialer
    23  var dialerOnce sync.Once
    24  
    25  type WebSocketClientCreator func(context.Context, *url.URL, string, *url.URL, *pb.AgentInformation) WebSocketClient
    26  
    27  type WebSocketClient interface {
    28  	GetConnectionCtx() context.Context
    29  	Dial(context.Context) error
    30  	Handshake(context.Context) error
    31  	Receive(context.Context) ([]byte, error)
    32  	Send(context.Context, []byte) error
    33  	Close()
    34  }
    35  
    36  type webSocketClient struct {
    37  	agentURL            *url.URL
    38  	agentInfo           *pb.AgentInformation
    39  	conn                *gorilla.Conn
    40  	token               string
    41  	proxy               *url.URL
    42  	ConnectionCtx       context.Context
    43  	cancelConnectionCtx context.CancelFunc
    44  	writeMutex          sync.Mutex
    45  }
    46  
    47  func NewWebSocketClient(ctx context.Context, agentURL *url.URL, token string, proxy *url.URL, agentInfo *pb.AgentInformation) WebSocketClient {
    48  	client := &webSocketClient{
    49  		agentURL:  agentURL,
    50  		agentInfo: agentInfo,
    51  		conn:      &gorilla.Conn{},
    52  		token:     token,
    53  		proxy:     proxy,
    54  	}
    55  	client.ConnectionCtx, client.cancelConnectionCtx = context.WithCancel(ctx)
    56  	return client
    57  }
    58  
    59  func (w *webSocketClient) GetConnectionCtx() context.Context {
    60  	return w.ConnectionCtx
    61  }
    62  
    63  func (w *webSocketClient) Dial(ctx context.Context) error {
    64  	conn, httpRes, err := w.getWSDialer().DialContext(ctx, w.agentURL.String(), http.Header{"X-Rookout-Token": []string{w.token}})
    65  	if err != nil {
    66  		badToken := isHttpResponseBadToken(httpRes)
    67  		if badToken {
    68  			censoredToken := ""
    69  			if len(w.token) > 5 {
    70  				censoredToken = w.token[:5]
    71  			}
    72  
    73  			logger.Logger().Errorf("The Rookout token supplied (%s) is not valid; please check the token and try again", censoredToken)
    74  			return rookoutErrors.NewInvalidTokenError()
    75  		} else if isHttpResponseBadRequest(httpRes) {
    76  			return rookoutErrors.NewWebSocketError()
    77  		} else {
    78  			logger.Logger().Errorf("Failed to connect to controller (%s). err: %s", w.agentURL, err.Error())
    79  		}
    80  		return err
    81  	}
    82  	w.conn = conn
    83  
    84  	pingTimeout := config.WebSocketClientConfig().PingTimeout
    85  	if err = w.conn.SetReadDeadline(time.Now().Add(pingTimeout)); err != nil {
    86  		logger.Logger().WithError(err).Error("failed to set read deadline, closing connection")
    87  		w.Close()
    88  		return err
    89  	}
    90  	utils.CreateGoroutine(func() {
    91  		w.sendPingLoop()
    92  	})
    93  	w.conn.SetPongHandler(func(string) error {
    94  		err := w.conn.SetReadDeadline(time.Now().Add(pingTimeout))
    95  		if err != nil {
    96  			logger.Logger().WithError(err).Error("Failed to set read deadline on pong, closing connection")
    97  			w.Close()
    98  		}
    99  
   100  		return nil
   101  	})
   102  
   103  	return nil
   104  }
   105  
   106  func (w *webSocketClient) Handshake(ctx context.Context) error {
   107  	buf, err := common.WrapMsgInEnvelope(&pb.NewAgentMessage{AgentInfo: w.agentInfo})
   108  	if err != nil {
   109  		return err
   110  	}
   111  
   112  	err = w.Send(ctx, buf)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func (w *webSocketClient) Receive(ctx context.Context) ([]byte, error) {
   121  	
   122  	if deadline, ok := ctx.Deadline(); ok {
   123  		err := w.conn.SetReadDeadline(deadline)
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  	}
   128  	messageType, buf, err := w.conn.ReadMessage()
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	if messageType != gorilla.BinaryMessage {
   134  		return nil, errors.Errorf("unexpected message type, got %d\n", messageType)
   135  	}
   136  
   137  	return buf, nil
   138  }
   139  
   140  func (w *webSocketClient) sendPing(ctx context.Context) error {
   141  	err := w.sendMsg(ctx, gorilla.PingMessage, nil)
   142  	if err != nil {
   143  		return err
   144  	}
   145  	return nil
   146  }
   147  
   148  func (w *webSocketClient) sendPingLoop() {
   149  	defer w.cancelConnectionCtx()
   150  
   151  	pingTimer := time.NewTicker(config.WebSocketClientConfig().PingInterval)
   152  	defer drainTimer(pingTimer)
   153  	defer pingTimer.Stop()
   154  
   155  	for {
   156  		select {
   157  		case <-w.ConnectionCtx.Done():
   158  			return
   159  		case <-pingTimer.C:
   160  			err := func() error {
   161  				ctxTimeout, cancelFunc := context.WithTimeout(w.ConnectionCtx, config.WebSocketClientConfig().WriteTimeout)
   162  				defer cancelFunc()
   163  
   164  				return w.sendPing(ctxTimeout)
   165  			}()
   166  			if err != nil {
   167  				logger.Logger().WithError(err).Error("Failed writing ping")
   168  				return
   169  			}
   170  		}
   171  	}
   172  }
   173  
   174  func (w *webSocketClient) sendMsg(ctx context.Context, msgType int, data []byte) error {
   175  	w.writeMutex.Lock()
   176  	defer w.writeMutex.Unlock()
   177  
   178  	if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
   179  		err := w.conn.SetWriteDeadline(deadline)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  
   185  	if ctx.Err() != nil {
   186  		return ctx.Err()
   187  	}
   188  
   189  	return w.conn.WriteMessage(msgType, data)
   190  }
   191  
   192  func (w *webSocketClient) sendBinary(ctx context.Context, buf []byte) error {
   193  	err := w.sendMsg(ctx, gorilla.BinaryMessage, buf)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	return nil
   198  }
   199  
   200  func (w *webSocketClient) Send(ctx context.Context, buf []byte) error {
   201  	if ctx.Err() != nil {
   202  		return ctx.Err()
   203  	}
   204  
   205  	err := func() error {
   206  		ctxTimeout, cancelFunc := context.WithTimeout(ctx, config.WebSocketClientConfig().WriteTimeout)
   207  		defer cancelFunc()
   208  
   209  		return w.sendBinary(ctxTimeout, buf)
   210  	}()
   211  	if err != nil {
   212  		logger.Logger().WithError(err).Error("Failed writing message")
   213  		return err
   214  	}
   215  	return nil
   216  }
   217  
   218  func (w *webSocketClient) Close() {
   219  	_ = w.conn.Close()
   220  	w.cancelConnectionCtx()
   221  }
   222  
   223  func isHttpResponseBadToken(httpRes *http.Response) bool {
   224  	if httpRes == nil {
   225  		return false
   226  	}
   227  	return httpRes.StatusCode == http.StatusForbidden || httpRes.StatusCode == http.StatusUnauthorized
   228  }
   229  
   230  func isHttpResponseBadRequest(httpRes *http.Response) bool {
   231  	if httpRes == nil {
   232  		return false
   233  	}
   234  	return httpRes.StatusCode == http.StatusBadRequest
   235  }
   236  
   237  func drainTimer(timer *time.Ticker) {
   238  	select {
   239  	case <-timer.C:
   240  	default:
   241  	}
   242  }
   243  
   244  func (w *webSocketClient) getWSDialer() *gorilla.Dialer {
   245  	dialerOnce.Do(func() {
   246  		dialerTemp := *gorilla.DefaultDialer
   247  		netDialer := net.Dialer{Resolver: &net.Resolver{PreferGo: true}}
   248  		dialerTemp.NetDial = netDialer.Dial
   249  		dialer = &dialerTemp
   250  		dialerTemp.TLSClientConfig = &tls.Config{InsecureSkipVerify: config.WebSocketClientConfig().SkipSSLVerify}
   251  	})
   252  
   253  	if w.proxy != nil {
   254  		dialer.Proxy = func(_ *http.Request) (*url.URL, error) {
   255  			return w.proxy, nil
   256  		}
   257  		logger.Logger().Infof("Using proxy: %s", w.proxy.String())
   258  	}
   259  	return dialer
   260  }