github.com/getgauge/gauge@v1.6.9/conn/network.go (about)

     1  /*----------------------------------------------------------------
     2   *  Copyright (c) ThoughtWorks, Inc.
     3   *  Licensed under the Apache License, Version 2.0
     4   *  See LICENSE in the project root for license information.
     5   *----------------------------------------------------------------*/
     6  
     7  package conn
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"net"
    13  	"os"
    14  	"strconv"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/getgauge/common"
    19  	"github.com/getgauge/gauge-proto/go/gauge_messages"
    20  	"github.com/getgauge/gauge/logger"
    21  
    22  	// github.com/golang/protobuf/proto is deprecated, however this package is used by the legacy API
    23  	// which is consumed only by IntelliJ IDEA presently. Since IDEA does not plan to implement LSP
    24  	// gauge will have to keep this alive. Upgrading to google.golang.org/protobuf/proto is not a drop in change
    25  	// since the newer library does not support DecodeVarint. The whole message handling will need to be refactored.
    26  	"github.com/golang/protobuf/proto" //nolint:staticcheck
    27  )
    28  
    29  type response struct {
    30  	result chan *gauge_messages.Message
    31  	err    chan error
    32  	timer  *time.Timer
    33  }
    34  
    35  func (r *response) stopTimer() {
    36  	if r.timer != nil {
    37  		r.timer.Stop()
    38  	}
    39  }
    40  
    41  func (r *response) addTimer(timeout time.Duration, message *gauge_messages.Message) {
    42  	if timeout > 0 {
    43  		r.timer = time.AfterFunc(timeout, func() {
    44  			r.err <- fmt.Errorf("Request timed out for Message with ID => %v and Type => %s", message.GetMessageId(), message.GetMessageType().String())
    45  		})
    46  	}
    47  }
    48  
    49  type messages struct {
    50  	m map[int64]response
    51  	sync.Mutex
    52  }
    53  
    54  func (m *messages) get(k int64) response {
    55  	m.Lock()
    56  	defer m.Unlock()
    57  	return m.m[k]
    58  }
    59  
    60  func (m *messages) put(k int64, res response) {
    61  	m.Lock()
    62  	defer m.Unlock()
    63  	m.m[k] = res
    64  }
    65  
    66  func (m *messages) delete(k int64) {
    67  	m.Lock()
    68  	defer m.Unlock()
    69  	delete(m.m, k)
    70  }
    71  
    72  var m = &messages{m: make(map[int64]response)}
    73  
    74  func writeDataAndGetResponse(conn net.Conn, messageBytes []byte) ([]byte, error) {
    75  	if err := Write(conn, messageBytes); err != nil {
    76  		return nil, err
    77  	}
    78  	return readResponse(conn)
    79  }
    80  
    81  func readResponse(conn net.Conn) ([]byte, error) {
    82  	buffer := new(bytes.Buffer)
    83  	data := make([]byte, 8192)
    84  	for {
    85  		n, err := conn.Read(data)
    86  		if err != nil {
    87  			e := conn.Close()
    88  			if e != nil {
    89  				logger.Debugf(false, "Connection already closed, %s", e.Error())
    90  			}
    91  			return nil, fmt.Errorf("connection closed [%s] cause: %s", conn.RemoteAddr(), err.Error())
    92  		}
    93  
    94  		_, err = buffer.Write(data[0:n])
    95  		if err != nil {
    96  			return nil, fmt.Errorf("unable to write to buffer, %s", err.Error())
    97  		}
    98  		messageLength, bytesRead := proto.DecodeVarint(buffer.Bytes())
    99  		if (messageLength > 0 && messageLength < uint64(buffer.Len())) && ((messageLength + uint64(bytesRead)) <= uint64(buffer.Len())) {
   100  			return buffer.Bytes()[bytesRead : messageLength+uint64(bytesRead)], nil
   101  		}
   102  	}
   103  }
   104  
   105  func Write(conn net.Conn, messageBytes []byte) error {
   106  	messageLen := proto.EncodeVarint(uint64(len(messageBytes)))
   107  	data := append(messageLen, messageBytes...)
   108  	_, err := conn.Write(data)
   109  	return err
   110  }
   111  
   112  func WriteGaugeMessage(message *gauge_messages.Message, conn net.Conn) error {
   113  	messageID := common.GetUniqueID()
   114  	message.MessageId = messageID
   115  
   116  	data, err := proto.Marshal(message)
   117  	if err != nil {
   118  		return err
   119  	}
   120  	return Write(conn, data)
   121  }
   122  
   123  func getResponseForGaugeMessage(message *gauge_messages.Message, conn net.Conn, res response, timeout time.Duration) {
   124  	message.MessageId = common.GetUniqueID()
   125  	res.addTimer(timeout, message)
   126  	handle := func(err error) {
   127  		if err != nil {
   128  			res.stopTimer()
   129  			res.err <- err
   130  		}
   131  	}
   132  
   133  	data, err := proto.Marshal(message)
   134  
   135  	handle(err)
   136  	m.put(message.GetMessageId(), res)
   137  
   138  	responseBytes, err := writeDataAndGetResponse(conn, data)
   139  	handle(err)
   140  
   141  	responseMessage := &gauge_messages.Message{}
   142  	err = proto.Unmarshal(responseBytes, responseMessage)
   143  	handle(err)
   144  
   145  	err = checkUnsupportedResponseMessage(responseMessage)
   146  	handle(err)
   147  
   148  	responseRes := m.get(responseMessage.GetMessageId())
   149  	responseRes.stopTimer()
   150  	responseRes.result <- responseMessage
   151  	m.delete(responseMessage.GetMessageId())
   152  }
   153  
   154  func checkUnsupportedResponseMessage(message *gauge_messages.Message) error {
   155  	if message.GetMessageType() == gauge_messages.Message_UnsupportedMessageResponse {
   156  		return fmt.Errorf("Unsupported Message response received. Message not supported. %s", message.GetUnsupportedMessageResponse().GetMessage())
   157  	}
   158  	return nil
   159  }
   160  
   161  // Sends request to plugin for a message. If response is not received for the given message within the configured timeout, an error is thrown
   162  // To wait indefinitely for the response from the plugin, set timeout value as 0.
   163  func GetResponseForMessageWithTimeout(message *gauge_messages.Message, conn net.Conn, timeout time.Duration) (*gauge_messages.Message, error) {
   164  	res := response{result: make(chan *gauge_messages.Message), err: make(chan error)}
   165  	go getResponseForGaugeMessage(message, conn, res, timeout)
   166  	select {
   167  	case err := <-res.err:
   168  		return nil, err
   169  	case res := <-res.result:
   170  		return res, nil
   171  	}
   172  }
   173  
   174  func GetPortFromEnvironmentVariable(portEnvVariable string) (int, error) {
   175  	if port := os.Getenv(portEnvVariable); port != "" {
   176  		gport, err := strconv.Atoi(port)
   177  		if err != nil {
   178  			return 0, fmt.Errorf("%s is not a valid port", port)
   179  		}
   180  		return gport, nil
   181  	}
   182  	return 0, fmt.Errorf("%s Environment variable not set", portEnvVariable)
   183  }
   184  
   185  // SendProcessKillMessage sends a KillProcessRequest message through the connection.
   186  func SendProcessKillMessage(connection net.Conn) error {
   187  	id := common.GetUniqueID()
   188  	message := &gauge_messages.Message{MessageId: id, MessageType: gauge_messages.Message_KillProcessRequest,
   189  		KillProcessRequest: &gauge_messages.KillProcessRequest{}}
   190  
   191  	return WriteGaugeMessage(message, connection)
   192  }