github.com/metaworking/channeld@v0.7.3/pkg/replay/replay.go (about) 1 package replay 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "io/ioutil" 8 "log" 9 "os" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 "github.com/metaworking/channeld/pkg/channeldpb" 15 "github.com/metaworking/channeld/pkg/client" 16 "github.com/metaworking/channeld/pkg/replaypb" 17 "google.golang.org/protobuf/proto" 18 ) 19 20 type Duration time.Duration 21 22 type CaseConfig struct { 23 ChanneldAddr string `json:"channeldAddr"` 24 ConnectionGroups []ConnectionGroupConfig `json:"connectionGroups"` 25 } 26 27 type ConnectionGroupConfig struct { 28 CprFilePath string `json:"cprFilePath"` 29 ConnectionNumber int `json:"connectionNumber"` // Number of connections 30 ConnectInterval Duration `json:"connectInterval"` // start connections interval 31 RunningTime Duration `json:"runningTime"` // replay session running time 32 SleepEndOfSession Duration `json:"sleepEndOfSession"` // sleep each end of session 33 MaxTickInterval Duration `json:"maxTickInterval"` // channeld connection max tick time 34 ActionIntervalMultiplier float64 `json:"actionIntervalMultiplier"` // used to adjust the packet offsettime (ActionIntervalMultiplier * ReplayPacket.Offsettime) 35 WaitAuthSuccess bool `json:"waitAuthSuccess"` // if true, replay loop will wait for auth success 36 AuthOnlyOnce bool `json:"authOnlyOnce"` // if true, only send auth message once in the entire replay 37 } 38 39 var DefaultConnGroupConfig = ConnectionGroupConfig{ 40 ConnectionNumber: 1, 41 ActionIntervalMultiplier: 1, 42 WaitAuthSuccess: true, 43 AuthOnlyOnce: true, 44 } 45 46 type AlterChannelIdBeforeSendHandlerFunc func(channelId uint32, msgType channeldpb.MessageType, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) (chId uint32, needToSend bool) 47 48 type BeforeSendMessageHandlerFunc func(msg proto.Message, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) (needToSend bool) 49 type beforeSendMessageMapEntry struct { 50 msgTemp proto.Message 51 beforeSendHandler BeforeSendMessageHandlerFunc 52 } 53 54 type MessageHandlerFunc func(c *client.ChanneldClient, channelId uint32, m proto.Message) 55 type messageMapEntry struct { 56 msg proto.Message 57 handlers []MessageHandlerFunc 58 } 59 60 type NeedWaitMessageCallbackHandlerFunc func(msgType channeldpb.MessageType, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) bool 61 62 type ReplayClient struct { 63 CaseConfig CaseConfig 64 ConnectionGroups []ConnectionGroup 65 alterChannelIdBeforeSendHandler AlterChannelIdBeforeSendHandlerFunc 66 beforeSendMessageMap map[channeldpb.MessageType]*beforeSendMessageMapEntry 67 needWaitMessageCallbackHandler NeedWaitMessageCallbackHandlerFunc 68 messageMap map[channeldpb.MessageType]*messageMapEntry 69 } 70 71 type ConnectionGroup struct { 72 config ConnectionGroupConfig 73 session *replaypb.ReplaySession 74 } 75 76 func CreateReplayClientByConfigFile(configPath string) (*ReplayClient, error) { 77 rc := &ReplayClient{} 78 err := rc.LoadCaseConfig(configPath) 79 if err != nil { 80 return nil, err 81 } 82 rc.beforeSendMessageMap = make(map[channeldpb.MessageType]*beforeSendMessageMapEntry) 83 rc.messageMap = make(map[channeldpb.MessageType]*messageMapEntry) 84 return rc, nil 85 } 86 87 func (d *Duration) UnmarshalJSON(b []byte) error { 88 var v interface{} 89 if err := json.Unmarshal(b, &v); err != nil { 90 return err 91 } 92 switch value := v.(type) { 93 case float64: 94 *d = Duration(time.Duration(value)) 95 return nil 96 case string: 97 tmp, err := time.ParseDuration(value) 98 if err != nil { 99 return err 100 } 101 *d = Duration(tmp) 102 return nil 103 default: 104 return errors.New("invalid duration") 105 } 106 } 107 108 func (rc *ReplayClient) LoadCaseConfig(path string) error { 109 110 config, err := ioutil.ReadFile(path) 111 if err == nil { 112 if err := json.Unmarshal(config, &rc.CaseConfig); err != nil { 113 return fmt.Errorf("failed to unmarshall case config: %v", err) 114 } 115 } else { 116 return fmt.Errorf("failed to load case config: %v", err) 117 } 118 119 for _, c := range rc.CaseConfig.ConnectionGroups { 120 session, err := ReadReplaySessionFile(c.CprFilePath) 121 if err != nil { 122 return fmt.Errorf("failed to read replay session file: %v", err) 123 } 124 rc.ConnectionGroups = append(rc.ConnectionGroups, ConnectionGroup{ 125 config: c, 126 session: session, 127 }) 128 } 129 return nil 130 } 131 132 func (rc *ReplayClient) SetAlterChannelIdBeforeSendHandler(handler AlterChannelIdBeforeSendHandlerFunc) { 133 rc.alterChannelIdBeforeSendHandler = handler 134 } 135 136 func (rc *ReplayClient) SetBeforeSendMessageEntry(msgType channeldpb.MessageType, msgTemp proto.Message, handler BeforeSendMessageHandlerFunc) { 137 rc.beforeSendMessageMap[msgType] = &beforeSendMessageMapEntry{ 138 msgTemp: msgTemp, 139 beforeSendHandler: handler, 140 } 141 } 142 143 func (rc *ReplayClient) AddMessageHandler(msgType channeldpb.MessageType, handlers ...MessageHandlerFunc) { 144 entry := rc.messageMap[msgType] 145 if entry != nil { 146 entry.handlers = append(entry.handlers, handlers...) 147 } else { 148 rc.messageMap[msgType] = &messageMapEntry{ 149 handlers: handlers, 150 } 151 } 152 } 153 154 func (rc *ReplayClient) SetMessageEntry(msgType uint32, msgTemp proto.Message, handlers ...MessageHandlerFunc) { 155 rc.messageMap[channeldpb.MessageType(msgType)] = &messageMapEntry{ 156 msg: msgTemp, 157 handlers: handlers, 158 } 159 } 160 161 func (rc *ReplayClient) SetNeedWaitMessageCallback(handler NeedWaitMessageCallbackHandlerFunc) { 162 rc.needWaitMessageCallbackHandler = handler 163 } 164 165 func ReadReplaySessionFile(cprPath string) (*replaypb.ReplaySession, error) { 166 data, err := os.ReadFile(cprPath) 167 if err != nil { 168 return nil, err 169 } 170 171 var rs replaypb.ReplaySession 172 if err = proto.Unmarshal(data, &rs); err != nil { 173 return nil, err 174 } 175 176 return &rs, nil 177 } 178 179 func (rc *ReplayClient) Run() { 180 for _, cg := range rc.ConnectionGroups { 181 cg.run(rc) 182 } 183 184 } 185 186 func (cg *ConnectionGroup) run(rc *ReplayClient) { 187 connNumber := cg.config.ConnectionNumber 188 maxTickInterval := time.Duration(cg.config.MaxTickInterval) 189 connInterval := time.Duration(cg.config.ConnectInterval) 190 runningTime := time.Duration(cg.config.RunningTime) 191 session := cg.session 192 authOnlyOnce := cg.config.AuthOnlyOnce 193 waitAuthSuccess := cg.config.WaitAuthSuccess 194 sleepEndOfSession := time.Duration(cg.config.SleepEndOfSession) 195 actionIntervalMultiplier := cg.config.ActionIntervalMultiplier 196 197 sessionPacketNum := len(session.Packets) 198 199 var wg sync.WaitGroup // wait all connections in the group to timeout 200 wg.Add(connNumber) 201 202 for ci := 0; ci < connNumber; ci++ { 203 go func() { 204 c, err := client.NewClient(rc.CaseConfig.ChanneldAddr) 205 if err != nil { 206 log.Println(err) 207 return 208 } 209 210 // Register msg handlers 211 for msgType, entry := range rc.messageMap { 212 handlers := make([]client.MessageHandlerFunc, 0, len(entry.handlers)) 213 for _, handler := range entry.handlers { 214 handlers = append(handlers, func(client *client.ChanneldClient, channelId uint32, m client.Message) { 215 handler(client, channelId, m) 216 }) 217 } 218 err := c.AddMessageHandler(uint32(msgType), handlers...) 219 if err != nil { 220 c.SetMessageEntry(uint32(msgType), entry.msg, handlers...) 221 } 222 } 223 224 go func() { 225 for { 226 if err := c.Receive(); err != nil { 227 log.Println(err) 228 return 229 } 230 } 231 }() 232 233 nextPacketIndex := 0 // replay packet index in session 234 prePacketTime := time.Now() // used to determine whether the next packet has arrived at the sending time 235 isFirstAuth := true 236 237 var waitMessageCallback int32 // counter for messages that has not been callback 238 239 for t := time.Now(); time.Since(t) < runningTime && c.IsConnected(); { 240 tickStartTime := time.Now() 241 242 if nextPacketIndex >= sessionPacketNum { 243 // If end of the session, delay sleepEndOfSession and replay packet that start of the session 244 nextPacketIndex = nextPacketIndex % sessionPacketNum 245 prePacketTime = prePacketTime.Add(sleepEndOfSession) 246 } 247 248 replayPacket := session.Packets[nextPacketIndex] 249 offsetTime := time.Duration(float64(replayPacket.OffsetTime) * actionIntervalMultiplier) 250 251 // If no messages that has not been callback and the next packet has arrived at the sending time 252 // try to send the packet 253 if waitMessageCallback == 0 && time.Since(prePacketTime) >= offsetTime { 254 255 prePacketTime = prePacketTime.Add(offsetTime) 256 nextPacketIndex++ 257 258 // Try to send messages in packet 259 for _, msgPack := range replayPacket.Packet.Messages { 260 261 msgType := channeldpb.MessageType(msgPack.MsgType) 262 263 if msgType == channeldpb.MessageType_AUTH { 264 if isFirstAuth { 265 isFirstAuth = false 266 } else if authOnlyOnce { 267 continue 268 } 269 } 270 271 channelId := msgPack.ChannelId 272 // Alter channelId if user set the alterChannelIdBeforeSendHandler 273 if rc.alterChannelIdBeforeSendHandler != nil { 274 newChId, needToSend := rc.alterChannelIdBeforeSendHandler(channelId, msgType, msgPack, c) 275 if !needToSend { 276 log.Printf("Connection(%v) pass message: { %v }", c.Id, msgPack.String()) 277 continue 278 } 279 channelId = newChId 280 } 281 282 var messageCallback func(client *client.ChanneldClient, channelId uint32, m client.Message) = nil 283 needWaitMessageCallback := (rc.needWaitMessageCallbackHandler != nil && rc.needWaitMessageCallbackHandler(msgType, msgPack, c)) || (msgType == channeldpb.MessageType_AUTH && waitAuthSuccess) 284 if needWaitMessageCallback { 285 messageCallback = func(client *client.ChanneldClient, channelId uint32, m client.Message) { 286 atomic.AddInt32(&waitMessageCallback, -1) 287 } 288 } 289 290 // The handler for alter or abandon message before send message 291 entry, ok := rc.beforeSendMessageMap[msgType] 292 293 if !ok && entry == nil { 294 // Send bytes directly 295 log.Printf("Connection(%v) send message: { %v }", c.Id, msgPack.String()) 296 if needWaitMessageCallback { 297 atomic.AddInt32(&waitMessageCallback, 1) 298 } 299 c.SendRaw(channelId, channeldpb.BroadcastType(msgPack.Broadcast), msgPack.MsgType, &msgPack.MsgBody, messageCallback) 300 } else { 301 // Copy message for uesr modify 302 msg := proto.Clone(entry.msgTemp) 303 err := proto.Unmarshal(msgPack.MsgBody, msg) 304 if err != nil { 305 log.Println(err) 306 return 307 } 308 needToSend := entry.beforeSendHandler(msg, msgPack, c) 309 if needToSend { 310 log.Printf("Connection(%v) send message: { %v }", c.Id, msgPack.String()) 311 if needWaitMessageCallback { 312 atomic.AddInt32(&waitMessageCallback, 1) 313 } 314 c.Send(channelId, channeldpb.BroadcastType(msgPack.Broadcast), msgPack.MsgType, msg, messageCallback) 315 } else { 316 log.Printf("Connection(%v) pass message: { %v }", c.Id, msgPack.String()) 317 continue 318 } 319 } 320 } 321 } 322 c.Tick() 323 time.Sleep(maxTickInterval - time.Since(tickStartTime)) 324 } 325 c.Disconnect() 326 time.Sleep(time.Millisecond * 100) // wait totally disconnect 327 wg.Done() 328 }() 329 // Run next connection after connectionInterval 330 time.Sleep(connInterval) 331 } 332 wg.Wait() 333 }