github.com/Rookout/GoSDK@v0.1.48/pkg/com_ws/websocket_client.go (about) 1 package com_ws 2 3 import ( 4 "context" 5 "crypto/tls" 6 "net" 7 "net/http" 8 "net/url" 9 "sync" 10 "time" 11 12 "github.com/Rookout/GoSDK/pkg/common" 13 "github.com/Rookout/GoSDK/pkg/config" 14 "github.com/Rookout/GoSDK/pkg/logger" 15 pb "github.com/Rookout/GoSDK/pkg/protobuf" 16 "github.com/Rookout/GoSDK/pkg/rookoutErrors" 17 "github.com/Rookout/GoSDK/pkg/utils" 18 "github.com/go-errors/errors" 19 gorilla "github.com/gorilla/websocket" 20 ) 21 22 var dialer *gorilla.Dialer 23 var dialerOnce sync.Once 24 25 type WebSocketClientCreator func(context.Context, *url.URL, string, *url.URL, *pb.AgentInformation) WebSocketClient 26 27 type WebSocketClient interface { 28 GetConnectionCtx() context.Context 29 Dial(context.Context) error 30 Handshake(context.Context) error 31 Receive(context.Context) ([]byte, error) 32 Send(context.Context, []byte) error 33 Close() 34 } 35 36 type webSocketClient struct { 37 agentURL *url.URL 38 agentInfo *pb.AgentInformation 39 conn *gorilla.Conn 40 token string 41 proxy *url.URL 42 ConnectionCtx context.Context 43 cancelConnectionCtx context.CancelFunc 44 writeMutex sync.Mutex 45 } 46 47 func NewWebSocketClient(ctx context.Context, agentURL *url.URL, token string, proxy *url.URL, agentInfo *pb.AgentInformation) WebSocketClient { 48 client := &webSocketClient{ 49 agentURL: agentURL, 50 agentInfo: agentInfo, 51 conn: &gorilla.Conn{}, 52 token: token, 53 proxy: proxy, 54 } 55 client.ConnectionCtx, client.cancelConnectionCtx = context.WithCancel(ctx) 56 return client 57 } 58 59 func (w *webSocketClient) GetConnectionCtx() context.Context { 60 return w.ConnectionCtx 61 } 62 63 func (w *webSocketClient) Dial(ctx context.Context) error { 64 conn, httpRes, err := w.getWSDialer().DialContext(ctx, w.agentURL.String(), http.Header{"X-Rookout-Token": []string{w.token}}) 65 if err != nil { 66 badToken := isHttpResponseBadToken(httpRes) 67 if badToken { 68 censoredToken := "" 69 if len(w.token) > 5 { 70 censoredToken = w.token[:5] 71 } 72 73 logger.Logger().Errorf("The Rookout token supplied (%s) is not valid; please check the token and try again", censoredToken) 74 return rookoutErrors.NewInvalidTokenError() 75 } else if isHttpResponseBadRequest(httpRes) { 76 return rookoutErrors.NewWebSocketError() 77 } else { 78 logger.Logger().Errorf("Failed to connect to controller (%s). err: %s", w.agentURL, err.Error()) 79 } 80 return err 81 } 82 w.conn = conn 83 84 pingTimeout := config.WebSocketClientConfig().PingTimeout 85 if err = w.conn.SetReadDeadline(time.Now().Add(pingTimeout)); err != nil { 86 logger.Logger().WithError(err).Error("failed to set read deadline, closing connection") 87 w.Close() 88 return err 89 } 90 utils.CreateGoroutine(func() { 91 w.sendPingLoop() 92 }) 93 w.conn.SetPongHandler(func(string) error { 94 err := w.conn.SetReadDeadline(time.Now().Add(pingTimeout)) 95 if err != nil { 96 logger.Logger().WithError(err).Error("Failed to set read deadline on pong, closing connection") 97 w.Close() 98 } 99 100 return nil 101 }) 102 103 return nil 104 } 105 106 func (w *webSocketClient) Handshake(ctx context.Context) error { 107 buf, err := common.WrapMsgInEnvelope(&pb.NewAgentMessage{AgentInfo: w.agentInfo}) 108 if err != nil { 109 return err 110 } 111 112 err = w.Send(ctx, buf) 113 if err != nil { 114 return err 115 } 116 117 return nil 118 } 119 120 func (w *webSocketClient) Receive(ctx context.Context) ([]byte, error) { 121 122 if deadline, ok := ctx.Deadline(); ok { 123 err := w.conn.SetReadDeadline(deadline) 124 if err != nil { 125 return nil, err 126 } 127 } 128 messageType, buf, err := w.conn.ReadMessage() 129 if err != nil { 130 return nil, err 131 } 132 133 if messageType != gorilla.BinaryMessage { 134 return nil, errors.Errorf("unexpected message type, got %d\n", messageType) 135 } 136 137 return buf, nil 138 } 139 140 func (w *webSocketClient) sendPing(ctx context.Context) error { 141 err := w.sendMsg(ctx, gorilla.PingMessage, nil) 142 if err != nil { 143 return err 144 } 145 return nil 146 } 147 148 func (w *webSocketClient) sendPingLoop() { 149 defer w.cancelConnectionCtx() 150 151 pingTimer := time.NewTicker(config.WebSocketClientConfig().PingInterval) 152 defer drainTimer(pingTimer) 153 defer pingTimer.Stop() 154 155 for { 156 select { 157 case <-w.ConnectionCtx.Done(): 158 return 159 case <-pingTimer.C: 160 err := func() error { 161 ctxTimeout, cancelFunc := context.WithTimeout(w.ConnectionCtx, config.WebSocketClientConfig().WriteTimeout) 162 defer cancelFunc() 163 164 return w.sendPing(ctxTimeout) 165 }() 166 if err != nil { 167 logger.Logger().WithError(err).Error("Failed writing ping") 168 return 169 } 170 } 171 } 172 } 173 174 func (w *webSocketClient) sendMsg(ctx context.Context, msgType int, data []byte) error { 175 w.writeMutex.Lock() 176 defer w.writeMutex.Unlock() 177 178 if deadline, hasDeadline := ctx.Deadline(); hasDeadline { 179 err := w.conn.SetWriteDeadline(deadline) 180 if err != nil { 181 return err 182 } 183 } 184 185 if ctx.Err() != nil { 186 return ctx.Err() 187 } 188 189 return w.conn.WriteMessage(msgType, data) 190 } 191 192 func (w *webSocketClient) sendBinary(ctx context.Context, buf []byte) error { 193 err := w.sendMsg(ctx, gorilla.BinaryMessage, buf) 194 if err != nil { 195 return err 196 } 197 return nil 198 } 199 200 func (w *webSocketClient) Send(ctx context.Context, buf []byte) error { 201 if ctx.Err() != nil { 202 return ctx.Err() 203 } 204 205 err := func() error { 206 ctxTimeout, cancelFunc := context.WithTimeout(ctx, config.WebSocketClientConfig().WriteTimeout) 207 defer cancelFunc() 208 209 return w.sendBinary(ctxTimeout, buf) 210 }() 211 if err != nil { 212 logger.Logger().WithError(err).Error("Failed writing message") 213 return err 214 } 215 return nil 216 } 217 218 func (w *webSocketClient) Close() { 219 _ = w.conn.Close() 220 w.cancelConnectionCtx() 221 } 222 223 func isHttpResponseBadToken(httpRes *http.Response) bool { 224 if httpRes == nil { 225 return false 226 } 227 return httpRes.StatusCode == http.StatusForbidden || httpRes.StatusCode == http.StatusUnauthorized 228 } 229 230 func isHttpResponseBadRequest(httpRes *http.Response) bool { 231 if httpRes == nil { 232 return false 233 } 234 return httpRes.StatusCode == http.StatusBadRequest 235 } 236 237 func drainTimer(timer *time.Ticker) { 238 select { 239 case <-timer.C: 240 default: 241 } 242 } 243 244 func (w *webSocketClient) getWSDialer() *gorilla.Dialer { 245 dialerOnce.Do(func() { 246 dialerTemp := *gorilla.DefaultDialer 247 netDialer := net.Dialer{Resolver: &net.Resolver{PreferGo: true}} 248 dialerTemp.NetDial = netDialer.Dial 249 dialer = &dialerTemp 250 dialerTemp.TLSClientConfig = &tls.Config{InsecureSkipVerify: config.WebSocketClientConfig().SkipSSLVerify} 251 }) 252 253 if w.proxy != nil { 254 dialer.Proxy = func(_ *http.Request) (*url.URL, error) { 255 return w.proxy, nil 256 } 257 logger.Logger().Infof("Using proxy: %s", w.proxy.String()) 258 } 259 return dialer 260 }