go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket_client.go (about) 1 package onet 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net" 9 "net/http" 10 "net/url" 11 "reflect" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/gorilla/websocket" 18 "go.dedis.ch/onet/v3/log" 19 "go.dedis.ch/onet/v3/network" 20 "go.dedis.ch/protobuf" 21 "golang.org/x/xerrors" 22 ) 23 24 // Client is a struct used to communicate with a remote Service running on a 25 // onet.Server. Using Send it can connect to multiple remote Servers. 26 type Client struct { 27 service string 28 connections map[destination]*websocket.Conn 29 connectionsLock map[destination]*sync.Mutex 30 suite network.Suite 31 // if not nil, use TLS 32 TLSClientConfig *tls.Config 33 // whether to keep the connection 34 keep bool 35 rx uint64 36 tx uint64 37 // How long to wait for a reply 38 ReadTimeout time.Duration 39 // How long to wait to open a connection 40 HandshakeTimeout time.Duration 41 sync.Mutex 42 } 43 44 // NewClient returns a client using the service s. On the first Send, the 45 // connection will be started, until Close is called. 46 func NewClient(suite network.Suite, s string) *Client { 47 return &Client{ 48 service: s, 49 connections: make(map[destination]*websocket.Conn), 50 connectionsLock: make(map[destination]*sync.Mutex), 51 suite: suite, 52 ReadTimeout: time.Second * 60, 53 HandshakeTimeout: time.Second * 5, 54 } 55 } 56 57 // NewClientKeep returns a Client that doesn't close the connection between 58 // two messages if it's the same server. 59 func NewClientKeep(suite network.Suite, s string) *Client { 60 cl := NewClient(suite, s) 61 cl.keep = true 62 return cl 63 } 64 65 // Suite returns the cryptographic suite in use on this connection. 66 func (c *Client) Suite() network.Suite { 67 return c.suite 68 } 69 70 func (c *Client) closeSingleUseConn(dst *network.ServerIdentity, path string) { 71 dest := destination{dst, path} 72 if !c.keep { 73 if err := c.closeConn(dest); err != nil { 74 log.Errorf("error while closing the connection to %v : %+v\n", 75 dest, err) 76 } 77 } 78 } 79 80 func (c *Client) newConnIfNotExist(dst *network.ServerIdentity, path string) (*websocket.Conn, *sync.Mutex, error) { 81 var err error 82 83 // c.Lock protects the connections and connectionsLock map 84 // c.connectionsLock is held as long as the connection is in use - to avoid that two 85 // processes send data over the same websocket concurrently. 86 dest := destination{dst, path} 87 c.Lock() 88 connLock, exists := c.connectionsLock[dest] 89 if !exists { 90 c.connectionsLock[dest] = &sync.Mutex{} 91 connLock = c.connectionsLock[dest] 92 } 93 c.Unlock() 94 // if connLock.Lock is done while the c.Lock is still held, the next process trying to 95 // use the same connection will deadlock, as it'll wait for connLock to be released, 96 // while the other process will wait for c.Unlock to be released. 97 connLock.Lock() 98 c.Lock() 99 conn, connected := c.connections[dest] 100 c.Unlock() 101 102 if !connected { 103 d := &websocket.Dialer{} 104 d.TLSClientConfig = c.TLSClientConfig 105 106 var serverURL string 107 var header http.Header 108 109 // If the URL is in the dst, then use it. 110 if dst.URL != "" { 111 u, err := url.Parse(dst.URL) 112 if err != nil { 113 connLock.Unlock() 114 return nil, nil, xerrors.Errorf("parsing url: %v", err) 115 } 116 if u.Scheme == "https" { 117 u.Scheme = "wss" 118 } else { 119 u.Scheme = "ws" 120 } 121 if !strings.HasSuffix(u.Path, "/") { 122 u.Path += "/" 123 } 124 u.Path += c.service + "/" + path 125 serverURL = u.String() 126 header = http.Header{"Origin": []string{dst.URL}} 127 } else { 128 // Open connection to service. 129 hp, err := getWSHostPort(dst, false) 130 if err != nil { 131 connLock.Unlock() 132 return nil, nil, xerrors.Errorf("parsing port: %v", err) 133 } 134 135 var wsProtocol string 136 var protocol string 137 138 // The old hacky way of deciding if this server has HTTPS or not: 139 // the client somehow magically knows and tells onet by setting 140 // c.TLSClientConfig to a non-nil value. 141 if c.TLSClientConfig != nil { 142 wsProtocol = "wss" 143 protocol = "https" 144 } else { 145 wsProtocol = "ws" 146 protocol = "http" 147 } 148 serverURL = fmt.Sprintf("%s://%s/%s/%s", wsProtocol, hp, c.service, path) 149 header = http.Header{"Origin": []string{protocol + "://" + hp}} 150 } 151 152 // Re-try to connect in case the websocket is just about to start 153 d.HandshakeTimeout = c.HandshakeTimeout 154 for a := 0; a < network.MaxRetryConnect; a++ { 155 conn, _, err = d.Dial(serverURL, header) 156 if err == nil { 157 break 158 } 159 time.Sleep(network.WaitRetry) 160 } 161 if err != nil { 162 connLock.Unlock() 163 return nil, nil, xerrors.Errorf("dial: %v", err) 164 } 165 c.Lock() 166 c.connections[dest] = conn 167 c.Unlock() 168 } 169 return conn, connLock, nil 170 } 171 172 // Send will marshal the message into a ClientRequest message and send it. It has a 173 // very simple parallel sending mechanism included: if the send goes to a new or an 174 // idle connection, the message is sent right away. If the current connection is busy, 175 // it waits for it to be free. 176 func (c *Client) Send(dst *network.ServerIdentity, path string, buf []byte) ([]byte, error) { 177 conn, connLock, err := c.newConnIfNotExist(dst, path) 178 if err != nil { 179 return nil, xerrors.Errorf("new connection: %v", err) 180 } 181 defer connLock.Unlock() 182 183 var rcv []byte 184 defer func() { 185 c.Lock() 186 c.closeSingleUseConn(dst, path) 187 c.rx += uint64(len(rcv)) 188 c.tx += uint64(len(buf)) 189 c.Unlock() 190 }() 191 192 log.Lvlf4("Sending %x to %s/%s", buf, c.service, path) 193 if err := conn.WriteMessage(websocket.BinaryMessage, buf); err != nil { 194 return nil, xerrors.Errorf("connection write: %v", err) 195 } 196 197 if err := conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil { 198 return nil, xerrors.Errorf("read deadline: %v", err) 199 } 200 _, rcv, err = conn.ReadMessage() 201 if err != nil { 202 return nil, xerrors.Errorf("connection read: %v", err) 203 } 204 return rcv, nil 205 } 206 207 // SendProtobuf wraps protobuf.(En|De)code over the Client.Send-function. It 208 // takes the destination, a pointer to a msg-structure that will be 209 // protobuf-encoded and sent over the websocket. If ret is non-nil, it 210 // has to be a pointer to the struct that is sent back to the 211 // client. If there is no error, the ret-structure is filled with the 212 // data from the service. 213 func (c *Client) SendProtobuf(dst *network.ServerIdentity, msg interface{}, ret interface{}) error { 214 buf, err := protobuf.Encode(msg) 215 if err != nil { 216 return xerrors.Errorf("encoding: %v", err) 217 } 218 path := strings.Split(reflect.TypeOf(msg).String(), ".")[1] 219 reply, err := c.Send(dst, path, buf) 220 if err != nil { 221 return xerrors.Errorf("sending: %v", err) 222 } 223 if ret != nil { 224 err := protobuf.DecodeWithConstructors(reply, ret, network.DefaultConstructors(c.suite)) 225 if err != nil { 226 return xerrors.Errorf("decoding: %v", err) 227 } 228 } 229 return nil 230 } 231 232 // ParallelOptions defines how SendProtobufParallel behaves. Each field has a default 233 // value that will be used if 'nil' is passed to SendProtobufParallel. For integers, 234 // the default will also be used if the integer = 0. 235 type ParallelOptions struct { 236 // Parallel indicates how many requests are sent in parallel. 237 // Default: half of all nodes in the roster 238 Parallel int 239 // AskNodes indicates how many requests are sent in total. 240 // Default: all nodes in the roster, except if StartNodes is set > 0 241 AskNodes int 242 // StartNode indicates where to start in the roster. If StartNode is > 0 and < len(roster), 243 // but AskNodes is 0, then AskNodes will be set to len(Roster)-StartNode. 244 // Default: 0 245 StartNode int 246 // QuitError - if true, the first error received will be returned. 247 // Default: false 248 QuitError bool 249 // IgnoreNodes is a set of nodes that will not be contacted. They are counted towards 250 // AskNodes and StartNode, but not contacted. 251 // Default: false 252 IgnoreNodes []*network.ServerIdentity 253 // DontShuffle - if true, the nodes will be contacted in the same order as given in the Roster. 254 // StartNode will be applied before shuffling. 255 // Default: false 256 DontShuffle bool 257 } 258 259 // GetList returns how many requests to start in parallel and a channel of nodes to be used. 260 // If po == nil, it uses default values. 261 func (po *ParallelOptions) GetList(nodes []*network.ServerIdentity) (parallel int, nodesChan chan *network.ServerIdentity) { 262 // Default values 263 parallel = (len(nodes) + 1) / 2 264 askNodes := len(nodes) 265 startNode := 0 266 var ignoreNodes []*network.ServerIdentity 267 var perm []int 268 if po != nil { 269 if po.Parallel > 0 && po.Parallel < parallel { 270 parallel = po.Parallel 271 } 272 if po.StartNode > 0 && po.StartNode < len(nodes) { 273 startNode = po.StartNode 274 askNodes -= startNode 275 } 276 if po.AskNodes > 0 && po.AskNodes < len(nodes) { 277 askNodes = po.AskNodes 278 } 279 if askNodes < parallel { 280 parallel = askNodes 281 } 282 if po.DontShuffle { 283 for i := range nodes { 284 perm = append(perm, i) 285 } 286 } 287 ignoreNodes = po.IgnoreNodes 288 } 289 if len(perm) == 0 { 290 perm = rand.Perm(len(nodes)) 291 } 292 293 nodesChan = make(chan *network.ServerIdentity, askNodes) 294 for i := range nodes { 295 addNode := true 296 node := nodes[(startNode+perm[i])%len(nodes)] 297 for _, ignore := range ignoreNodes { 298 if node.Equal(ignore) { 299 addNode = false 300 break 301 } 302 } 303 if addNode { 304 nodesChan <- node 305 } 306 if len(nodesChan) == askNodes { 307 break 308 } 309 } 310 return parallel, nodesChan 311 } 312 313 // Quit return false if po == nil, or the value in po.QuitError. 314 func (po *ParallelOptions) Quit() bool { 315 if po == nil { 316 return false 317 } 318 return po.QuitError 319 } 320 321 // Decoder is a function that takes the data and the interface to fill in 322 // as input and decodes the message. 323 type Decoder func(data []byte, ret interface{}) error 324 325 // SendProtobufParallelWithDecoder sends the msg to a set of nodes in parallel and returns the first successful 326 // answer. If all nodes return an error, only the first error is returned. 327 // The behaviour of this method can be changed using the ParallelOptions argument. It is kept 328 // as a structure for future enhancements. If opt is nil, then standard values will be taken. 329 func (c *Client) SendProtobufParallelWithDecoder(nodes []*network.ServerIdentity, msg interface{}, ret interface{}, 330 opt *ParallelOptions, decoder Decoder) (*network.ServerIdentity, error) { 331 buf, err := protobuf.Encode(msg) 332 if err != nil { 333 return nil, xerrors.Errorf("decoding: %v", err) 334 } 335 path := strings.Split(reflect.TypeOf(msg).String(), ".")[1] 336 337 parallel, nodesChan := opt.GetList(nodes) 338 nodesNbr := len(nodesChan) 339 errChan := make(chan error, nodesNbr) 340 decodedChan := make(chan *network.ServerIdentity, 1) 341 var decoding sync.Mutex 342 done := make(chan bool) 343 344 contactNode := func() bool { 345 select { 346 case <-done: 347 return false 348 default: 349 select { 350 case node := <-nodesChan: 351 log.Lvlf3("Asking %T from: %v - %v", msg, node.Address, node.URL) 352 reply, err := c.Send(node, path, buf) 353 if err != nil { 354 log.Lvl2("Error while sending to node:", node, err) 355 errChan <- err 356 } else { 357 log.Lvl3("Done asking node", node, len(reply)) 358 decoding.Lock() 359 select { 360 case <-done: 361 default: 362 if ret != nil { 363 err := decoder(reply, ret) 364 if err != nil { 365 errChan <- err 366 break 367 } 368 } 369 decodedChan <- node 370 close(done) 371 } 372 decoding.Unlock() 373 } 374 default: 375 return false 376 } 377 } 378 return true 379 } 380 381 // Producer that puts messages in errChan and replyChan 382 for g := 0; g < parallel; g++ { 383 go func() { 384 for { 385 if !contactNode() { 386 return 387 } 388 } 389 }() 390 } 391 392 var errs []error 393 for len(errs) < nodesNbr { 394 select { 395 case node := <-decodedChan: 396 return node, nil 397 case err := <-errChan: 398 if opt.Quit() { 399 close(done) 400 return nil, err 401 } 402 errs = append(errs, xerrors.Errorf("sending: %v", err)) 403 } 404 } 405 406 return nil, errs[0] 407 } 408 409 // SendProtobufParallel sends the msg to a set of nodes in parallel and returns the first successful 410 // answer. If all nodes return an error, only the first error is returned. 411 // The behaviour of this method can be changed using the ParallelOptions argument. It is kept 412 // as a structure for future enhancements. If opt is nil, then standard values will be taken. 413 func (c *Client) SendProtobufParallel(nodes []*network.ServerIdentity, msg interface{}, ret interface{}, 414 opt *ParallelOptions) (*network.ServerIdentity, error) { 415 si, err := c.SendProtobufParallelWithDecoder(nodes, msg, ret, opt, protobuf.Decode) 416 if err != nil { 417 return nil, xerrors.Errorf("sending: %v", err) 418 } 419 return si, nil 420 } 421 422 // StreamingConn allows clients to read from it without sending additional 423 // requests. 424 type StreamingConn struct { 425 conn *websocket.Conn 426 suite network.Suite 427 } 428 429 // StreamingReadOpts contains options for the ReadMessageWithOpts. It allows us 430 // to add new options in the future without making breaking changes. 431 type StreamingReadOpts struct { 432 Deadline time.Time 433 } 434 435 // ReadMessage read more data from the connection, it will block if there are 436 // no messages. 437 func (c *StreamingConn) ReadMessage(ret interface{}) error { 438 opts := StreamingReadOpts{ 439 Deadline: time.Now().Add(5 * time.Minute), 440 } 441 442 return c.readMsg(ret, opts) 443 } 444 445 // ReadMessageWithOpts does the same as ReadMessage and allows to pass options. 446 func (c *StreamingConn) ReadMessageWithOpts(ret interface{}, opts StreamingReadOpts) error { 447 return c.readMsg(ret, opts) 448 } 449 450 func (c *StreamingConn) readMsg(ret interface{}, opts StreamingReadOpts) error { 451 if err := c.conn.SetReadDeadline(opts.Deadline); err != nil { 452 return xerrors.Errorf("read deadline: %v", err) 453 } 454 // No need to add bytes to counter here because this function is only 455 // called by the client. 456 _, buf, err := c.conn.ReadMessage() 457 if err != nil { 458 return xerrors.Errorf("connection read: %w", err) 459 } 460 err = protobuf.DecodeWithConstructors(buf, ret, network.DefaultConstructors(c.suite)) 461 if err != nil { 462 return xerrors.Errorf("decoding: %v", err) 463 } 464 return nil 465 } 466 467 // Ping sends a ping message. Data can be nil. 468 func (c *StreamingConn) Ping(data []byte, deadline time.Time) error { 469 return c.conn.WriteControl(websocket.PingMessage, data, deadline) 470 } 471 472 // Stream will send a request to start streaming, it returns a connection where 473 // the client can continue to read values from it. 474 func (c *Client) Stream(dst *network.ServerIdentity, msg interface{}) (StreamingConn, error) { 475 buf, err := protobuf.Encode(msg) 476 if err != nil { 477 return StreamingConn{}, err 478 } 479 path := strings.Split(reflect.TypeOf(msg).String(), ".")[1] 480 481 conn, connLock, err := c.newConnIfNotExist(dst, path) 482 if err != nil { 483 return StreamingConn{}, err 484 } 485 defer connLock.Unlock() 486 err = conn.WriteMessage(websocket.BinaryMessage, buf) 487 if err != nil { 488 return StreamingConn{}, err 489 } 490 c.Lock() 491 c.tx += uint64(len(buf)) 492 c.Unlock() 493 return StreamingConn{conn, c.Suite()}, nil 494 } 495 496 // SendToAll sends a message to all ServerIdentities of the Roster and returns 497 // all errors encountered concatenated together as a string. 498 func (c *Client) SendToAll(dst *Roster, path string, buf []byte) ([][]byte, error) { 499 msgs := make([][]byte, len(dst.List)) 500 var errstrs []string 501 for i, e := range dst.List { 502 var err error 503 msgs[i], err = c.Send(e, path, buf) 504 if err != nil { 505 errstrs = append(errstrs, fmt.Sprint(e.String(), err.Error())) 506 } 507 } 508 var err error 509 if len(errstrs) > 0 { 510 err = xerrors.New(strings.Join(errstrs, "\n")) 511 } 512 return msgs, err 513 } 514 515 // Close sends a close-command to all open connections and returns nil if no 516 // errors occurred or all errors encountered concatenated together as a string. 517 func (c *Client) Close() error { 518 c.Lock() 519 defer c.Unlock() 520 var errstrs []string 521 for dest := range c.connections { 522 connLock := c.connectionsLock[dest] 523 c.Unlock() 524 connLock.Lock() 525 c.Lock() 526 if err := c.closeConn(dest); err != nil { 527 errstrs = append(errstrs, err.Error()) 528 } 529 connLock.Unlock() 530 } 531 var err error 532 if len(errstrs) > 0 { 533 err = xerrors.New(strings.Join(errstrs, "\n")) 534 } 535 return err 536 } 537 538 // closeConn sends a close-command to the connection. Correct locking must be done 539 // befor calling this method. 540 func (c *Client) closeConn(dst destination) error { 541 conn, ok := c.connections[dst] 542 if ok { 543 delete(c.connections, dst) 544 err := conn.WriteMessage(websocket.CloseMessage, 545 websocket.FormatCloseMessage(websocket.CloseNormalClosure, "client closed")) 546 if err != nil { 547 log.Error("Error while sending closing type:", err) 548 } 549 return conn.Close() 550 } 551 return nil 552 } 553 554 // Tx returns the number of bytes transmitted by this Client. It implements 555 // the monitor.CounterIOMeasure interface. 556 func (c *Client) Tx() uint64 { 557 c.Lock() 558 defer c.Unlock() 559 return c.tx 560 } 561 562 // Rx returns the number of bytes read by this Client. It implements 563 // the monitor.CounterIOMeasure interface. 564 func (c *Client) Rx() uint64 { 565 c.Lock() 566 defer c.Unlock() 567 return c.rx 568 } 569 570 // schemeToPort returns the port corresponding to the given scheme, much like netdb. 571 func schemeToPort(name string) (uint16, error) { 572 switch name { 573 case "http": 574 return 80, nil 575 case "https": 576 return 443, nil 577 default: 578 return 0, fmt.Errorf("no such scheme: %v", name) 579 } 580 } 581 582 // getWSHostPort returns the hostname:port to bind to with WebSocket. 583 // If global is true, the hostname is set to the unspecified 0.0.0.0-address. 584 // If si.URL is "", the url uses the hostname and port+1 of si.Address. 585 func getWSHostPort(si *network.ServerIdentity, global bool) (string, error) { 586 const portBitSize = 16 587 const portNumericBase = 10 588 589 var hostname string 590 var port uint16 591 592 if si.URL != "" { 593 url, err := url.Parse(si.URL) 594 if err != nil { 595 return "", fmt.Errorf("unable to parse URL: %v", err) 596 } 597 if !url.IsAbs() { 598 return "", errors.New("URL is not absolute") 599 } 600 601 protocolPort, err := schemeToPort(url.Scheme) 602 if err != nil { 603 return "", fmt.Errorf("unable to translate URL' scheme to port: %v", err) 604 } 605 606 portStr := url.Port() 607 if portStr == "" { 608 port = protocolPort 609 } else { 610 portRaw, err := strconv.ParseUint(portStr, portNumericBase, portBitSize) 611 if err != nil { 612 return "", fmt.Errorf("URL doesn't contain a valid port: %v", err) 613 } 614 port = uint16(portRaw) 615 } 616 hostname = url.Hostname() 617 } else { 618 portRaw, err := strconv.ParseUint(si.Address.Port(), portNumericBase, portBitSize) 619 if err != nil { 620 return "", fmt.Errorf("unable to parse port of Address as int: %v", err) 621 } 622 port = uint16(portRaw + 1) 623 hostname = si.Address.Host() 624 } 625 626 if global { 627 hostname = "0.0.0.0" 628 } 629 630 portFormatted := strconv.FormatUint(uint64(port), 10) 631 return net.JoinHostPort(hostname, portFormatted), nil 632 }