go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket_client.go (about)

     1  package onet
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/gorilla/websocket"
    18  	"go.dedis.ch/onet/v3/log"
    19  	"go.dedis.ch/onet/v3/network"
    20  	"go.dedis.ch/protobuf"
    21  	"golang.org/x/xerrors"
    22  )
    23  
    24  // Client is a struct used to communicate with a remote Service running on a
    25  // onet.Server. Using Send it can connect to multiple remote Servers.
    26  type Client struct {
    27  	service         string
    28  	connections     map[destination]*websocket.Conn
    29  	connectionsLock map[destination]*sync.Mutex
    30  	suite           network.Suite
    31  	// if not nil, use TLS
    32  	TLSClientConfig *tls.Config
    33  	// whether to keep the connection
    34  	keep bool
    35  	rx   uint64
    36  	tx   uint64
    37  	// How long to wait for a reply
    38  	ReadTimeout time.Duration
    39  	// How long to wait to open a connection
    40  	HandshakeTimeout time.Duration
    41  	sync.Mutex
    42  }
    43  
    44  // NewClient returns a client using the service s. On the first Send, the
    45  // connection will be started, until Close is called.
    46  func NewClient(suite network.Suite, s string) *Client {
    47  	return &Client{
    48  		service:          s,
    49  		connections:      make(map[destination]*websocket.Conn),
    50  		connectionsLock:  make(map[destination]*sync.Mutex),
    51  		suite:            suite,
    52  		ReadTimeout:      time.Second * 60,
    53  		HandshakeTimeout: time.Second * 5,
    54  	}
    55  }
    56  
    57  // NewClientKeep returns a Client that doesn't close the connection between
    58  // two messages if it's the same server.
    59  func NewClientKeep(suite network.Suite, s string) *Client {
    60  	cl := NewClient(suite, s)
    61  	cl.keep = true
    62  	return cl
    63  }
    64  
    65  // Suite returns the cryptographic suite in use on this connection.
    66  func (c *Client) Suite() network.Suite {
    67  	return c.suite
    68  }
    69  
    70  func (c *Client) closeSingleUseConn(dst *network.ServerIdentity, path string) {
    71  	dest := destination{dst, path}
    72  	if !c.keep {
    73  		if err := c.closeConn(dest); err != nil {
    74  			log.Errorf("error while closing the connection to %v : %+v\n",
    75  				dest, err)
    76  		}
    77  	}
    78  }
    79  
    80  func (c *Client) newConnIfNotExist(dst *network.ServerIdentity, path string) (*websocket.Conn, *sync.Mutex, error) {
    81  	var err error
    82  
    83  	// c.Lock protects the connections and connectionsLock map
    84  	// c.connectionsLock is held as long as the connection is in use - to avoid that two
    85  	// processes send data over the same websocket concurrently.
    86  	dest := destination{dst, path}
    87  	c.Lock()
    88  	connLock, exists := c.connectionsLock[dest]
    89  	if !exists {
    90  		c.connectionsLock[dest] = &sync.Mutex{}
    91  		connLock = c.connectionsLock[dest]
    92  	}
    93  	c.Unlock()
    94  	// if connLock.Lock is done while the c.Lock is still held, the next process trying to
    95  	// use the same connection will deadlock, as it'll wait for connLock to be released,
    96  	// while the other process will wait for c.Unlock to be released.
    97  	connLock.Lock()
    98  	c.Lock()
    99  	conn, connected := c.connections[dest]
   100  	c.Unlock()
   101  
   102  	if !connected {
   103  		d := &websocket.Dialer{}
   104  		d.TLSClientConfig = c.TLSClientConfig
   105  
   106  		var serverURL string
   107  		var header http.Header
   108  
   109  		// If the URL is in the dst, then use it.
   110  		if dst.URL != "" {
   111  			u, err := url.Parse(dst.URL)
   112  			if err != nil {
   113  				connLock.Unlock()
   114  				return nil, nil, xerrors.Errorf("parsing url: %v", err)
   115  			}
   116  			if u.Scheme == "https" {
   117  				u.Scheme = "wss"
   118  			} else {
   119  				u.Scheme = "ws"
   120  			}
   121  			if !strings.HasSuffix(u.Path, "/") {
   122  				u.Path += "/"
   123  			}
   124  			u.Path += c.service + "/" + path
   125  			serverURL = u.String()
   126  			header = http.Header{"Origin": []string{dst.URL}}
   127  		} else {
   128  			// Open connection to service.
   129  			hp, err := getWSHostPort(dst, false)
   130  			if err != nil {
   131  				connLock.Unlock()
   132  				return nil, nil, xerrors.Errorf("parsing port: %v", err)
   133  			}
   134  
   135  			var wsProtocol string
   136  			var protocol string
   137  
   138  			// The old hacky way of deciding if this server has HTTPS or not:
   139  			// the client somehow magically knows and tells onet by setting
   140  			// c.TLSClientConfig to a non-nil value.
   141  			if c.TLSClientConfig != nil {
   142  				wsProtocol = "wss"
   143  				protocol = "https"
   144  			} else {
   145  				wsProtocol = "ws"
   146  				protocol = "http"
   147  			}
   148  			serverURL = fmt.Sprintf("%s://%s/%s/%s", wsProtocol, hp, c.service, path)
   149  			header = http.Header{"Origin": []string{protocol + "://" + hp}}
   150  		}
   151  
   152  		// Re-try to connect in case the websocket is just about to start
   153  		d.HandshakeTimeout = c.HandshakeTimeout
   154  		for a := 0; a < network.MaxRetryConnect; a++ {
   155  			conn, _, err = d.Dial(serverURL, header)
   156  			if err == nil {
   157  				break
   158  			}
   159  			time.Sleep(network.WaitRetry)
   160  		}
   161  		if err != nil {
   162  			connLock.Unlock()
   163  			return nil, nil, xerrors.Errorf("dial: %v", err)
   164  		}
   165  		c.Lock()
   166  		c.connections[dest] = conn
   167  		c.Unlock()
   168  	}
   169  	return conn, connLock, nil
   170  }
   171  
   172  // Send will marshal the message into a ClientRequest message and send it. It has a
   173  // very simple parallel sending mechanism included: if the send goes to a new or an
   174  // idle connection, the message is sent right away. If the current connection is busy,
   175  // it waits for it to be free.
   176  func (c *Client) Send(dst *network.ServerIdentity, path string, buf []byte) ([]byte, error) {
   177  	conn, connLock, err := c.newConnIfNotExist(dst, path)
   178  	if err != nil {
   179  		return nil, xerrors.Errorf("new connection: %v", err)
   180  	}
   181  	defer connLock.Unlock()
   182  
   183  	var rcv []byte
   184  	defer func() {
   185  		c.Lock()
   186  		c.closeSingleUseConn(dst, path)
   187  		c.rx += uint64(len(rcv))
   188  		c.tx += uint64(len(buf))
   189  		c.Unlock()
   190  	}()
   191  
   192  	log.Lvlf4("Sending %x to %s/%s", buf, c.service, path)
   193  	if err := conn.WriteMessage(websocket.BinaryMessage, buf); err != nil {
   194  		return nil, xerrors.Errorf("connection write: %v", err)
   195  	}
   196  
   197  	if err := conn.SetReadDeadline(time.Now().Add(c.ReadTimeout)); err != nil {
   198  		return nil, xerrors.Errorf("read deadline: %v", err)
   199  	}
   200  	_, rcv, err = conn.ReadMessage()
   201  	if err != nil {
   202  		return nil, xerrors.Errorf("connection read: %v", err)
   203  	}
   204  	return rcv, nil
   205  }
   206  
   207  // SendProtobuf wraps protobuf.(En|De)code over the Client.Send-function. It
   208  // takes the destination, a pointer to a msg-structure that will be
   209  // protobuf-encoded and sent over the websocket. If ret is non-nil, it
   210  // has to be a pointer to the struct that is sent back to the
   211  // client. If there is no error, the ret-structure is filled with the
   212  // data from the service.
   213  func (c *Client) SendProtobuf(dst *network.ServerIdentity, msg interface{}, ret interface{}) error {
   214  	buf, err := protobuf.Encode(msg)
   215  	if err != nil {
   216  		return xerrors.Errorf("encoding: %v", err)
   217  	}
   218  	path := strings.Split(reflect.TypeOf(msg).String(), ".")[1]
   219  	reply, err := c.Send(dst, path, buf)
   220  	if err != nil {
   221  		return xerrors.Errorf("sending: %v", err)
   222  	}
   223  	if ret != nil {
   224  		err := protobuf.DecodeWithConstructors(reply, ret, network.DefaultConstructors(c.suite))
   225  		if err != nil {
   226  			return xerrors.Errorf("decoding: %v", err)
   227  		}
   228  	}
   229  	return nil
   230  }
   231  
   232  // ParallelOptions defines how SendProtobufParallel behaves. Each field has a default
   233  // value that will be used if 'nil' is passed to SendProtobufParallel. For integers,
   234  // the default will also be used if the integer = 0.
   235  type ParallelOptions struct {
   236  	// Parallel indicates how many requests are sent in parallel.
   237  	//   Default: half of all nodes in the roster
   238  	Parallel int
   239  	// AskNodes indicates how many requests are sent in total.
   240  	//   Default: all nodes in the roster, except if StartNodes is set > 0
   241  	AskNodes int
   242  	// StartNode indicates where to start in the roster. If StartNode is > 0 and < len(roster),
   243  	// but AskNodes is 0, then AskNodes will be set to len(Roster)-StartNode.
   244  	//   Default: 0
   245  	StartNode int
   246  	// QuitError - if true, the first error received will be returned.
   247  	//   Default: false
   248  	QuitError bool
   249  	// IgnoreNodes is a set of nodes that will not be contacted. They are counted towards
   250  	// AskNodes and StartNode, but not contacted.
   251  	//   Default: false
   252  	IgnoreNodes []*network.ServerIdentity
   253  	// DontShuffle - if true, the nodes will be contacted in the same order as given in the Roster.
   254  	// StartNode will be applied before shuffling.
   255  	//   Default: false
   256  	DontShuffle bool
   257  }
   258  
   259  // GetList returns how many requests to start in parallel and a channel of nodes to be used.
   260  // If po == nil, it uses default values.
   261  func (po *ParallelOptions) GetList(nodes []*network.ServerIdentity) (parallel int, nodesChan chan *network.ServerIdentity) {
   262  	// Default values
   263  	parallel = (len(nodes) + 1) / 2
   264  	askNodes := len(nodes)
   265  	startNode := 0
   266  	var ignoreNodes []*network.ServerIdentity
   267  	var perm []int
   268  	if po != nil {
   269  		if po.Parallel > 0 && po.Parallel < parallel {
   270  			parallel = po.Parallel
   271  		}
   272  		if po.StartNode > 0 && po.StartNode < len(nodes) {
   273  			startNode = po.StartNode
   274  			askNodes -= startNode
   275  		}
   276  		if po.AskNodes > 0 && po.AskNodes < len(nodes) {
   277  			askNodes = po.AskNodes
   278  		}
   279  		if askNodes < parallel {
   280  			parallel = askNodes
   281  		}
   282  		if po.DontShuffle {
   283  			for i := range nodes {
   284  				perm = append(perm, i)
   285  			}
   286  		}
   287  		ignoreNodes = po.IgnoreNodes
   288  	}
   289  	if len(perm) == 0 {
   290  		perm = rand.Perm(len(nodes))
   291  	}
   292  
   293  	nodesChan = make(chan *network.ServerIdentity, askNodes)
   294  	for i := range nodes {
   295  		addNode := true
   296  		node := nodes[(startNode+perm[i])%len(nodes)]
   297  		for _, ignore := range ignoreNodes {
   298  			if node.Equal(ignore) {
   299  				addNode = false
   300  				break
   301  			}
   302  		}
   303  		if addNode {
   304  			nodesChan <- node
   305  		}
   306  		if len(nodesChan) == askNodes {
   307  			break
   308  		}
   309  	}
   310  	return parallel, nodesChan
   311  }
   312  
   313  // Quit return false if po == nil, or the value in po.QuitError.
   314  func (po *ParallelOptions) Quit() bool {
   315  	if po == nil {
   316  		return false
   317  	}
   318  	return po.QuitError
   319  }
   320  
   321  // Decoder is a function that takes the data and the interface to fill in
   322  // as input and decodes the message.
   323  type Decoder func(data []byte, ret interface{}) error
   324  
   325  // SendProtobufParallelWithDecoder sends the msg to a set of nodes in parallel and returns the first successful
   326  // answer. If all nodes return an error, only the first error is returned.
   327  // The behaviour of this method can be changed using the ParallelOptions argument. It is kept
   328  // as a structure for future enhancements. If opt is nil, then standard values will be taken.
   329  func (c *Client) SendProtobufParallelWithDecoder(nodes []*network.ServerIdentity, msg interface{}, ret interface{},
   330  	opt *ParallelOptions, decoder Decoder) (*network.ServerIdentity, error) {
   331  	buf, err := protobuf.Encode(msg)
   332  	if err != nil {
   333  		return nil, xerrors.Errorf("decoding: %v", err)
   334  	}
   335  	path := strings.Split(reflect.TypeOf(msg).String(), ".")[1]
   336  
   337  	parallel, nodesChan := opt.GetList(nodes)
   338  	nodesNbr := len(nodesChan)
   339  	errChan := make(chan error, nodesNbr)
   340  	decodedChan := make(chan *network.ServerIdentity, 1)
   341  	var decoding sync.Mutex
   342  	done := make(chan bool)
   343  
   344  	contactNode := func() bool {
   345  		select {
   346  		case <-done:
   347  			return false
   348  		default:
   349  			select {
   350  			case node := <-nodesChan:
   351  				log.Lvlf3("Asking %T from: %v - %v", msg, node.Address, node.URL)
   352  				reply, err := c.Send(node, path, buf)
   353  				if err != nil {
   354  					log.Lvl2("Error while sending to node:", node, err)
   355  					errChan <- err
   356  				} else {
   357  					log.Lvl3("Done asking node", node, len(reply))
   358  					decoding.Lock()
   359  					select {
   360  					case <-done:
   361  					default:
   362  						if ret != nil {
   363  							err := decoder(reply, ret)
   364  							if err != nil {
   365  								errChan <- err
   366  								break
   367  							}
   368  						}
   369  						decodedChan <- node
   370  						close(done)
   371  					}
   372  					decoding.Unlock()
   373  				}
   374  			default:
   375  				return false
   376  			}
   377  		}
   378  		return true
   379  	}
   380  
   381  	// Producer that puts messages in errChan and replyChan
   382  	for g := 0; g < parallel; g++ {
   383  		go func() {
   384  			for {
   385  				if !contactNode() {
   386  					return
   387  				}
   388  			}
   389  		}()
   390  	}
   391  
   392  	var errs []error
   393  	for len(errs) < nodesNbr {
   394  		select {
   395  		case node := <-decodedChan:
   396  			return node, nil
   397  		case err := <-errChan:
   398  			if opt.Quit() {
   399  				close(done)
   400  				return nil, err
   401  			}
   402  			errs = append(errs, xerrors.Errorf("sending: %v", err))
   403  		}
   404  	}
   405  
   406  	return nil, errs[0]
   407  }
   408  
   409  // SendProtobufParallel sends the msg to a set of nodes in parallel and returns the first successful
   410  // answer. If all nodes return an error, only the first error is returned.
   411  // The behaviour of this method can be changed using the ParallelOptions argument. It is kept
   412  // as a structure for future enhancements. If opt is nil, then standard values will be taken.
   413  func (c *Client) SendProtobufParallel(nodes []*network.ServerIdentity, msg interface{}, ret interface{},
   414  	opt *ParallelOptions) (*network.ServerIdentity, error) {
   415  	si, err := c.SendProtobufParallelWithDecoder(nodes, msg, ret, opt, protobuf.Decode)
   416  	if err != nil {
   417  		return nil, xerrors.Errorf("sending: %v", err)
   418  	}
   419  	return si, nil
   420  }
   421  
   422  // StreamingConn allows clients to read from it without sending additional
   423  // requests.
   424  type StreamingConn struct {
   425  	conn  *websocket.Conn
   426  	suite network.Suite
   427  }
   428  
   429  // StreamingReadOpts contains options for the ReadMessageWithOpts. It allows us
   430  // to add new options in the future without making breaking changes.
   431  type StreamingReadOpts struct {
   432  	Deadline time.Time
   433  }
   434  
   435  // ReadMessage read more data from the connection, it will block if there are
   436  // no messages.
   437  func (c *StreamingConn) ReadMessage(ret interface{}) error {
   438  	opts := StreamingReadOpts{
   439  		Deadline: time.Now().Add(5 * time.Minute),
   440  	}
   441  
   442  	return c.readMsg(ret, opts)
   443  }
   444  
   445  // ReadMessageWithOpts does the same as ReadMessage and allows to pass options.
   446  func (c *StreamingConn) ReadMessageWithOpts(ret interface{}, opts StreamingReadOpts) error {
   447  	return c.readMsg(ret, opts)
   448  }
   449  
   450  func (c *StreamingConn) readMsg(ret interface{}, opts StreamingReadOpts) error {
   451  	if err := c.conn.SetReadDeadline(opts.Deadline); err != nil {
   452  		return xerrors.Errorf("read deadline: %v", err)
   453  	}
   454  	// No need to add bytes to counter here because this function is only
   455  	// called by the client.
   456  	_, buf, err := c.conn.ReadMessage()
   457  	if err != nil {
   458  		return xerrors.Errorf("connection read: %w", err)
   459  	}
   460  	err = protobuf.DecodeWithConstructors(buf, ret, network.DefaultConstructors(c.suite))
   461  	if err != nil {
   462  		return xerrors.Errorf("decoding: %v", err)
   463  	}
   464  	return nil
   465  }
   466  
   467  // Ping sends a ping message. Data can be nil.
   468  func (c *StreamingConn) Ping(data []byte, deadline time.Time) error {
   469  	return c.conn.WriteControl(websocket.PingMessage, data, deadline)
   470  }
   471  
   472  // Stream will send a request to start streaming, it returns a connection where
   473  // the client can continue to read values from it.
   474  func (c *Client) Stream(dst *network.ServerIdentity, msg interface{}) (StreamingConn, error) {
   475  	buf, err := protobuf.Encode(msg)
   476  	if err != nil {
   477  		return StreamingConn{}, err
   478  	}
   479  	path := strings.Split(reflect.TypeOf(msg).String(), ".")[1]
   480  
   481  	conn, connLock, err := c.newConnIfNotExist(dst, path)
   482  	if err != nil {
   483  		return StreamingConn{}, err
   484  	}
   485  	defer connLock.Unlock()
   486  	err = conn.WriteMessage(websocket.BinaryMessage, buf)
   487  	if err != nil {
   488  		return StreamingConn{}, err
   489  	}
   490  	c.Lock()
   491  	c.tx += uint64(len(buf))
   492  	c.Unlock()
   493  	return StreamingConn{conn, c.Suite()}, nil
   494  }
   495  
   496  // SendToAll sends a message to all ServerIdentities of the Roster and returns
   497  // all errors encountered concatenated together as a string.
   498  func (c *Client) SendToAll(dst *Roster, path string, buf []byte) ([][]byte, error) {
   499  	msgs := make([][]byte, len(dst.List))
   500  	var errstrs []string
   501  	for i, e := range dst.List {
   502  		var err error
   503  		msgs[i], err = c.Send(e, path, buf)
   504  		if err != nil {
   505  			errstrs = append(errstrs, fmt.Sprint(e.String(), err.Error()))
   506  		}
   507  	}
   508  	var err error
   509  	if len(errstrs) > 0 {
   510  		err = xerrors.New(strings.Join(errstrs, "\n"))
   511  	}
   512  	return msgs, err
   513  }
   514  
   515  // Close sends a close-command to all open connections and returns nil if no
   516  // errors occurred or all errors encountered concatenated together as a string.
   517  func (c *Client) Close() error {
   518  	c.Lock()
   519  	defer c.Unlock()
   520  	var errstrs []string
   521  	for dest := range c.connections {
   522  		connLock := c.connectionsLock[dest]
   523  		c.Unlock()
   524  		connLock.Lock()
   525  		c.Lock()
   526  		if err := c.closeConn(dest); err != nil {
   527  			errstrs = append(errstrs, err.Error())
   528  		}
   529  		connLock.Unlock()
   530  	}
   531  	var err error
   532  	if len(errstrs) > 0 {
   533  		err = xerrors.New(strings.Join(errstrs, "\n"))
   534  	}
   535  	return err
   536  }
   537  
   538  // closeConn sends a close-command to the connection. Correct locking must be done
   539  // befor calling this method.
   540  func (c *Client) closeConn(dst destination) error {
   541  	conn, ok := c.connections[dst]
   542  	if ok {
   543  		delete(c.connections, dst)
   544  		err := conn.WriteMessage(websocket.CloseMessage,
   545  			websocket.FormatCloseMessage(websocket.CloseNormalClosure, "client closed"))
   546  		if err != nil {
   547  			log.Error("Error while sending closing type:", err)
   548  		}
   549  		return conn.Close()
   550  	}
   551  	return nil
   552  }
   553  
   554  // Tx returns the number of bytes transmitted by this Client. It implements
   555  // the monitor.CounterIOMeasure interface.
   556  func (c *Client) Tx() uint64 {
   557  	c.Lock()
   558  	defer c.Unlock()
   559  	return c.tx
   560  }
   561  
   562  // Rx returns the number of bytes read by this Client. It implements
   563  // the monitor.CounterIOMeasure interface.
   564  func (c *Client) Rx() uint64 {
   565  	c.Lock()
   566  	defer c.Unlock()
   567  	return c.rx
   568  }
   569  
   570  // schemeToPort returns the port corresponding to the given scheme, much like netdb.
   571  func schemeToPort(name string) (uint16, error) {
   572  	switch name {
   573  	case "http":
   574  		return 80, nil
   575  	case "https":
   576  		return 443, nil
   577  	default:
   578  		return 0, fmt.Errorf("no such scheme: %v", name)
   579  	}
   580  }
   581  
   582  // getWSHostPort returns the hostname:port to bind to with WebSocket.
   583  // If global is true, the hostname is set to the unspecified 0.0.0.0-address.
   584  // If si.URL is "", the url uses the hostname and port+1 of si.Address.
   585  func getWSHostPort(si *network.ServerIdentity, global bool) (string, error) {
   586  	const portBitSize = 16
   587  	const portNumericBase = 10
   588  
   589  	var hostname string
   590  	var port uint16
   591  
   592  	if si.URL != "" {
   593  		url, err := url.Parse(si.URL)
   594  		if err != nil {
   595  			return "", fmt.Errorf("unable to parse URL: %v", err)
   596  		}
   597  		if !url.IsAbs() {
   598  			return "", errors.New("URL is not absolute")
   599  		}
   600  
   601  		protocolPort, err := schemeToPort(url.Scheme)
   602  		if err != nil {
   603  			return "", fmt.Errorf("unable to translate URL' scheme to port: %v", err)
   604  		}
   605  
   606  		portStr := url.Port()
   607  		if portStr == "" {
   608  			port = protocolPort
   609  		} else {
   610  			portRaw, err := strconv.ParseUint(portStr, portNumericBase, portBitSize)
   611  			if err != nil {
   612  				return "", fmt.Errorf("URL doesn't contain a valid port: %v", err)
   613  			}
   614  			port = uint16(portRaw)
   615  		}
   616  		hostname = url.Hostname()
   617  	} else {
   618  		portRaw, err := strconv.ParseUint(si.Address.Port(), portNumericBase, portBitSize)
   619  		if err != nil {
   620  			return "", fmt.Errorf("unable to parse port of Address as int: %v", err)
   621  		}
   622  		port = uint16(portRaw + 1)
   623  		hostname = si.Address.Host()
   624  	}
   625  
   626  	if global {
   627  		hostname = "0.0.0.0"
   628  	}
   629  
   630  	portFormatted := strconv.FormatUint(uint64(port), 10)
   631  	return net.JoinHostPort(hostname, portFormatted), nil
   632  }