github.com/metaworking/channeld@v0.7.3/pkg/client/client.go (about) 1 package client 2 3 import ( 4 "errors" 5 "fmt" 6 "net" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/golang/snappy" 12 "github.com/gorilla/websocket" 13 "github.com/metaworking/channeld/pkg/channeld" 14 "github.com/metaworking/channeld/pkg/channeldpb" 15 "google.golang.org/protobuf/proto" 16 ) 17 18 type Message = proto.Message 19 type MessageHandlerFunc func(client *ChanneldClient, channelId uint32, m Message) 20 type messageMapEntry struct { 21 msg Message 22 handlers []MessageHandlerFunc 23 } 24 type messageQueueEntry struct { 25 msg Message 26 channelId uint32 27 stubId uint32 28 handlers []MessageHandlerFunc 29 } 30 31 // Go library for writing game client/server that interations with channeld. 32 type ChanneldClient struct { 33 Id uint32 34 CompressionType channeldpb.CompressionType 35 SubscribedChannels map[uint32]struct{} 36 CreatedChannels map[uint32]struct{} 37 ListedChannels map[uint32]struct{} 38 Conn net.Conn 39 readBuffer []byte 40 readPos int 41 connected bool 42 incomingQueue chan messageQueueEntry 43 outgoingQueue chan *channeldpb.MessagePack 44 messageMap map[uint32]*messageMapEntry 45 stubCallbacks map[uint32]MessageHandlerFunc 46 writeMutex sync.Mutex 47 } 48 49 func NewClient(addr string) (*ChanneldClient, error) { 50 var conn net.Conn 51 if strings.HasPrefix(addr, "ws") { 52 c, _, err := websocket.DefaultDialer.Dial(addr, nil) 53 if err != nil { 54 return nil, err 55 } 56 57 conn = &wsConn{conn: c} 58 } else { 59 var err error 60 conn, err = net.Dial("tcp", addr) 61 if err != nil { 62 return nil, err 63 } 64 } 65 c := &ChanneldClient{ 66 CompressionType: channeldpb.CompressionType_NO_COMPRESSION, 67 SubscribedChannels: make(map[uint32]struct{}), 68 CreatedChannels: make(map[uint32]struct{}), 69 ListedChannels: make(map[uint32]struct{}), 70 Conn: conn, 71 readBuffer: make([]byte, channeld.MaxPacketSize), 72 readPos: 0, 73 connected: true, 74 incomingQueue: make(chan messageQueueEntry, 128), 75 outgoingQueue: make(chan *channeldpb.MessagePack, 32), 76 messageMap: make(map[uint32]*messageMapEntry), 77 stubCallbacks: map[uint32]MessageHandlerFunc{ 78 // 0 is Reserved 79 0: func(_ *ChanneldClient, _ uint32, _ Message) {}, 80 }, 81 } 82 83 c.SetMessageEntry(uint32(channeldpb.MessageType_AUTH), &channeldpb.AuthResultMessage{}, handleAuth) 84 c.SetMessageEntry(uint32(channeldpb.MessageType_CREATE_CHANNEL), &channeldpb.CreateChannelResultMessage{}, handleCreateChannel) 85 c.SetMessageEntry(uint32(channeldpb.MessageType_REMOVE_CHANNEL), &channeldpb.RemoveChannelMessage{}, handleRemoveChannel) 86 c.SetMessageEntry(uint32(channeldpb.MessageType_SUB_TO_CHANNEL), &channeldpb.SubscribedToChannelResultMessage{}, handleSubToChannel) 87 c.SetMessageEntry(uint32(channeldpb.MessageType_UNSUB_FROM_CHANNEL), &channeldpb.UnsubscribedFromChannelResultMessage{}, handleUnsubToChannel) 88 c.SetMessageEntry(uint32(channeldpb.MessageType_LIST_CHANNEL), &channeldpb.ListChannelResultMessage{}, handleListChannel) 89 c.SetMessageEntry(uint32(channeldpb.MessageType_CHANNEL_DATA_UPDATE), &channeldpb.ChannelDataUpdateMessage{}, defaultMessageHandler) 90 91 return c, nil 92 } 93 94 func (client *ChanneldClient) Disconnect() error { 95 return client.Conn.Close() 96 } 97 98 func (client *ChanneldClient) SetMessageEntry(msgType uint32, msgTemplate Message, handlers ...MessageHandlerFunc) { 99 client.messageMap[msgType] = &messageMapEntry{ 100 msg: msgTemplate, 101 handlers: handlers, 102 } 103 } 104 105 func (client *ChanneldClient) AddMessageHandler(msgType uint32, handlers ...MessageHandlerFunc) error { 106 entry := client.messageMap[msgType] 107 if entry != nil { 108 entry.handlers = append(entry.handlers, handlers...) 109 return nil 110 } else { 111 return fmt.Errorf("failed to add handler as the message entry not found, msgType: %d", msgType) 112 } 113 } 114 115 func (client *ChanneldClient) Auth(lt string, pit string) { 116 //result := make(chan *channeldpb.AuthResultMessage) 117 client.Send(0, channeldpb.BroadcastType_NO_BROADCAST, uint32(channeldpb.MessageType_AUTH), &channeldpb.AuthMessage{ 118 LoginToken: lt, 119 PlayerIdentifierToken: pit, 120 }, nil) 121 //return result 122 } 123 124 func handleAuth(client *ChanneldClient, channelId uint32, m Message) { 125 msg := m.(*channeldpb.AuthResultMessage) 126 127 if msg.Result == channeldpb.AuthResultMessage_SUCCESSFUL { 128 if client.Id == 0 { 129 client.Id = msg.ConnId 130 client.CompressionType = msg.CompressionType 131 } 132 133 // client.Send(0, channeldpb.BroadcastType_NO_BROADCAST, uint32(channeldpb.MessageType_SUB_TO_CHANNEL), &channeldpb.SubscribedToChannelMessage{ 134 // ConnId: client.Id, 135 // }, nil) 136 } 137 } 138 139 func handleCreateChannel(c *ChanneldClient, channelId uint32, m Message) { 140 c.CreatedChannels[channelId] = struct{}{} 141 } 142 143 func handleRemoveChannel(client *ChanneldClient, channelId uint32, m Message) { 144 msg := m.(*channeldpb.RemoveChannelMessage) 145 delete(client.SubscribedChannels, msg.ChannelId) 146 delete(client.CreatedChannels, msg.ChannelId) 147 delete(client.ListedChannels, msg.ChannelId) 148 } 149 150 func handleSubToChannel(client *ChanneldClient, channelId uint32, m Message) { 151 client.SubscribedChannels[channelId] = struct{}{} 152 } 153 154 func handleUnsubToChannel(c *ChanneldClient, channelId uint32, m Message) { 155 delete(c.SubscribedChannels, channelId) 156 } 157 158 func handleListChannel(c *ChanneldClient, channelId uint32, m Message) { 159 c.ListedChannels = map[uint32]struct{}{} 160 for _, info := range m.(*channeldpb.ListChannelResultMessage).Channels { 161 c.ListedChannels[info.ChannelId] = struct{}{} 162 } 163 } 164 165 func defaultMessageHandler(client *ChanneldClient, channelId uint32, m Message) { 166 //log.Printf("Client(%d) received message from channel %d: %s", client.Id, channelId, m) 167 } 168 169 func (client *ChanneldClient) IsConnected() bool { 170 return client.connected 171 } 172 173 func (client *ChanneldClient) Receive() error { 174 readPtr := client.readBuffer[client.readPos:] 175 bytesRead, err := client.Conn.Read(readPtr) 176 if err != nil { 177 return err 178 } 179 180 client.readPos += bytesRead 181 if client.readPos < 5 { 182 // Unfinished header 183 return nil 184 } 185 186 tag := client.readBuffer[:5] 187 if tag[0] != 67 { 188 return fmt.Errorf("invalid tag: %s, the packet will be dropped: %w", tag, err) 189 } 190 191 packetSize := int(tag[3]) 192 if tag[1] != 72 { 193 packetSize = packetSize | int(tag[1])<<16 | int(tag[2])<<8 194 } else if tag[2] != 78 { 195 packetSize = packetSize | int(tag[2])<<8 196 } 197 198 fullSize := 5 + packetSize 199 if client.readPos < fullSize { 200 // Unfinished packet 201 return nil 202 } 203 204 bytes := client.readBuffer[5:fullSize] 205 206 // Apply the decompression from the 5th byte in the header 207 // Apply the decompression from the 5th byte in the header 208 if tag[4] == byte(channeldpb.CompressionType_SNAPPY) { 209 len, err := snappy.DecodedLen(bytes) 210 if err != nil { 211 return fmt.Errorf("snappy.DecodedLen: %w", err) 212 } 213 dst := make([]byte, len) 214 bytes, err = snappy.Decode(dst, bytes) 215 if err != nil { 216 return fmt.Errorf("snappy.Decode: %w", err) 217 } 218 } 219 220 var p channeldpb.Packet 221 if err := proto.Unmarshal(bytes, &p); err != nil { 222 return fmt.Errorf("error unmarshalling packet: %w", err) 223 } 224 225 for _, mp := range p.Messages { 226 entry := client.messageMap[mp.MsgType] 227 if entry == nil { 228 return fmt.Errorf("no message type registered: %d", mp.MsgType) 229 } 230 231 // Always make a clone! 232 msg := proto.Clone(entry.msg) 233 err = proto.Unmarshal(mp.MsgBody, msg) 234 if err != nil { 235 return fmt.Errorf("failed to unmarshal message: %w", err) 236 } 237 238 client.incomingQueue <- messageQueueEntry{msg, mp.ChannelId, mp.StubId, entry.handlers} 239 } 240 241 client.readPos = 0 242 243 return nil 244 } 245 246 func (client *ChanneldClient) Tick() error { 247 for len(client.incomingQueue) > 0 { 248 entry := <-client.incomingQueue 249 250 for _, handler := range entry.handlers { 251 handler(client, entry.channelId, entry.msg) 252 } 253 254 if entry.stubId > 0 { 255 callback := client.stubCallbacks[entry.stubId] 256 if callback != nil { 257 callback(client, entry.channelId, entry.msg) 258 } 259 } 260 } 261 262 if len(client.outgoingQueue) == 0 { 263 return nil 264 } 265 266 p := channeldpb.Packet{Messages: make([]*channeldpb.MessagePack, 0, len(client.outgoingQueue))} 267 size := 0 268 for len(client.outgoingQueue) > 0 { 269 mp := <-client.outgoingQueue 270 if size+proto.Size(mp) >= 0xfffff0 { 271 break 272 } 273 p.Messages = append(p.Messages, mp) 274 } 275 return client.writePacket(&p) 276 } 277 278 func (client *ChanneldClient) Send(channelId uint32, broadcast channeldpb.BroadcastType, msgType uint32, msg Message, callback MessageHandlerFunc) error { 279 var stubId uint32 = 0 280 if callback != nil { 281 for client.stubCallbacks[stubId] != nil { 282 stubId++ 283 } 284 client.stubCallbacks[stubId] = callback 285 } 286 287 msgBody, err := proto.Marshal(msg) 288 if err != nil { 289 return fmt.Errorf("failed to marshal message %d: %s. Error: %w", msgType, msg, err) 290 } 291 292 client.outgoingQueue <- &channeldpb.MessagePack{ 293 ChannelId: channelId, 294 Broadcast: uint32(broadcast), 295 StubId: stubId, 296 MsgType: msgType, 297 MsgBody: msgBody, 298 } 299 return nil 300 } 301 302 func (client *ChanneldClient) SendRaw(channelId uint32, broadcast channeldpb.BroadcastType, msgType uint32, msgBody *[]byte, callback MessageHandlerFunc) error { 303 var stubId uint32 = 0 304 if callback != nil { 305 for client.stubCallbacks[stubId] != nil { 306 stubId++ 307 } 308 client.stubCallbacks[stubId] = callback 309 } 310 311 client.outgoingQueue <- &channeldpb.MessagePack{ 312 ChannelId: channelId, 313 Broadcast: uint32(broadcast), 314 StubId: stubId, 315 MsgType: msgType, 316 MsgBody: *msgBody, 317 } 318 return nil 319 } 320 321 func (client *ChanneldClient) writePacket(p *channeldpb.Packet) error { 322 bytes, err := proto.Marshal(p) 323 if err != nil { 324 return fmt.Errorf("error marshalling packet: %w", err) 325 } 326 327 // Apply the compression 328 if client.CompressionType == channeldpb.CompressionType_SNAPPY { 329 dst := make([]byte, snappy.MaxEncodedLen(len(bytes))) 330 bytes = snappy.Encode(dst, bytes) 331 } 332 333 // 'CHNL' in ASCII 334 tag := []byte{67, 72, 78, 76, byte(client.CompressionType)} 335 len := len(bytes) 336 tag[3] = byte(len & 0xff) 337 if len > 0xff { 338 tag[2] = byte((len >> 8) & 0xff) 339 } 340 if len > 0xffff { 341 tag[1] = byte((len >> 16) & 0xff) 342 } 343 344 client.writeMutex.Lock() 345 defer client.writeMutex.Unlock() 346 /* With WebSocket, every Write() sends a message. 347 client.conn.Write(tag) 348 client.conn.Write(bytes) 349 */ 350 client.Conn.Write(append(tag, bytes...)) 351 return nil 352 } 353 354 type wsConn struct { 355 conn *websocket.Conn 356 readBuf []byte 357 readIdx int 358 } 359 360 func (c *wsConn) Read(b []byte) (n int, err error) { 361 //c.SetReadDeadline(time.Now().Add(30 * time.Second)) 362 if c.readBuf == nil || c.readIdx >= len(c.readBuf) { 363 defer func() { 364 if recover() != nil { 365 err = errors.New("read on failed connection") 366 } 367 }() 368 _, c.readBuf, err = c.conn.ReadMessage() 369 if err != nil { 370 return 0, err 371 } 372 c.readIdx = 0 373 } 374 n = copy(b, c.readBuf[c.readIdx:]) 375 c.readIdx += n 376 return n, err 377 } 378 379 func (c *wsConn) Write(b []byte) (n int, err error) { 380 return len(b), c.conn.WriteMessage(websocket.BinaryMessage, b) 381 } 382 383 func (c *wsConn) Close() error { 384 return c.conn.Close() 385 } 386 387 func (c *wsConn) LocalAddr() net.Addr { 388 return c.conn.LocalAddr() 389 } 390 391 func (c *wsConn) RemoteAddr() net.Addr { 392 return c.conn.RemoteAddr() 393 } 394 395 func (c *wsConn) SetDeadline(t time.Time) error { 396 return c.conn.UnderlyingConn().SetDeadline(t) 397 } 398 399 func (c *wsConn) SetReadDeadline(t time.Time) error { 400 return c.conn.SetReadDeadline(t) 401 } 402 403 func (c *wsConn) SetWriteDeadline(t time.Time) error { 404 return c.conn.SetWriteDeadline(t) 405 }