github.com/metaworking/channeld@v0.7.3/pkg/channeld/data.go (about)

     1  package channeld
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  
     7  	"github.com/indiest/fmutils"
     8  	"github.com/metaworking/channeld/pkg/channeldpb"
     9  	"github.com/metaworking/channeld/pkg/common"
    10  	"go.uber.org/zap"
    11  	"google.golang.org/protobuf/proto"
    12  
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  	"google.golang.org/protobuf/types/known/anypb"
    15  )
    16  
    17  type ChannelData struct {
    18  	mergeOptions *channeldpb.ChannelDataMergeOptions
    19  	msg          common.ChannelDataMessage
    20  	//updateMsg       common.ChannelDataMessage
    21  	accumulatedUpdateMsg common.ChannelDataMessage
    22  	updateMsgBuffer      *list.List
    23  	maxFanOutIntervalMs  uint32
    24  	msgIndex             uint64
    25  }
    26  
    27  // Indicate that the channel data message should be initialized with default values.
    28  type ChannelDataInitializer interface {
    29  	common.Message
    30  	Init() error
    31  }
    32  
    33  type RemovableMapField interface {
    34  	GetRemoved() bool
    35  }
    36  
    37  type fanOutConnection struct {
    38  	conn             ConnectionInChannel
    39  	hadFirstFanOut   bool
    40  	lastFanOutTime   ChannelTime
    41  	lastMessageIndex uint64
    42  }
    43  
    44  type updateMsgBufferElement struct {
    45  	updateMsg    common.ChannelDataMessage
    46  	arrivalTime  ChannelTime
    47  	senderConnId ConnectionId
    48  	messageIndex uint64
    49  }
    50  
    51  const (
    52  	MaxUpdateMsgBufferSize = 512
    53  )
    54  
    55  var channelDataTypeRegistery = make(map[channeldpb.ChannelType]proto.Message)
    56  
    57  // Register a Protobuf message template as the channel data of a specific channel type.
    58  // This is needed when channeld doesn't know the package of the message is in,
    59  // as well as creating a ChannelData using ReflectChannelData()
    60  func RegisterChannelDataType(channelType channeldpb.ChannelType, msgTemplate proto.Message) {
    61  	msg, exists := channelDataTypeRegistery[channelType]
    62  
    63  	if exists {
    64  		if rootLogger != nil {
    65  			rootLogger.Warn("channel data type already exists, won't be registered",
    66  				zap.String("channelType", channelType.String()),
    67  				zap.String("curMsgName", string(msg.ProtoReflect().Descriptor().FullName())),
    68  				zap.String("newMsgName", string(msgTemplate.ProtoReflect().Descriptor().FullName())),
    69  			)
    70  		}
    71  	} else {
    72  		channelDataTypeRegistery[channelType] = msgTemplate
    73  
    74  		if rootLogger != nil {
    75  			rootLogger.Info("registered channel data type",
    76  				zap.String("channelType", channelType.String()),
    77  				zap.String("msgFullName", string(msgTemplate.ProtoReflect().Descriptor().FullName())),
    78  			)
    79  		}
    80  	}
    81  
    82  }
    83  
    84  func ReflectChannelDataMessage(channelType channeldpb.ChannelType) (common.ChannelDataMessage, error) {
    85  	/*
    86  		channelTypeName := channelType.String()
    87  		dataTypeName := fmt.Sprintf("channeld.%sChannelDataMessage",
    88  			strcase.ToCamel(strings.ToLower(channelTypeName)))
    89  		dataType, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(dataTypeName))
    90  		if err != nil {
    91  			return nil, fmt.Errorf("failed to create data for channel type %s: %w", channelTypeName, err)
    92  		}
    93  	*/
    94  	dataType, exists := channelDataTypeRegistery[channelType]
    95  	if !exists {
    96  		return nil, fmt.Errorf("no channel data type registered for channel type %s", channelType.String())
    97  	}
    98  
    99  	return dataType.ProtoReflect().New().Interface(), nil
   100  }
   101  
   102  func (ch *Channel) InitData(dataMsg common.ChannelDataMessage, mergeOptions *channeldpb.ChannelDataMergeOptions) {
   103  	ch.data = &ChannelData{
   104  		msg:             dataMsg,
   105  		updateMsgBuffer: list.New(),
   106  		mergeOptions:    mergeOptions,
   107  	}
   108  
   109  	if dataMsg == nil {
   110  		var err error
   111  		ch.data.msg, err = ReflectChannelDataMessage(ch.channelType)
   112  		if err != nil {
   113  			ch.logger.Info("unable to create default channel data message; will use the first received message to set", zap.String("chType", ch.channelType.String()), zap.Error(err))
   114  			return
   115  		}
   116  		ch.data.accumulatedUpdateMsg = proto.Clone(ch.data.msg)
   117  	}
   118  
   119  	initializer, ok := ch.data.msg.(ChannelDataInitializer)
   120  	if ok {
   121  		if err := initializer.Init(); err != nil {
   122  			ch.logger.Error("failed to initialize channel data message", zap.Error(err))
   123  			return
   124  		}
   125  	}
   126  }
   127  
   128  func (ch *Channel) Data() *ChannelData {
   129  	return ch.data
   130  }
   131  
   132  // CAUTION: this function is not goroutine-safe. Read/write to the channel data message should be done in the the channel's goroutine.
   133  func (ch *Channel) GetDataMessage() common.ChannelDataMessage {
   134  	if ch.data == nil {
   135  		return nil
   136  	}
   137  	return ch.data.msg
   138  }
   139  
   140  func (ch *Channel) SetDataUpdateConnId(connId ConnectionId) {
   141  	ch.latestDataUpdateConnId = connId
   142  }
   143  
   144  func (d *ChannelData) OnUpdate(updateMsg common.ChannelDataMessage, t ChannelTime, senderConnId ConnectionId, spatialNotifier common.SpatialInfoChangedNotifier) {
   145  	if d.msg == nil {
   146  		d.msg = updateMsg
   147  		rootLogger.Info("initialized channel data with update message",
   148  			zap.Uint32("senderConnId", uint32(senderConnId)),
   149  			zap.String("msgName", string(updateMsg.ProtoReflect().Descriptor().FullName())),
   150  		)
   151  	} else {
   152  		mergeWithOptions(d.msg, updateMsg, d.mergeOptions, spatialNotifier)
   153  	}
   154  	d.msgIndex = d.msgIndex + 1
   155  	d.updateMsgBuffer.PushBack(&updateMsgBufferElement{
   156  		updateMsg:    updateMsg,
   157  		arrivalTime:  t,
   158  		senderConnId: senderConnId,
   159  		messageIndex: d.msgIndex,
   160  	})
   161  	if d.updateMsgBuffer.Len() > MaxUpdateMsgBufferSize {
   162  		oldest := d.updateMsgBuffer.Front()
   163  		// Remove the oldest update message if it should has been fanned-out
   164  		if oldest.Value.(*updateMsgBufferElement).arrivalTime.AddMs(d.maxFanOutIntervalMs) < t {
   165  			d.updateMsgBuffer.Remove(oldest)
   166  		}
   167  	}
   168  }
   169  
   170  func (ch *Channel) tickData(t ChannelTime) {
   171  	if ch.data == nil || ch.data.msg == nil {
   172  		return
   173  	}
   174  
   175  	focp := ch.fanOutQueue.Front()
   176  
   177  	for focp != nil {
   178  		foc := focp.Value.(*fanOutConnection)
   179  		conn := foc.conn
   180  		if conn == nil || conn.IsClosing() {
   181  			tmp := focp.Next()
   182  			ch.fanOutQueue.Remove(focp)
   183  			focp = tmp
   184  			continue
   185  		}
   186  		// ch.subLock.RLock()
   187  		cs := ch.subscribedConnections[conn]
   188  		// ch.subLock.RUnlock()
   189  		if cs == nil || *cs.options.DataAccess == channeldpb.ChannelDataAccess_NO_ACCESS {
   190  			focp = focp.Next()
   191  			continue
   192  		}
   193  
   194  		/*
   195  			----------------------------------------------------
   196  			   ^                       ^                    ^
   197  			   |------FanOutDelay------|---FanOutInterval---|
   198  			   subTime                 firstFanOutTime      secondFanOutTime
   199  		*/
   200  		nextFanOutTime := foc.lastFanOutTime.AddMs(*cs.options.FanOutIntervalMs)
   201  		// latestFanoutTime := foc.lastFanOutTime
   202  		if t >= nextFanOutTime {
   203  			latestFanoutTime := nextFanOutTime
   204  			var lastUpdateTime ChannelTime
   205  			bufp := ch.data.updateMsgBuffer.Front()
   206  			if ch.data.accumulatedUpdateMsg == nil {
   207  				ch.data.accumulatedUpdateMsg = ch.data.msg.ProtoReflect().New().Interface()
   208  			} else {
   209  				proto.Reset(ch.data.accumulatedUpdateMsg)
   210  			}
   211  			hasEverMerged := false
   212  
   213  			//if foc.lastFanOutTime <= cs.subTime {
   214  			if !foc.hadFirstFanOut {
   215  				// Send the whole data for the first time
   216  				ch.fanOutDataUpdate(conn, cs, ch.data.msg)
   217  				foc.hadFirstFanOut = true
   218  				foc.lastMessageIndex = ch.data.msgIndex
   219  				latestFanoutTime = t
   220  			} else if bufp != nil {
   221  				if foc.lastFanOutTime >= lastUpdateTime {
   222  					lastUpdateTime = foc.lastFanOutTime
   223  				}
   224  
   225  				for bufi := 0; bufi < ch.data.updateMsgBuffer.Len(); bufi++ {
   226  					be := bufp.Value.(*updateMsgBufferElement)
   227  					/*
   228  						ch.Logger().Trace("going through updateMsgBuffer",
   229  							zap.Int("bufi", bufi),
   230  							zap.Int64("lastUpdateTime", int64(lastUpdateTime)/1000),
   231  							zap.Int64("arrivalTime", int64(be.arrivalTime)/1000),
   232  							zap.Int64("nextFanOutTime", int64(nextFanOutTime)/1000),
   233  							zap.Uint32("senderConnId", uint32(be.senderConnId)),
   234  						)
   235  					*/
   236  
   237  					if be.senderConnId == conn.Id() && *cs.options.SkipSelfUpdateFanOut {
   238  						bufp = bufp.Next()
   239  						continue
   240  					}
   241  
   242  					if be.arrivalTime >= lastUpdateTime && be.arrivalTime <= nextFanOutTime {
   243  						if !hasEverMerged {
   244  							proto.Merge(ch.data.accumulatedUpdateMsg, be.updateMsg)
   245  						} else {
   246  							mergeWithOptions(ch.data.accumulatedUpdateMsg, be.updateMsg, ch.data.mergeOptions, nil)
   247  						}
   248  						hasEverMerged = true
   249  						lastUpdateTime = be.arrivalTime
   250  						foc.lastMessageIndex = be.messageIndex
   251  					}
   252  
   253  					/* TODO: remove the out-dated buffer element to decrease the iteration time
   254  					if be.arrivalTime.AddMs(ch.data.maxFanOutIntervalMs*2) < t {
   255  						ch.data.updateMsgBuffer.Remove(bufp)
   256  					}
   257  					*/
   258  
   259  					bufp = bufp.Next()
   260  				}
   261  
   262  				if hasEverMerged {
   263  					ch.fanOutDataUpdate(conn, cs, ch.data.accumulatedUpdateMsg)
   264  				}
   265  			}
   266  			foc.lastFanOutTime = latestFanoutTime
   267  
   268  			temp := focp.Prev()
   269  			// Move the fanned-out connection to the back of the queue
   270  			for be := ch.fanOutQueue.Back(); be != nil; be = be.Prev() {
   271  				if be.Value.(*fanOutConnection).lastFanOutTime <= foc.lastFanOutTime {
   272  					ch.fanOutQueue.MoveAfter(focp, be)
   273  					if temp != nil {
   274  						focp = temp.Next()
   275  					} else {
   276  						focp = ch.fanOutQueue.Front()
   277  					}
   278  
   279  					break
   280  				}
   281  			}
   282  		} else {
   283  			focp = focp.Next()
   284  		}
   285  	}
   286  }
   287  
   288  func (ch *Channel) fanOutDataUpdate(conn ConnectionInChannel, cs *ChannelSubscription, updateMsg common.ChannelDataMessage) {
   289  	fmutils.Filter(updateMsg, cs.options.DataFieldMasks)
   290  	any, err := anypb.New(updateMsg)
   291  	if err != nil {
   292  		ch.Logger().Error("failed to marshal channel update data", zap.Error(err))
   293  		return
   294  	}
   295  	conn.Send(MessageContext{
   296  		MsgType:    channeldpb.MessageType_CHANNEL_DATA_UPDATE,
   297  		Msg:        &channeldpb.ChannelDataUpdateMessage{Data: any},
   298  		Connection: nil,
   299  		Channel:    ch,
   300  		Broadcast:  0,
   301  		StubId:     0,
   302  		ChannelId:  uint32(ch.id),
   303  	})
   304  	/*
   305  		conn.Logger().Trace("fan out",
   306  			zap.Int64("channelTime", int64(ch.GetTime())),
   307  			zap.Int64("lastFanOutTime", int64(cs.fanOutElement.Value.(*fanOutConnection).lastFanOutTime)),
   308  			zap.Stringer("updateMsg", updateMsg.(fmt.Stringer)),
   309  		)
   310  	*/
   311  	// cs.lastFanOutTime = time.Now()
   312  	// cs.fanOutDataMsg = nil
   313  }
   314  
   315  // Implement this interface to manually merge the channel data. In most cases it can be MUCH more efficient than the default reflection-based merge.
   316  type MergeableChannelData interface {
   317  	common.Message
   318  	Merge(src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions, spatialNotifier common.SpatialInfoChangedNotifier) error
   319  }
   320  
   321  func mergeWithOptions(dst common.ChannelDataMessage, src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions, spatialNotifier common.SpatialInfoChangedNotifier) {
   322  	mergeable, ok := dst.(MergeableChannelData)
   323  	if ok {
   324  		if options == nil {
   325  			options = &channeldpb.ChannelDataMergeOptions{
   326  				ShouldReplaceList:            false,
   327  				ListSizeLimit:                0,
   328  				TruncateTop:                  false,
   329  				ShouldCheckRemovableMapField: true,
   330  			}
   331  		}
   332  		if err := mergeable.Merge(src, options, spatialNotifier); err != nil {
   333  			rootLogger.Error("custom merge error", zap.Error(err),
   334  				zap.String("dstType", string(dst.ProtoReflect().Descriptor().FullName().Name())),
   335  				zap.String("srcType", string(src.ProtoReflect().Descriptor().FullName().Name())),
   336  			)
   337  		}
   338  	} else {
   339  		ReflectMerge(dst, src, options)
   340  	}
   341  }
   342  
   343  // Use protoreflect to merge. No need to write custom merge code but less efficient.
   344  func ReflectMerge(dst common.ChannelDataMessage, src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions) {
   345  	proto.Merge(dst, src)
   346  
   347  	if options != nil {
   348  		//logger.Debug("merged with options", zap.Any("src", src), zap.Any("dst", dst))
   349  		defer func() {
   350  			recover()
   351  		}()
   352  
   353  		dst.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   354  			if fd.IsList() {
   355  				if options.ShouldReplaceList {
   356  					dst.ProtoReflect().Set(fd, src.ProtoReflect().Get(fd))
   357  				}
   358  				list := v.List()
   359  				offset := list.Len() - int(options.ListSizeLimit)
   360  				if options.ListSizeLimit > 0 && offset > 0 {
   361  					if options.TruncateTop {
   362  						for i := 0; i < int(options.ListSizeLimit); i++ {
   363  							list.Set(i, list.Get(i+offset))
   364  						}
   365  					}
   366  					list.Truncate(int(options.ListSizeLimit))
   367  				}
   368  			} else if fd.IsMap() {
   369  				if options.ShouldCheckRemovableMapField {
   370  					dstMap := v.Map()
   371  					dstMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool {
   372  						removable, ok := mv.Message().Interface().(RemovableMapField)
   373  						if ok && removable.GetRemoved() {
   374  							dstMap.Clear(mk)
   375  						}
   376  						return true
   377  					})
   378  				}
   379  			}
   380  			return true
   381  		})
   382  	}
   383  }