github.com/ungtb10d/cli/v2@v2.0.0-20221110210412-98537dd9d6a1/pkg/liveshare/test/socket.go (about)

     1  package livesharetest
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/gorilla/websocket"
    10  )
    11  
    12  type socketConn struct {
    13  	*websocket.Conn
    14  
    15  	reader     io.Reader
    16  	writeMutex sync.Mutex
    17  	readMutex  sync.Mutex
    18  }
    19  
    20  func newSocketConn(conn *websocket.Conn) *socketConn {
    21  	return &socketConn{Conn: conn}
    22  }
    23  
    24  func (s *socketConn) Read(b []byte) (int, error) {
    25  	s.readMutex.Lock()
    26  	defer s.readMutex.Unlock()
    27  
    28  	if s.reader == nil {
    29  		msgType, r, err := s.Conn.NextReader()
    30  		if err != nil {
    31  			return 0, fmt.Errorf("error getting next reader: %w", err)
    32  		}
    33  		if msgType != websocket.BinaryMessage {
    34  			return 0, fmt.Errorf("invalid message type")
    35  		}
    36  		s.reader = r
    37  	}
    38  
    39  	bytesRead, err := s.reader.Read(b)
    40  	if err != nil {
    41  		s.reader = nil
    42  
    43  		if err == io.EOF {
    44  			err = nil
    45  		}
    46  	}
    47  
    48  	return bytesRead, err
    49  }
    50  
    51  func (s *socketConn) Write(b []byte) (int, error) {
    52  	s.writeMutex.Lock()
    53  	defer s.writeMutex.Unlock()
    54  
    55  	w, err := s.Conn.NextWriter(websocket.BinaryMessage)
    56  	if err != nil {
    57  		return 0, fmt.Errorf("error getting next writer: %w", err)
    58  	}
    59  
    60  	n, err := w.Write(b)
    61  	if err != nil {
    62  		return 0, fmt.Errorf("error writing: %w", err)
    63  	}
    64  
    65  	if err := w.Close(); err != nil {
    66  		return 0, fmt.Errorf("error closing writer: %w", err)
    67  	}
    68  
    69  	return n, nil
    70  }
    71  
    72  func (s *socketConn) SetDeadline(deadline time.Time) error {
    73  	if err := s.Conn.SetReadDeadline(deadline); err != nil {
    74  		return err
    75  	}
    76  	return s.Conn.SetWriteDeadline(deadline)
    77  }