github.com/metaworking/channeld@v0.7.3/pkg/channeld/data.go (about) 1 package channeld 2 3 import ( 4 "container/list" 5 "fmt" 6 7 "github.com/indiest/fmutils" 8 "github.com/metaworking/channeld/pkg/channeldpb" 9 "github.com/metaworking/channeld/pkg/common" 10 "go.uber.org/zap" 11 "google.golang.org/protobuf/proto" 12 13 "google.golang.org/protobuf/reflect/protoreflect" 14 "google.golang.org/protobuf/types/known/anypb" 15 ) 16 17 type ChannelData struct { 18 mergeOptions *channeldpb.ChannelDataMergeOptions 19 msg common.ChannelDataMessage 20 //updateMsg common.ChannelDataMessage 21 accumulatedUpdateMsg common.ChannelDataMessage 22 updateMsgBuffer *list.List 23 maxFanOutIntervalMs uint32 24 msgIndex uint64 25 } 26 27 // Indicate that the channel data message should be initialized with default values. 28 type ChannelDataInitializer interface { 29 common.Message 30 Init() error 31 } 32 33 type RemovableMapField interface { 34 GetRemoved() bool 35 } 36 37 type fanOutConnection struct { 38 conn ConnectionInChannel 39 hadFirstFanOut bool 40 lastFanOutTime ChannelTime 41 lastMessageIndex uint64 42 } 43 44 type updateMsgBufferElement struct { 45 updateMsg common.ChannelDataMessage 46 arrivalTime ChannelTime 47 senderConnId ConnectionId 48 messageIndex uint64 49 } 50 51 const ( 52 MaxUpdateMsgBufferSize = 512 53 ) 54 55 var channelDataTypeRegistery = make(map[channeldpb.ChannelType]proto.Message) 56 57 // Register a Protobuf message template as the channel data of a specific channel type. 58 // This is needed when channeld doesn't know the package of the message is in, 59 // as well as creating a ChannelData using ReflectChannelData() 60 func RegisterChannelDataType(channelType channeldpb.ChannelType, msgTemplate proto.Message) { 61 msg, exists := channelDataTypeRegistery[channelType] 62 63 if exists { 64 if rootLogger != nil { 65 rootLogger.Warn("channel data type already exists, won't be registered", 66 zap.String("channelType", channelType.String()), 67 zap.String("curMsgName", string(msg.ProtoReflect().Descriptor().FullName())), 68 zap.String("newMsgName", string(msgTemplate.ProtoReflect().Descriptor().FullName())), 69 ) 70 } 71 } else { 72 channelDataTypeRegistery[channelType] = msgTemplate 73 74 if rootLogger != nil { 75 rootLogger.Info("registered channel data type", 76 zap.String("channelType", channelType.String()), 77 zap.String("msgFullName", string(msgTemplate.ProtoReflect().Descriptor().FullName())), 78 ) 79 } 80 } 81 82 } 83 84 func ReflectChannelDataMessage(channelType channeldpb.ChannelType) (common.ChannelDataMessage, error) { 85 /* 86 channelTypeName := channelType.String() 87 dataTypeName := fmt.Sprintf("channeld.%sChannelDataMessage", 88 strcase.ToCamel(strings.ToLower(channelTypeName))) 89 dataType, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(dataTypeName)) 90 if err != nil { 91 return nil, fmt.Errorf("failed to create data for channel type %s: %w", channelTypeName, err) 92 } 93 */ 94 dataType, exists := channelDataTypeRegistery[channelType] 95 if !exists { 96 return nil, fmt.Errorf("no channel data type registered for channel type %s", channelType.String()) 97 } 98 99 return dataType.ProtoReflect().New().Interface(), nil 100 } 101 102 func (ch *Channel) InitData(dataMsg common.ChannelDataMessage, mergeOptions *channeldpb.ChannelDataMergeOptions) { 103 ch.data = &ChannelData{ 104 msg: dataMsg, 105 updateMsgBuffer: list.New(), 106 mergeOptions: mergeOptions, 107 } 108 109 if dataMsg == nil { 110 var err error 111 ch.data.msg, err = ReflectChannelDataMessage(ch.channelType) 112 if err != nil { 113 ch.logger.Info("unable to create default channel data message; will use the first received message to set", zap.String("chType", ch.channelType.String()), zap.Error(err)) 114 return 115 } 116 ch.data.accumulatedUpdateMsg = proto.Clone(ch.data.msg) 117 } 118 119 initializer, ok := ch.data.msg.(ChannelDataInitializer) 120 if ok { 121 if err := initializer.Init(); err != nil { 122 ch.logger.Error("failed to initialize channel data message", zap.Error(err)) 123 return 124 } 125 } 126 } 127 128 func (ch *Channel) Data() *ChannelData { 129 return ch.data 130 } 131 132 // CAUTION: this function is not goroutine-safe. Read/write to the channel data message should be done in the the channel's goroutine. 133 func (ch *Channel) GetDataMessage() common.ChannelDataMessage { 134 if ch.data == nil { 135 return nil 136 } 137 return ch.data.msg 138 } 139 140 func (ch *Channel) SetDataUpdateConnId(connId ConnectionId) { 141 ch.latestDataUpdateConnId = connId 142 } 143 144 func (d *ChannelData) OnUpdate(updateMsg common.ChannelDataMessage, t ChannelTime, senderConnId ConnectionId, spatialNotifier common.SpatialInfoChangedNotifier) { 145 if d.msg == nil { 146 d.msg = updateMsg 147 rootLogger.Info("initialized channel data with update message", 148 zap.Uint32("senderConnId", uint32(senderConnId)), 149 zap.String("msgName", string(updateMsg.ProtoReflect().Descriptor().FullName())), 150 ) 151 } else { 152 mergeWithOptions(d.msg, updateMsg, d.mergeOptions, spatialNotifier) 153 } 154 d.msgIndex = d.msgIndex + 1 155 d.updateMsgBuffer.PushBack(&updateMsgBufferElement{ 156 updateMsg: updateMsg, 157 arrivalTime: t, 158 senderConnId: senderConnId, 159 messageIndex: d.msgIndex, 160 }) 161 if d.updateMsgBuffer.Len() > MaxUpdateMsgBufferSize { 162 oldest := d.updateMsgBuffer.Front() 163 // Remove the oldest update message if it should has been fanned-out 164 if oldest.Value.(*updateMsgBufferElement).arrivalTime.AddMs(d.maxFanOutIntervalMs) < t { 165 d.updateMsgBuffer.Remove(oldest) 166 } 167 } 168 } 169 170 func (ch *Channel) tickData(t ChannelTime) { 171 if ch.data == nil || ch.data.msg == nil { 172 return 173 } 174 175 focp := ch.fanOutQueue.Front() 176 177 for focp != nil { 178 foc := focp.Value.(*fanOutConnection) 179 conn := foc.conn 180 if conn == nil || conn.IsClosing() { 181 tmp := focp.Next() 182 ch.fanOutQueue.Remove(focp) 183 focp = tmp 184 continue 185 } 186 // ch.subLock.RLock() 187 cs := ch.subscribedConnections[conn] 188 // ch.subLock.RUnlock() 189 if cs == nil || *cs.options.DataAccess == channeldpb.ChannelDataAccess_NO_ACCESS { 190 focp = focp.Next() 191 continue 192 } 193 194 /* 195 ---------------------------------------------------- 196 ^ ^ ^ 197 |------FanOutDelay------|---FanOutInterval---| 198 subTime firstFanOutTime secondFanOutTime 199 */ 200 nextFanOutTime := foc.lastFanOutTime.AddMs(*cs.options.FanOutIntervalMs) 201 // latestFanoutTime := foc.lastFanOutTime 202 if t >= nextFanOutTime { 203 latestFanoutTime := nextFanOutTime 204 var lastUpdateTime ChannelTime 205 bufp := ch.data.updateMsgBuffer.Front() 206 if ch.data.accumulatedUpdateMsg == nil { 207 ch.data.accumulatedUpdateMsg = ch.data.msg.ProtoReflect().New().Interface() 208 } else { 209 proto.Reset(ch.data.accumulatedUpdateMsg) 210 } 211 hasEverMerged := false 212 213 //if foc.lastFanOutTime <= cs.subTime { 214 if !foc.hadFirstFanOut { 215 // Send the whole data for the first time 216 ch.fanOutDataUpdate(conn, cs, ch.data.msg) 217 foc.hadFirstFanOut = true 218 foc.lastMessageIndex = ch.data.msgIndex 219 latestFanoutTime = t 220 } else if bufp != nil { 221 if foc.lastFanOutTime >= lastUpdateTime { 222 lastUpdateTime = foc.lastFanOutTime 223 } 224 225 for bufi := 0; bufi < ch.data.updateMsgBuffer.Len(); bufi++ { 226 be := bufp.Value.(*updateMsgBufferElement) 227 /* 228 ch.Logger().Trace("going through updateMsgBuffer", 229 zap.Int("bufi", bufi), 230 zap.Int64("lastUpdateTime", int64(lastUpdateTime)/1000), 231 zap.Int64("arrivalTime", int64(be.arrivalTime)/1000), 232 zap.Int64("nextFanOutTime", int64(nextFanOutTime)/1000), 233 zap.Uint32("senderConnId", uint32(be.senderConnId)), 234 ) 235 */ 236 237 if be.senderConnId == conn.Id() && *cs.options.SkipSelfUpdateFanOut { 238 bufp = bufp.Next() 239 continue 240 } 241 242 if be.arrivalTime >= lastUpdateTime && be.arrivalTime <= nextFanOutTime { 243 if !hasEverMerged { 244 proto.Merge(ch.data.accumulatedUpdateMsg, be.updateMsg) 245 } else { 246 mergeWithOptions(ch.data.accumulatedUpdateMsg, be.updateMsg, ch.data.mergeOptions, nil) 247 } 248 hasEverMerged = true 249 lastUpdateTime = be.arrivalTime 250 foc.lastMessageIndex = be.messageIndex 251 } 252 253 /* TODO: remove the out-dated buffer element to decrease the iteration time 254 if be.arrivalTime.AddMs(ch.data.maxFanOutIntervalMs*2) < t { 255 ch.data.updateMsgBuffer.Remove(bufp) 256 } 257 */ 258 259 bufp = bufp.Next() 260 } 261 262 if hasEverMerged { 263 ch.fanOutDataUpdate(conn, cs, ch.data.accumulatedUpdateMsg) 264 } 265 } 266 foc.lastFanOutTime = latestFanoutTime 267 268 temp := focp.Prev() 269 // Move the fanned-out connection to the back of the queue 270 for be := ch.fanOutQueue.Back(); be != nil; be = be.Prev() { 271 if be.Value.(*fanOutConnection).lastFanOutTime <= foc.lastFanOutTime { 272 ch.fanOutQueue.MoveAfter(focp, be) 273 if temp != nil { 274 focp = temp.Next() 275 } else { 276 focp = ch.fanOutQueue.Front() 277 } 278 279 break 280 } 281 } 282 } else { 283 focp = focp.Next() 284 } 285 } 286 } 287 288 func (ch *Channel) fanOutDataUpdate(conn ConnectionInChannel, cs *ChannelSubscription, updateMsg common.ChannelDataMessage) { 289 fmutils.Filter(updateMsg, cs.options.DataFieldMasks) 290 any, err := anypb.New(updateMsg) 291 if err != nil { 292 ch.Logger().Error("failed to marshal channel update data", zap.Error(err)) 293 return 294 } 295 conn.Send(MessageContext{ 296 MsgType: channeldpb.MessageType_CHANNEL_DATA_UPDATE, 297 Msg: &channeldpb.ChannelDataUpdateMessage{Data: any}, 298 Connection: nil, 299 Channel: ch, 300 Broadcast: 0, 301 StubId: 0, 302 ChannelId: uint32(ch.id), 303 }) 304 /* 305 conn.Logger().Trace("fan out", 306 zap.Int64("channelTime", int64(ch.GetTime())), 307 zap.Int64("lastFanOutTime", int64(cs.fanOutElement.Value.(*fanOutConnection).lastFanOutTime)), 308 zap.Stringer("updateMsg", updateMsg.(fmt.Stringer)), 309 ) 310 */ 311 // cs.lastFanOutTime = time.Now() 312 // cs.fanOutDataMsg = nil 313 } 314 315 // Implement this interface to manually merge the channel data. In most cases it can be MUCH more efficient than the default reflection-based merge. 316 type MergeableChannelData interface { 317 common.Message 318 Merge(src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions, spatialNotifier common.SpatialInfoChangedNotifier) error 319 } 320 321 func mergeWithOptions(dst common.ChannelDataMessage, src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions, spatialNotifier common.SpatialInfoChangedNotifier) { 322 mergeable, ok := dst.(MergeableChannelData) 323 if ok { 324 if options == nil { 325 options = &channeldpb.ChannelDataMergeOptions{ 326 ShouldReplaceList: false, 327 ListSizeLimit: 0, 328 TruncateTop: false, 329 ShouldCheckRemovableMapField: true, 330 } 331 } 332 if err := mergeable.Merge(src, options, spatialNotifier); err != nil { 333 rootLogger.Error("custom merge error", zap.Error(err), 334 zap.String("dstType", string(dst.ProtoReflect().Descriptor().FullName().Name())), 335 zap.String("srcType", string(src.ProtoReflect().Descriptor().FullName().Name())), 336 ) 337 } 338 } else { 339 ReflectMerge(dst, src, options) 340 } 341 } 342 343 // Use protoreflect to merge. No need to write custom merge code but less efficient. 344 func ReflectMerge(dst common.ChannelDataMessage, src common.ChannelDataMessage, options *channeldpb.ChannelDataMergeOptions) { 345 proto.Merge(dst, src) 346 347 if options != nil { 348 //logger.Debug("merged with options", zap.Any("src", src), zap.Any("dst", dst)) 349 defer func() { 350 recover() 351 }() 352 353 dst.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 354 if fd.IsList() { 355 if options.ShouldReplaceList { 356 dst.ProtoReflect().Set(fd, src.ProtoReflect().Get(fd)) 357 } 358 list := v.List() 359 offset := list.Len() - int(options.ListSizeLimit) 360 if options.ListSizeLimit > 0 && offset > 0 { 361 if options.TruncateTop { 362 for i := 0; i < int(options.ListSizeLimit); i++ { 363 list.Set(i, list.Get(i+offset)) 364 } 365 } 366 list.Truncate(int(options.ListSizeLimit)) 367 } 368 } else if fd.IsMap() { 369 if options.ShouldCheckRemovableMapField { 370 dstMap := v.Map() 371 dstMap.Range(func(mk protoreflect.MapKey, mv protoreflect.Value) bool { 372 removable, ok := mv.Message().Interface().(RemovableMapField) 373 if ok && removable.GetRemoved() { 374 dstMap.Clear(mk) 375 } 376 return true 377 }) 378 } 379 } 380 return true 381 }) 382 } 383 }