github.com/getgauge/gauge@v1.6.9/conn/connectionHandler.go (about)

     1  /*----------------------------------------------------------------
     2   *  Copyright (c) ThoughtWorks, Inc.
     3   *  Licensed under the Apache License, Version 2.0
     4   *  See LICENSE in the project root for license information.
     5   *----------------------------------------------------------------*/
     6  
     7  package conn
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"net"
    13  	"time"
    14  
    15  	"github.com/getgauge/gauge/logger"
    16  	// github.com/golang/protobuf/proto is deprecated, however this package is used by the legacy API
    17  	// which is consumed only by IntelliJ IDEA presently. Since IDEA does not plan to implement LSP
    18  	// gauge will have to keep this alive. Upgrading to google.golang.org/protobuf/proto is not a drop in change
    19  	// since the newer library does not support DecodeVarint. The whole message handling will need to be refactored.
    20  	"github.com/golang/protobuf/proto" //nolint:staticcheck
    21  )
    22  
    23  type messageHandler interface {
    24  	MessageBytesReceived([]byte, net.Conn)
    25  }
    26  
    27  type GaugeConnectionHandler struct {
    28  	tcpListener    *net.TCPListener
    29  	messageHandler messageHandler
    30  }
    31  
    32  func NewGaugeConnectionHandler(port int, messageHandler messageHandler) (*GaugeConnectionHandler, error) {
    33  	// port = 0 means GO will find a unused port
    34  	address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", port))
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	listener, err := net.ListenTCP("tcp", address)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	return &GaugeConnectionHandler{tcpListener: listener, messageHandler: messageHandler}, nil
    44  }
    45  
    46  func (connectionHandler *GaugeConnectionHandler) AcceptConnection(connectionTimeOut time.Duration, errChannel chan error) (net.Conn, error) {
    47  	connectionChannel := make(chan net.Conn)
    48  
    49  	go func() {
    50  		connection, err := connectionHandler.tcpListener.Accept()
    51  		if err != nil {
    52  			errChannel <- err
    53  		}
    54  		if connection != nil {
    55  			connectionChannel <- connection
    56  		}
    57  	}()
    58  
    59  	select {
    60  	case err := <-errChannel:
    61  		return nil, err
    62  	case conn := <-connectionChannel:
    63  		if connectionHandler.messageHandler != nil {
    64  			go connectionHandler.handleConnectionMessages(conn)
    65  		}
    66  		return conn, nil
    67  	case <-time.After(connectionTimeOut):
    68  		return nil, fmt.Errorf("Timed out connecting to %v", connectionHandler.tcpListener.Addr())
    69  	}
    70  }
    71  
    72  func (connectionHandler *GaugeConnectionHandler) acceptConnectionWithoutTimeout() (net.Conn, error) {
    73  	errChannel := make(chan error)
    74  	connectionChannel := make(chan net.Conn)
    75  
    76  	go func() {
    77  		connection, err := connectionHandler.tcpListener.Accept()
    78  		if err != nil {
    79  			errChannel <- err
    80  		}
    81  		if connection != nil {
    82  			connectionChannel <- connection
    83  		}
    84  	}()
    85  
    86  	select {
    87  	case err := <-errChannel:
    88  		return nil, err
    89  	case conn := <-connectionChannel:
    90  		if connectionHandler.messageHandler != nil {
    91  			go connectionHandler.handleConnectionMessages(conn)
    92  		}
    93  		return conn, nil
    94  	}
    95  }
    96  
    97  func (connectionHandler *GaugeConnectionHandler) handleConnectionMessages(conn net.Conn) {
    98  	buffer := new(bytes.Buffer)
    99  	data := make([]byte, 8192)
   100  	for {
   101  		n, err := conn.Read(data)
   102  		if err != nil {
   103  			e := conn.Close()
   104  			if e != nil {
   105  				logger.Debugf(false, "Connection already closed, %s", e.Error())
   106  			}
   107  			logger.Infof(false, "Closing connection [%s] cause: %s", conn.RemoteAddr(), err.Error())
   108  			return
   109  		}
   110  
   111  		_, err = buffer.Write(data[0:n])
   112  		if err != nil {
   113  			logger.Infof(false, "Unable to write to buffer, %s", err.Error())
   114  			return
   115  		}
   116  		connectionHandler.processMessage(buffer, conn)
   117  	}
   118  }
   119  
   120  func (connectionHandler *GaugeConnectionHandler) processMessage(buffer *bytes.Buffer, conn net.Conn) {
   121  	for {
   122  		messageLength, bytesRead := proto.DecodeVarint(buffer.Bytes())
   123  		if messageLength > 0 && messageLength < uint64(buffer.Len()) {
   124  			messageBoundary := int(messageLength) + bytesRead
   125  			receivedBytes := buffer.Bytes()[bytesRead:messageBoundary]
   126  			connectionHandler.messageHandler.MessageBytesReceived(receivedBytes, conn)
   127  			buffer.Next(messageBoundary)
   128  			if buffer.Len() == 0 {
   129  				return
   130  			}
   131  		} else {
   132  			return
   133  		}
   134  	}
   135  }
   136  
   137  // HandleMultipleConnections accepts multiple connections and Handler responds to incoming messages
   138  func (connectionHandler *GaugeConnectionHandler) HandleMultipleConnections() {
   139  	for {
   140  		_, err := connectionHandler.acceptConnectionWithoutTimeout()
   141  		if err != nil {
   142  			logger.Fatalf(true, "Unable to connect to runner: %s", err.Error())
   143  		}
   144  	}
   145  
   146  }
   147  
   148  func (connectionHandler *GaugeConnectionHandler) ConnectionPortNumber() int {
   149  	if connectionHandler.tcpListener != nil {
   150  		return connectionHandler.tcpListener.Addr().(*net.TCPAddr).Port
   151  	}
   152  	return 0
   153  }