github.com/metaworking/channeld@v0.7.3/pkg/client/client.go (about)

     1  package client
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/golang/snappy"
    12  	"github.com/gorilla/websocket"
    13  	"github.com/metaworking/channeld/pkg/channeld"
    14  	"github.com/metaworking/channeld/pkg/channeldpb"
    15  	"google.golang.org/protobuf/proto"
    16  )
    17  
    18  type Message = proto.Message
    19  type MessageHandlerFunc func(client *ChanneldClient, channelId uint32, m Message)
    20  type messageMapEntry struct {
    21  	msg      Message
    22  	handlers []MessageHandlerFunc
    23  }
    24  type messageQueueEntry struct {
    25  	msg       Message
    26  	channelId uint32
    27  	stubId    uint32
    28  	handlers  []MessageHandlerFunc
    29  }
    30  
    31  // Go library for writing game client/server that interations with channeld.
    32  type ChanneldClient struct {
    33  	Id                 uint32
    34  	CompressionType    channeldpb.CompressionType
    35  	SubscribedChannels map[uint32]struct{}
    36  	CreatedChannels    map[uint32]struct{}
    37  	ListedChannels     map[uint32]struct{}
    38  	Conn               net.Conn
    39  	readBuffer         []byte
    40  	readPos            int
    41  	connected          bool
    42  	incomingQueue      chan messageQueueEntry
    43  	outgoingQueue      chan *channeldpb.MessagePack
    44  	messageMap         map[uint32]*messageMapEntry
    45  	stubCallbacks      map[uint32]MessageHandlerFunc
    46  	writeMutex         sync.Mutex
    47  }
    48  
    49  func NewClient(addr string) (*ChanneldClient, error) {
    50  	var conn net.Conn
    51  	if strings.HasPrefix(addr, "ws") {
    52  		c, _, err := websocket.DefaultDialer.Dial(addr, nil)
    53  		if err != nil {
    54  			return nil, err
    55  		}
    56  
    57  		conn = &wsConn{conn: c}
    58  	} else {
    59  		var err error
    60  		conn, err = net.Dial("tcp", addr)
    61  		if err != nil {
    62  			return nil, err
    63  		}
    64  	}
    65  	c := &ChanneldClient{
    66  		CompressionType:    channeldpb.CompressionType_NO_COMPRESSION,
    67  		SubscribedChannels: make(map[uint32]struct{}),
    68  		CreatedChannels:    make(map[uint32]struct{}),
    69  		ListedChannels:     make(map[uint32]struct{}),
    70  		Conn:               conn,
    71  		readBuffer:         make([]byte, channeld.MaxPacketSize),
    72  		readPos:            0,
    73  		connected:          true,
    74  		incomingQueue:      make(chan messageQueueEntry, 128),
    75  		outgoingQueue:      make(chan *channeldpb.MessagePack, 32),
    76  		messageMap:         make(map[uint32]*messageMapEntry),
    77  		stubCallbacks: map[uint32]MessageHandlerFunc{
    78  			// 0 is Reserved
    79  			0: func(_ *ChanneldClient, _ uint32, _ Message) {},
    80  		},
    81  	}
    82  
    83  	c.SetMessageEntry(uint32(channeldpb.MessageType_AUTH), &channeldpb.AuthResultMessage{}, handleAuth)
    84  	c.SetMessageEntry(uint32(channeldpb.MessageType_CREATE_CHANNEL), &channeldpb.CreateChannelResultMessage{}, handleCreateChannel)
    85  	c.SetMessageEntry(uint32(channeldpb.MessageType_REMOVE_CHANNEL), &channeldpb.RemoveChannelMessage{}, handleRemoveChannel)
    86  	c.SetMessageEntry(uint32(channeldpb.MessageType_SUB_TO_CHANNEL), &channeldpb.SubscribedToChannelResultMessage{}, handleSubToChannel)
    87  	c.SetMessageEntry(uint32(channeldpb.MessageType_UNSUB_FROM_CHANNEL), &channeldpb.UnsubscribedFromChannelResultMessage{}, handleUnsubToChannel)
    88  	c.SetMessageEntry(uint32(channeldpb.MessageType_LIST_CHANNEL), &channeldpb.ListChannelResultMessage{}, handleListChannel)
    89  	c.SetMessageEntry(uint32(channeldpb.MessageType_CHANNEL_DATA_UPDATE), &channeldpb.ChannelDataUpdateMessage{}, defaultMessageHandler)
    90  
    91  	return c, nil
    92  }
    93  
    94  func (client *ChanneldClient) Disconnect() error {
    95  	return client.Conn.Close()
    96  }
    97  
    98  func (client *ChanneldClient) SetMessageEntry(msgType uint32, msgTemplate Message, handlers ...MessageHandlerFunc) {
    99  	client.messageMap[msgType] = &messageMapEntry{
   100  		msg:      msgTemplate,
   101  		handlers: handlers,
   102  	}
   103  }
   104  
   105  func (client *ChanneldClient) AddMessageHandler(msgType uint32, handlers ...MessageHandlerFunc) error {
   106  	entry := client.messageMap[msgType]
   107  	if entry != nil {
   108  		entry.handlers = append(entry.handlers, handlers...)
   109  		return nil
   110  	} else {
   111  		return fmt.Errorf("failed to add handler as the message entry not found, msgType: %d", msgType)
   112  	}
   113  }
   114  
   115  func (client *ChanneldClient) Auth(lt string, pit string) {
   116  	//result := make(chan *channeldpb.AuthResultMessage)
   117  	client.Send(0, channeldpb.BroadcastType_NO_BROADCAST, uint32(channeldpb.MessageType_AUTH), &channeldpb.AuthMessage{
   118  		LoginToken:            lt,
   119  		PlayerIdentifierToken: pit,
   120  	}, nil)
   121  	//return result
   122  }
   123  
   124  func handleAuth(client *ChanneldClient, channelId uint32, m Message) {
   125  	msg := m.(*channeldpb.AuthResultMessage)
   126  
   127  	if msg.Result == channeldpb.AuthResultMessage_SUCCESSFUL {
   128  		if client.Id == 0 {
   129  			client.Id = msg.ConnId
   130  			client.CompressionType = msg.CompressionType
   131  		}
   132  
   133  		// client.Send(0, channeldpb.BroadcastType_NO_BROADCAST, uint32(channeldpb.MessageType_SUB_TO_CHANNEL), &channeldpb.SubscribedToChannelMessage{
   134  		// 	ConnId: client.Id,
   135  		// }, nil)
   136  	}
   137  }
   138  
   139  func handleCreateChannel(c *ChanneldClient, channelId uint32, m Message) {
   140  	c.CreatedChannels[channelId] = struct{}{}
   141  }
   142  
   143  func handleRemoveChannel(client *ChanneldClient, channelId uint32, m Message) {
   144  	msg := m.(*channeldpb.RemoveChannelMessage)
   145  	delete(client.SubscribedChannels, msg.ChannelId)
   146  	delete(client.CreatedChannels, msg.ChannelId)
   147  	delete(client.ListedChannels, msg.ChannelId)
   148  }
   149  
   150  func handleSubToChannel(client *ChanneldClient, channelId uint32, m Message) {
   151  	client.SubscribedChannels[channelId] = struct{}{}
   152  }
   153  
   154  func handleUnsubToChannel(c *ChanneldClient, channelId uint32, m Message) {
   155  	delete(c.SubscribedChannels, channelId)
   156  }
   157  
   158  func handleListChannel(c *ChanneldClient, channelId uint32, m Message) {
   159  	c.ListedChannels = map[uint32]struct{}{}
   160  	for _, info := range m.(*channeldpb.ListChannelResultMessage).Channels {
   161  		c.ListedChannels[info.ChannelId] = struct{}{}
   162  	}
   163  }
   164  
   165  func defaultMessageHandler(client *ChanneldClient, channelId uint32, m Message) {
   166  	//log.Printf("Client(%d) received message from channel %d: %s", client.Id, channelId, m)
   167  }
   168  
   169  func (client *ChanneldClient) IsConnected() bool {
   170  	return client.connected
   171  }
   172  
   173  func (client *ChanneldClient) Receive() error {
   174  	readPtr := client.readBuffer[client.readPos:]
   175  	bytesRead, err := client.Conn.Read(readPtr)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	client.readPos += bytesRead
   181  	if client.readPos < 5 {
   182  		// Unfinished header
   183  		return nil
   184  	}
   185  
   186  	tag := client.readBuffer[:5]
   187  	if tag[0] != 67 {
   188  		return fmt.Errorf("invalid tag: %s, the packet will be dropped: %w", tag, err)
   189  	}
   190  
   191  	packetSize := int(tag[3])
   192  	if tag[1] != 72 {
   193  		packetSize = packetSize | int(tag[1])<<16 | int(tag[2])<<8
   194  	} else if tag[2] != 78 {
   195  		packetSize = packetSize | int(tag[2])<<8
   196  	}
   197  
   198  	fullSize := 5 + packetSize
   199  	if client.readPos < fullSize {
   200  		// Unfinished packet
   201  		return nil
   202  	}
   203  
   204  	bytes := client.readBuffer[5:fullSize]
   205  
   206  	// Apply the decompression from the 5th byte in the header
   207  	// Apply the decompression from the 5th byte in the header
   208  	if tag[4] == byte(channeldpb.CompressionType_SNAPPY) {
   209  		len, err := snappy.DecodedLen(bytes)
   210  		if err != nil {
   211  			return fmt.Errorf("snappy.DecodedLen: %w", err)
   212  		}
   213  		dst := make([]byte, len)
   214  		bytes, err = snappy.Decode(dst, bytes)
   215  		if err != nil {
   216  			return fmt.Errorf("snappy.Decode: %w", err)
   217  		}
   218  	}
   219  
   220  	var p channeldpb.Packet
   221  	if err := proto.Unmarshal(bytes, &p); err != nil {
   222  		return fmt.Errorf("error unmarshalling packet: %w", err)
   223  	}
   224  
   225  	for _, mp := range p.Messages {
   226  		entry := client.messageMap[mp.MsgType]
   227  		if entry == nil {
   228  			return fmt.Errorf("no message type registered: %d", mp.MsgType)
   229  		}
   230  
   231  		// Always make a clone!
   232  		msg := proto.Clone(entry.msg)
   233  		err = proto.Unmarshal(mp.MsgBody, msg)
   234  		if err != nil {
   235  			return fmt.Errorf("failed to unmarshal message: %w", err)
   236  		}
   237  
   238  		client.incomingQueue <- messageQueueEntry{msg, mp.ChannelId, mp.StubId, entry.handlers}
   239  	}
   240  
   241  	client.readPos = 0
   242  
   243  	return nil
   244  }
   245  
   246  func (client *ChanneldClient) Tick() error {
   247  	for len(client.incomingQueue) > 0 {
   248  		entry := <-client.incomingQueue
   249  
   250  		for _, handler := range entry.handlers {
   251  			handler(client, entry.channelId, entry.msg)
   252  		}
   253  
   254  		if entry.stubId > 0 {
   255  			callback := client.stubCallbacks[entry.stubId]
   256  			if callback != nil {
   257  				callback(client, entry.channelId, entry.msg)
   258  			}
   259  		}
   260  	}
   261  
   262  	if len(client.outgoingQueue) == 0 {
   263  		return nil
   264  	}
   265  
   266  	p := channeldpb.Packet{Messages: make([]*channeldpb.MessagePack, 0, len(client.outgoingQueue))}
   267  	size := 0
   268  	for len(client.outgoingQueue) > 0 {
   269  		mp := <-client.outgoingQueue
   270  		if size+proto.Size(mp) >= 0xfffff0 {
   271  			break
   272  		}
   273  		p.Messages = append(p.Messages, mp)
   274  	}
   275  	return client.writePacket(&p)
   276  }
   277  
   278  func (client *ChanneldClient) Send(channelId uint32, broadcast channeldpb.BroadcastType, msgType uint32, msg Message, callback MessageHandlerFunc) error {
   279  	var stubId uint32 = 0
   280  	if callback != nil {
   281  		for client.stubCallbacks[stubId] != nil {
   282  			stubId++
   283  		}
   284  		client.stubCallbacks[stubId] = callback
   285  	}
   286  
   287  	msgBody, err := proto.Marshal(msg)
   288  	if err != nil {
   289  		return fmt.Errorf("failed to marshal message %d: %s. Error: %w", msgType, msg, err)
   290  	}
   291  
   292  	client.outgoingQueue <- &channeldpb.MessagePack{
   293  		ChannelId: channelId,
   294  		Broadcast: uint32(broadcast),
   295  		StubId:    stubId,
   296  		MsgType:   msgType,
   297  		MsgBody:   msgBody,
   298  	}
   299  	return nil
   300  }
   301  
   302  func (client *ChanneldClient) SendRaw(channelId uint32, broadcast channeldpb.BroadcastType, msgType uint32, msgBody *[]byte, callback MessageHandlerFunc) error {
   303  	var stubId uint32 = 0
   304  	if callback != nil {
   305  		for client.stubCallbacks[stubId] != nil {
   306  			stubId++
   307  		}
   308  		client.stubCallbacks[stubId] = callback
   309  	}
   310  
   311  	client.outgoingQueue <- &channeldpb.MessagePack{
   312  		ChannelId: channelId,
   313  		Broadcast: uint32(broadcast),
   314  		StubId:    stubId,
   315  		MsgType:   msgType,
   316  		MsgBody:   *msgBody,
   317  	}
   318  	return nil
   319  }
   320  
   321  func (client *ChanneldClient) writePacket(p *channeldpb.Packet) error {
   322  	bytes, err := proto.Marshal(p)
   323  	if err != nil {
   324  		return fmt.Errorf("error marshalling packet: %w", err)
   325  	}
   326  
   327  	// Apply the compression
   328  	if client.CompressionType == channeldpb.CompressionType_SNAPPY {
   329  		dst := make([]byte, snappy.MaxEncodedLen(len(bytes)))
   330  		bytes = snappy.Encode(dst, bytes)
   331  	}
   332  
   333  	// 'CHNL' in ASCII
   334  	tag := []byte{67, 72, 78, 76, byte(client.CompressionType)}
   335  	len := len(bytes)
   336  	tag[3] = byte(len & 0xff)
   337  	if len > 0xff {
   338  		tag[2] = byte((len >> 8) & 0xff)
   339  	}
   340  	if len > 0xffff {
   341  		tag[1] = byte((len >> 16) & 0xff)
   342  	}
   343  
   344  	client.writeMutex.Lock()
   345  	defer client.writeMutex.Unlock()
   346  	/* With WebSocket, every Write() sends a message.
   347  	client.conn.Write(tag)
   348  	client.conn.Write(bytes)
   349  	*/
   350  	client.Conn.Write(append(tag, bytes...))
   351  	return nil
   352  }
   353  
   354  type wsConn struct {
   355  	conn    *websocket.Conn
   356  	readBuf []byte
   357  	readIdx int
   358  }
   359  
   360  func (c *wsConn) Read(b []byte) (n int, err error) {
   361  	//c.SetReadDeadline(time.Now().Add(30 * time.Second))
   362  	if c.readBuf == nil || c.readIdx >= len(c.readBuf) {
   363  		defer func() {
   364  			if recover() != nil {
   365  				err = errors.New("read on failed connection")
   366  			}
   367  		}()
   368  		_, c.readBuf, err = c.conn.ReadMessage()
   369  		if err != nil {
   370  			return 0, err
   371  		}
   372  		c.readIdx = 0
   373  	}
   374  	n = copy(b, c.readBuf[c.readIdx:])
   375  	c.readIdx += n
   376  	return n, err
   377  }
   378  
   379  func (c *wsConn) Write(b []byte) (n int, err error) {
   380  	return len(b), c.conn.WriteMessage(websocket.BinaryMessage, b)
   381  }
   382  
   383  func (c *wsConn) Close() error {
   384  	return c.conn.Close()
   385  }
   386  
   387  func (c *wsConn) LocalAddr() net.Addr {
   388  	return c.conn.LocalAddr()
   389  }
   390  
   391  func (c *wsConn) RemoteAddr() net.Addr {
   392  	return c.conn.RemoteAddr()
   393  }
   394  
   395  func (c *wsConn) SetDeadline(t time.Time) error {
   396  	return c.conn.UnderlyingConn().SetDeadline(t)
   397  }
   398  
   399  func (c *wsConn) SetReadDeadline(t time.Time) error {
   400  	return c.conn.SetReadDeadline(t)
   401  }
   402  
   403  func (c *wsConn) SetWriteDeadline(t time.Time) error {
   404  	return c.conn.SetWriteDeadline(t)
   405  }