github.com/metaworking/channeld@v0.7.3/pkg/replay/replay.go (about)

     1  package replay
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"log"
     9  	"os"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/metaworking/channeld/pkg/channeldpb"
    15  	"github.com/metaworking/channeld/pkg/client"
    16  	"github.com/metaworking/channeld/pkg/replaypb"
    17  	"google.golang.org/protobuf/proto"
    18  )
    19  
    20  type Duration time.Duration
    21  
    22  type CaseConfig struct {
    23  	ChanneldAddr     string                  `json:"channeldAddr"`
    24  	ConnectionGroups []ConnectionGroupConfig `json:"connectionGroups"`
    25  }
    26  
    27  type ConnectionGroupConfig struct {
    28  	CprFilePath              string   `json:"cprFilePath"`
    29  	ConnectionNumber         int      `json:"connectionNumber"`         // Number of connections
    30  	ConnectInterval          Duration `json:"connectInterval"`          // start connections interval
    31  	RunningTime              Duration `json:"runningTime"`              // replay session running time
    32  	SleepEndOfSession        Duration `json:"sleepEndOfSession"`        // sleep each end of session
    33  	MaxTickInterval          Duration `json:"maxTickInterval"`          // channeld connection max tick time
    34  	ActionIntervalMultiplier float64  `json:"actionIntervalMultiplier"` // used to adjust the packet offsettime (ActionIntervalMultiplier * ReplayPacket.Offsettime)
    35  	WaitAuthSuccess          bool     `json:"waitAuthSuccess"`          // if true, replay loop will wait for auth success
    36  	AuthOnlyOnce             bool     `json:"authOnlyOnce"`             // if true, only send auth message once in the entire replay
    37  }
    38  
    39  var DefaultConnGroupConfig = ConnectionGroupConfig{
    40  	ConnectionNumber:         1,
    41  	ActionIntervalMultiplier: 1,
    42  	WaitAuthSuccess:          true,
    43  	AuthOnlyOnce:             true,
    44  }
    45  
    46  type AlterChannelIdBeforeSendHandlerFunc func(channelId uint32, msgType channeldpb.MessageType, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) (chId uint32, needToSend bool)
    47  
    48  type BeforeSendMessageHandlerFunc func(msg proto.Message, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) (needToSend bool)
    49  type beforeSendMessageMapEntry struct {
    50  	msgTemp           proto.Message
    51  	beforeSendHandler BeforeSendMessageHandlerFunc
    52  }
    53  
    54  type MessageHandlerFunc func(c *client.ChanneldClient, channelId uint32, m proto.Message)
    55  type messageMapEntry struct {
    56  	msg      proto.Message
    57  	handlers []MessageHandlerFunc
    58  }
    59  
    60  type NeedWaitMessageCallbackHandlerFunc func(msgType channeldpb.MessageType, msgPack *channeldpb.MessagePack, c *client.ChanneldClient) bool
    61  
    62  type ReplayClient struct {
    63  	CaseConfig                      CaseConfig
    64  	ConnectionGroups                []ConnectionGroup
    65  	alterChannelIdBeforeSendHandler AlterChannelIdBeforeSendHandlerFunc
    66  	beforeSendMessageMap            map[channeldpb.MessageType]*beforeSendMessageMapEntry
    67  	needWaitMessageCallbackHandler  NeedWaitMessageCallbackHandlerFunc
    68  	messageMap                      map[channeldpb.MessageType]*messageMapEntry
    69  }
    70  
    71  type ConnectionGroup struct {
    72  	config  ConnectionGroupConfig
    73  	session *replaypb.ReplaySession
    74  }
    75  
    76  func CreateReplayClientByConfigFile(configPath string) (*ReplayClient, error) {
    77  	rc := &ReplayClient{}
    78  	err := rc.LoadCaseConfig(configPath)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	rc.beforeSendMessageMap = make(map[channeldpb.MessageType]*beforeSendMessageMapEntry)
    83  	rc.messageMap = make(map[channeldpb.MessageType]*messageMapEntry)
    84  	return rc, nil
    85  }
    86  
    87  func (d *Duration) UnmarshalJSON(b []byte) error {
    88  	var v interface{}
    89  	if err := json.Unmarshal(b, &v); err != nil {
    90  		return err
    91  	}
    92  	switch value := v.(type) {
    93  	case float64:
    94  		*d = Duration(time.Duration(value))
    95  		return nil
    96  	case string:
    97  		tmp, err := time.ParseDuration(value)
    98  		if err != nil {
    99  			return err
   100  		}
   101  		*d = Duration(tmp)
   102  		return nil
   103  	default:
   104  		return errors.New("invalid duration")
   105  	}
   106  }
   107  
   108  func (rc *ReplayClient) LoadCaseConfig(path string) error {
   109  
   110  	config, err := ioutil.ReadFile(path)
   111  	if err == nil {
   112  		if err := json.Unmarshal(config, &rc.CaseConfig); err != nil {
   113  			return fmt.Errorf("failed to unmarshall case config: %v", err)
   114  		}
   115  	} else {
   116  		return fmt.Errorf("failed to load case config: %v", err)
   117  	}
   118  
   119  	for _, c := range rc.CaseConfig.ConnectionGroups {
   120  		session, err := ReadReplaySessionFile(c.CprFilePath)
   121  		if err != nil {
   122  			return fmt.Errorf("failed to read replay session file: %v", err)
   123  		}
   124  		rc.ConnectionGroups = append(rc.ConnectionGroups, ConnectionGroup{
   125  			config:  c,
   126  			session: session,
   127  		})
   128  	}
   129  	return nil
   130  }
   131  
   132  func (rc *ReplayClient) SetAlterChannelIdBeforeSendHandler(handler AlterChannelIdBeforeSendHandlerFunc) {
   133  	rc.alterChannelIdBeforeSendHandler = handler
   134  }
   135  
   136  func (rc *ReplayClient) SetBeforeSendMessageEntry(msgType channeldpb.MessageType, msgTemp proto.Message, handler BeforeSendMessageHandlerFunc) {
   137  	rc.beforeSendMessageMap[msgType] = &beforeSendMessageMapEntry{
   138  		msgTemp:           msgTemp,
   139  		beforeSendHandler: handler,
   140  	}
   141  }
   142  
   143  func (rc *ReplayClient) AddMessageHandler(msgType channeldpb.MessageType, handlers ...MessageHandlerFunc) {
   144  	entry := rc.messageMap[msgType]
   145  	if entry != nil {
   146  		entry.handlers = append(entry.handlers, handlers...)
   147  	} else {
   148  		rc.messageMap[msgType] = &messageMapEntry{
   149  			handlers: handlers,
   150  		}
   151  	}
   152  }
   153  
   154  func (rc *ReplayClient) SetMessageEntry(msgType uint32, msgTemp proto.Message, handlers ...MessageHandlerFunc) {
   155  	rc.messageMap[channeldpb.MessageType(msgType)] = &messageMapEntry{
   156  		msg:      msgTemp,
   157  		handlers: handlers,
   158  	}
   159  }
   160  
   161  func (rc *ReplayClient) SetNeedWaitMessageCallback(handler NeedWaitMessageCallbackHandlerFunc) {
   162  	rc.needWaitMessageCallbackHandler = handler
   163  }
   164  
   165  func ReadReplaySessionFile(cprPath string) (*replaypb.ReplaySession, error) {
   166  	data, err := os.ReadFile(cprPath)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	var rs replaypb.ReplaySession
   172  	if err = proto.Unmarshal(data, &rs); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	return &rs, nil
   177  }
   178  
   179  func (rc *ReplayClient) Run() {
   180  	for _, cg := range rc.ConnectionGroups {
   181  		cg.run(rc)
   182  	}
   183  
   184  }
   185  
   186  func (cg *ConnectionGroup) run(rc *ReplayClient) {
   187  	connNumber := cg.config.ConnectionNumber
   188  	maxTickInterval := time.Duration(cg.config.MaxTickInterval)
   189  	connInterval := time.Duration(cg.config.ConnectInterval)
   190  	runningTime := time.Duration(cg.config.RunningTime)
   191  	session := cg.session
   192  	authOnlyOnce := cg.config.AuthOnlyOnce
   193  	waitAuthSuccess := cg.config.WaitAuthSuccess
   194  	sleepEndOfSession := time.Duration(cg.config.SleepEndOfSession)
   195  	actionIntervalMultiplier := cg.config.ActionIntervalMultiplier
   196  
   197  	sessionPacketNum := len(session.Packets)
   198  
   199  	var wg sync.WaitGroup // wait all connections in the group to timeout
   200  	wg.Add(connNumber)
   201  
   202  	for ci := 0; ci < connNumber; ci++ {
   203  		go func() {
   204  			c, err := client.NewClient(rc.CaseConfig.ChanneldAddr)
   205  			if err != nil {
   206  				log.Println(err)
   207  				return
   208  			}
   209  
   210  			// Register msg handlers
   211  			for msgType, entry := range rc.messageMap {
   212  				handlers := make([]client.MessageHandlerFunc, 0, len(entry.handlers))
   213  				for _, handler := range entry.handlers {
   214  					handlers = append(handlers, func(client *client.ChanneldClient, channelId uint32, m client.Message) {
   215  						handler(client, channelId, m)
   216  					})
   217  				}
   218  				err := c.AddMessageHandler(uint32(msgType), handlers...)
   219  				if err != nil {
   220  					c.SetMessageEntry(uint32(msgType), entry.msg, handlers...)
   221  				}
   222  			}
   223  
   224  			go func() {
   225  				for {
   226  					if err := c.Receive(); err != nil {
   227  						log.Println(err)
   228  						return
   229  					}
   230  				}
   231  			}()
   232  
   233  			nextPacketIndex := 0        // replay packet index in session
   234  			prePacketTime := time.Now() // used to determine whether the next packet has arrived at the sending time
   235  			isFirstAuth := true
   236  
   237  			var waitMessageCallback int32 // counter for messages that has not been callback
   238  
   239  			for t := time.Now(); time.Since(t) < runningTime && c.IsConnected(); {
   240  				tickStartTime := time.Now()
   241  
   242  				if nextPacketIndex >= sessionPacketNum {
   243  					// If end of the session, delay sleepEndOfSession and replay packet that start of the session
   244  					nextPacketIndex = nextPacketIndex % sessionPacketNum
   245  					prePacketTime = prePacketTime.Add(sleepEndOfSession)
   246  				}
   247  
   248  				replayPacket := session.Packets[nextPacketIndex]
   249  				offsetTime := time.Duration(float64(replayPacket.OffsetTime) * actionIntervalMultiplier)
   250  
   251  				// If no messages that has not been callback and the next packet has arrived at the sending time
   252  				// try to send the packet
   253  				if waitMessageCallback == 0 && time.Since(prePacketTime) >= offsetTime {
   254  
   255  					prePacketTime = prePacketTime.Add(offsetTime)
   256  					nextPacketIndex++
   257  
   258  					// Try to send messages in packet
   259  					for _, msgPack := range replayPacket.Packet.Messages {
   260  
   261  						msgType := channeldpb.MessageType(msgPack.MsgType)
   262  
   263  						if msgType == channeldpb.MessageType_AUTH {
   264  							if isFirstAuth {
   265  								isFirstAuth = false
   266  							} else if authOnlyOnce {
   267  								continue
   268  							}
   269  						}
   270  
   271  						channelId := msgPack.ChannelId
   272  						// Alter channelId if user set the alterChannelIdBeforeSendHandler
   273  						if rc.alterChannelIdBeforeSendHandler != nil {
   274  							newChId, needToSend := rc.alterChannelIdBeforeSendHandler(channelId, msgType, msgPack, c)
   275  							if !needToSend {
   276  								log.Printf("Connection(%v) pass message: { %v }", c.Id, msgPack.String())
   277  								continue
   278  							}
   279  							channelId = newChId
   280  						}
   281  
   282  						var messageCallback func(client *client.ChanneldClient, channelId uint32, m client.Message) = nil
   283  						needWaitMessageCallback := (rc.needWaitMessageCallbackHandler != nil && rc.needWaitMessageCallbackHandler(msgType, msgPack, c)) || (msgType == channeldpb.MessageType_AUTH && waitAuthSuccess)
   284  						if needWaitMessageCallback {
   285  							messageCallback = func(client *client.ChanneldClient, channelId uint32, m client.Message) {
   286  								atomic.AddInt32(&waitMessageCallback, -1)
   287  							}
   288  						}
   289  
   290  						// The handler for alter or abandon message before send message
   291  						entry, ok := rc.beforeSendMessageMap[msgType]
   292  
   293  						if !ok && entry == nil {
   294  							// Send bytes directly
   295  							log.Printf("Connection(%v) send message: { %v }", c.Id, msgPack.String())
   296  							if needWaitMessageCallback {
   297  								atomic.AddInt32(&waitMessageCallback, 1)
   298  							}
   299  							c.SendRaw(channelId, channeldpb.BroadcastType(msgPack.Broadcast), msgPack.MsgType, &msgPack.MsgBody, messageCallback)
   300  						} else {
   301  							// Copy message for uesr modify
   302  							msg := proto.Clone(entry.msgTemp)
   303  							err := proto.Unmarshal(msgPack.MsgBody, msg)
   304  							if err != nil {
   305  								log.Println(err)
   306  								return
   307  							}
   308  							needToSend := entry.beforeSendHandler(msg, msgPack, c)
   309  							if needToSend {
   310  								log.Printf("Connection(%v) send message: { %v }", c.Id, msgPack.String())
   311  								if needWaitMessageCallback {
   312  									atomic.AddInt32(&waitMessageCallback, 1)
   313  								}
   314  								c.Send(channelId, channeldpb.BroadcastType(msgPack.Broadcast), msgPack.MsgType, msg, messageCallback)
   315  							} else {
   316  								log.Printf("Connection(%v) pass message: { %v }", c.Id, msgPack.String())
   317  								continue
   318  							}
   319  						}
   320  					}
   321  				}
   322  				c.Tick()
   323  				time.Sleep(maxTickInterval - time.Since(tickStartTime))
   324  			}
   325  			c.Disconnect()
   326  			time.Sleep(time.Millisecond * 100) // wait totally disconnect
   327  			wg.Done()
   328  		}()
   329  		// Run next connection after connectionInterval
   330  		time.Sleep(connInterval)
   331  	}
   332  	wg.Wait()
   333  }