github.com/ActiveState/cli@v0.0.0-20240508170324-6801f60cd051/internal/testhelpers/wsmock/wsmock.go (about)

     1  package wsmock
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"path/filepath"
     7  	"runtime"
     8  	"strings"
     9  
    10  	"github.com/ActiveState/cli/internal/fileutils"
    11  	"github.com/ActiveState/cli/internal/logging"
    12  	"github.com/ActiveState/cli/internal/multilog"
    13  	"github.com/gorilla/websocket"
    14  	"github.com/posener/wstest"
    15  )
    16  
    17  type WsMock struct {
    18  	upgrader      websocket.Upgrader
    19  	responders    map[string]string
    20  	responsePath  string
    21  	responseQueue []string
    22  	done          chan bool
    23  }
    24  
    25  func Init() *WsMock {
    26  	mock := &WsMock{
    27  		responders: map[string]string{},
    28  	}
    29  	mock.upgrader.CheckOrigin = func(r *http.Request) bool { return true }
    30  	return mock
    31  }
    32  
    33  func (s *WsMock) Dialer() *websocket.Dialer {
    34  	return wstest.NewDialer(s)
    35  }
    36  
    37  func (s *WsMock) RegisterWithResponse(requestContains string, responseFile string) {
    38  	s.responders[requestContains] = responseFile
    39  }
    40  
    41  func (s *WsMock) QueueResponse(responseFile string) {
    42  	responseFile = s.getResponseFile(responseFile)
    43  	response := string(fileutils.ReadFileUnsafe(responseFile))
    44  	s.responseQueue = append(s.responseQueue, response)
    45  }
    46  
    47  func (s *WsMock) Close() {
    48  	logging.Debug("Close called")
    49  	s.done <- true
    50  }
    51  
    52  func (s *WsMock) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    53  	s.done = make(chan bool)
    54  
    55  	conn, err := s.upgrader.Upgrade(w, r, nil)
    56  	if err != nil {
    57  		panic(fmt.Sprintf("Could not upgrade connection to websocket: %v", err))
    58  	}
    59  
    60  	for s.done != nil {
    61  		logging.Debug("Loop")
    62  		_, msgBytes, err := conn.ReadMessage()
    63  		if err != nil {
    64  			if strings.HasPrefix(err.Error(), "websocket: close") {
    65  				logging.Debug("websocket close encountered")
    66  				s.Close()
    67  				return
    68  			}
    69  			multilog.Error("Reading Message failed: %v", err)
    70  			return
    71  		}
    72  
    73  		msg := string(msgBytes[:])
    74  		logging.Debug("Message received: %v", msg)
    75  
    76  		for requestContains, responseFile := range s.responders {
    77  			if strings.Contains(msg, requestContains) {
    78  				responseFile = s.getResponseFile(responseFile)
    79  				response := string(fileutils.ReadFileUnsafe(responseFile))
    80  				if err := conn.WriteMessage(websocket.TextMessage, []byte(response)); err != nil {
    81  					panic(fmt.Sprintf("Could not write response to websocket: %v", err))
    82  				}
    83  				break
    84  			}
    85  		}
    86  
    87  		for _, response := range s.responseQueue {
    88  			if err := conn.WriteMessage(websocket.TextMessage, []byte(response)); err != nil {
    89  				panic(fmt.Sprintf("Could not write response to websocket: %v", err))
    90  			}
    91  		}
    92  		s.responseQueue = []string{}
    93  	}
    94  }
    95  
    96  func (s *WsMock) getResponseFile(responseFile string) string {
    97  	return filepath.Join(s.getResponsePath(), responseFile) + ".json"
    98  }
    99  
   100  func (s *WsMock) getResponsePath() string {
   101  	if s.responsePath == "" {
   102  		_, currentFile, _, _ := runtime.Caller(0)
   103  		file := currentFile
   104  		ok := true
   105  		iter := 2
   106  
   107  		for file == currentFile && ok {
   108  			_, file, _, ok = runtime.Caller(iter)
   109  			iter = iter + 1
   110  		}
   111  
   112  		if file == "" || file == currentFile {
   113  			panic("Could not get caller")
   114  		}
   115  		s.responsePath = filepath.Join(filepath.Dir(file), "testdata", "wsresponse")
   116  	}
   117  
   118  	return s.responsePath
   119  }