github.com/bitfinexcom/bitfinex-api-go@v0.0.0-20210608095005-9e0b26f200fb/tests/integration/v2/test_ws_service.go (about) 1 package tests 2 3 import ( 4 "bytes" 5 "fmt" 6 "log" 7 "net" 8 "net/http" 9 "sync" 10 "time" 11 12 "github.com/gorilla/websocket" 13 ) 14 15 var upgrader = websocket.Upgrader{ 16 ReadBufferSize: 1024, 17 WriteBufferSize: 1024, 18 } 19 20 type client struct { 21 parent *TestWsService 22 *websocket.Conn 23 send chan []byte 24 received []string 25 lock sync.Mutex 26 } 27 28 func (c *client) writePump() { 29 for msg := range c.send { 30 err := c.Conn.WriteMessage(websocket.TextMessage, msg) 31 if err != nil { 32 log.Printf("could not send message (%s) to client: %s", string(msg), err.Error()) 33 continue 34 } 35 } 36 } 37 38 func (c *client) readPump() { 39 defer func() { 40 c.parent.unregister <- c 41 c.Conn.Close() 42 }() 43 for { 44 _, message, err := c.Conn.ReadMessage() 45 if err != nil { 46 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { 47 log.Printf("error: %v", err) 48 } 49 log.Printf("test ws service drop client: %s", err.Error()) 50 return 51 } 52 message = bytes.TrimSpace(bytes.Replace(message, []byte("\n"), []byte(" "), -1)) 53 c.lock.Lock() 54 log.Printf("[DEBUG] WsClient -> WsService: %s", string(message)) 55 c.received = append(c.received, string(message)) 56 c.lock.Unlock() 57 } 58 } 59 60 type TestWsService struct { 61 clients map[*client]bool 62 listener net.Listener 63 port int 64 65 register chan *client 66 unregister chan *client 67 broadcast chan []byte 68 totalClients int 69 lock *sync.RWMutex 70 71 publishOnConnect string 72 } 73 74 func (s *TestWsService) WaitForClientCount(count int) error { 75 loops := 80 76 delay := time.Millisecond * 50 77 for i := 0; i < loops; i++ { 78 s.lock.RLock() 79 if s.totalClients == count { 80 return nil 81 } 82 s.lock.RUnlock() 83 time.Sleep(delay) 84 } 85 return fmt.Errorf("client peer #%d did not connect", count) 86 } 87 88 func (s *TestWsService) TotalClientCount() int { 89 return s.totalClients 90 } 91 92 func (s *TestWsService) PublishOnConnect(msg string) { 93 s.publishOnConnect = msg 94 } 95 96 func NewTestWsService(port int) *TestWsService { 97 return &TestWsService{ 98 port: port, 99 clients: make(map[*client]bool), 100 register: make(chan *client), 101 unregister: make(chan *client), 102 broadcast: make(chan []byte), 103 lock: &sync.RWMutex{}, 104 } 105 } 106 107 // Broadcast sends a message to all connected clients. 108 func (s *TestWsService) Broadcast(msg string) { 109 s.broadcast <- []byte(msg) 110 111 } 112 113 // ReceivedCount starts indexing clients at position 0. 114 func (s *TestWsService) ReceivedCount(clientNum int) int { 115 i := 0 116 for client := range s.clients { 117 if i == clientNum { 118 client.lock.Lock() 119 defer client.lock.Unlock() 120 return len(client.received) 121 } 122 i++ 123 } 124 return 0 125 } 126 127 // Received starts indexing clients and message positions at position 0. 128 func (s *TestWsService) Received(clientNum int, msgNum int) (string, error) { 129 var client *client 130 i := 0 131 for client = range s.clients { 132 if i == clientNum { 133 break 134 } 135 i++ 136 } 137 if client != nil { 138 client.lock.Lock() 139 defer client.lock.Unlock() 140 if len(client.received) > msgNum { 141 return string(client.received[msgNum]), nil 142 } 143 return "", fmt.Errorf("could not find message index %d, %d messages exist", msgNum, len(client.received)) 144 } 145 return "", fmt.Errorf("could not find client %d", clientNum) 146 } 147 148 func (s *TestWsService) WaitForMessage(clientNum int, msgNum int) (string, error) { 149 loops := 80 150 delay := time.Millisecond * 50 151 var msg string 152 var err error 153 for i := 0; i < loops; i++ { 154 msg, err = s.Received(clientNum, msgNum) 155 if err != nil { 156 time.Sleep(delay) 157 } else { 158 return msg, nil 159 } 160 } 161 return "", err 162 } 163 164 func (s *TestWsService) ServeHTTP(w http.ResponseWriter, r *http.Request) { 165 s.serveWs(w, r) 166 } 167 168 func (s *TestWsService) Stop() { 169 //s.lock.RLock() 170 //defer s.lock.RUnlock() 171 s.listener.Close() // stop listening to http 172 for c := range s.clients { 173 c.Close() 174 } 175 } 176 177 //nolint 178 func (s *TestWsService) Start() error { 179 go s.loop() 180 l, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port)) 181 if err != nil { 182 return err 183 } 184 s.listener = l 185 go http.Serve(s.listener, s) 186 return nil 187 } 188 189 //nolint 190 func (s *TestWsService) serveWs(w http.ResponseWriter, r *http.Request) { 191 conn, err := upgrader.Upgrade(w, r, nil) 192 if err != nil { 193 log.Print(err) 194 return 195 } 196 s.totalClients++ 197 client := &client{parent: s, Conn: conn, send: make(chan []byte, 256), received: make([]string, 0)} 198 go client.writePump() 199 go client.readPump() 200 s.clients[client] = true 201 if s.publishOnConnect != "" { 202 s.Broadcast(s.publishOnConnect) 203 } 204 } 205 206 func (s *TestWsService) loop() { 207 for { 208 select { 209 case client := <-s.register: 210 //s.lock.Lock() 211 s.clients[client] = true 212 //s.lock.Unlock() 213 case client := <-s.unregister: 214 if _, ok := s.clients[client]; ok { 215 //s.lock.Lock() 216 delete(s.clients, client) 217 close(client.send) 218 //s.lock.Unlock() 219 } 220 case msg := <-s.broadcast: 221 for client := range s.clients { 222 select { 223 case client.send <- msg: 224 default: // send failure 225 //s.lock.Lock() 226 close(client.send) 227 delete(s.clients, client) 228 //s.lock.Unlock() 229 } 230 } 231 } 232 } 233 }