github.com/Mrs4s/MiraiGo@v0.0.0-20240226124653-54bdd873e3fe/binary/jce/reader_test.go (about)

     1  package jce
     2  
     3  import (
     4  	"crypto/rand"
     5  	"reflect"
     6  	"strconv"
     7  	"sync"
     8  	"testing"
     9  	"unsafe"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  func TestJceReader_ReadSlice(t *testing.T) {
    15  	s := make([][]byte, 50)
    16  	for i := range s {
    17  		b := make([]byte, 64)
    18  		_, _ = rand.Read(b)
    19  		s[i] = b
    20  	}
    21  	w := NewJceWriter()
    22  	w.WriteBytesSlice(s, 1)
    23  	r := NewJceReader(w.Bytes())
    24  	result := r.ReadByteArrArr(1)
    25  	assert.Equal(t, s, result)
    26  }
    27  
    28  var test []*BigDataIPInfo
    29  
    30  func BenchmarkJceReader_ReadSlice(b *testing.B) {
    31  	for i := 0; i <= 500; i++ {
    32  		test = append(test, &BigDataIPInfo{
    33  			Type:   1,
    34  			Server: "test1",
    35  			Port:   8080,
    36  		})
    37  	}
    38  	w := NewJceWriter()
    39  	w.WriteObject(test, 1)
    40  	src := w.Bytes()
    41  	b.SetBytes(int64(len(src)))
    42  	b.StartTimer()
    43  	for i := 0; i < b.N; i++ {
    44  		r := NewJceReader(src)
    45  		_ = r.ReadBigDataIPInfos(1)
    46  	}
    47  }
    48  
    49  var req = RequestDataVersion2{
    50  	Map: map[string]map[string][]byte{
    51  		"1": {
    52  			"123": []byte(`123`),
    53  		},
    54  		"2": {
    55  			"123": []byte(`123`),
    56  		},
    57  		"3": {
    58  			"123": []byte(`123`),
    59  		},
    60  		"4": {
    61  			"123": []byte(`123`),
    62  		},
    63  		"5": {
    64  			"123": []byte(`123`),
    65  		},
    66  	},
    67  }
    68  
    69  func TestRequestDataVersion2_ReadFrom(t *testing.T) {
    70  	// todo(wdv): fuzz test
    71  	w := NewJceWriter()
    72  	w.writeMapStrMapStrBytes(req.Map, 0)
    73  	src := w.Bytes()
    74  	result := RequestDataVersion2{}
    75  	result.ReadFrom(NewJceReader(src))
    76  	assert.Equal(t, req, result)
    77  }
    78  
    79  func BenchmarkRequestDataVersion2_ReadFrom(b *testing.B) {
    80  	w := NewJceWriter()
    81  	w.writeMapStrMapStrBytes(req.Map, 0)
    82  	src := w.Bytes()
    83  	b.SetBytes(int64(len(src)))
    84  	result := &RequestDataVersion2{}
    85  	for i := 0; i < b.N; i++ {
    86  		result.ReadFrom(NewJceReader(src))
    87  	}
    88  }
    89  
    90  func TestJceReader_ReadBytes(t *testing.T) {
    91  	b := make([]byte, 1024)
    92  	rand.Read(b)
    93  
    94  	w := NewJceWriter()
    95  	w.WriteBytes(b, 0)
    96  	r := NewJceReader(w.Bytes())
    97  	rb := r.ReadBytes(0)
    98  
    99  	assert.Equal(t, b, rb)
   100  }
   101  
   102  func (w *JceWriter) WriteObject(i any, tag byte) {
   103  	t := reflect.TypeOf(i)
   104  	if t.Kind() == reflect.Map {
   105  		w.WriteMap(i, tag)
   106  		return
   107  	}
   108  	if t.Kind() == reflect.Slice {
   109  		if b, ok := i.([]byte); ok {
   110  			w.WriteBytes(b, tag)
   111  			return
   112  		}
   113  		w.WriteSlice(i, tag)
   114  		return
   115  	}
   116  	switch o := i.(type) {
   117  	case byte:
   118  		w.WriteByte(o, tag)
   119  	case bool:
   120  		w.WriteBool(o, tag)
   121  	case int16:
   122  		w.WriteInt16(o, tag)
   123  	case int32:
   124  		w.WriteInt32(o, tag)
   125  	case int64:
   126  		w.WriteInt64(o, tag)
   127  	case float32:
   128  		w.WriteFloat32(o, tag)
   129  	case float64:
   130  		w.WriteFloat64(o, tag)
   131  	case string:
   132  		w.WriteString(o, tag)
   133  	case IJceStruct:
   134  		w.WriteJceStruct(o, tag)
   135  	}
   136  }
   137  
   138  func (w *JceWriter) writeObject(v reflect.Value, tag byte) {
   139  	k := v.Kind()
   140  	if k == reflect.Map {
   141  		switch o := v.Interface().(type) {
   142  		case map[string]string:
   143  			w.writeMapStrStr(o, tag)
   144  		case map[string][]byte:
   145  			w.writeMapStrBytes(o, tag)
   146  		case map[string]map[string][]byte:
   147  			w.writeMapStrMapStrBytes(o, tag)
   148  		default:
   149  			w.writeMap(v, tag)
   150  		}
   151  		return
   152  	}
   153  	if k == reflect.Slice {
   154  		switch o := v.Interface().(type) {
   155  		case []byte:
   156  			w.WriteBytes(o, tag)
   157  		case []IJceStruct:
   158  			w.WriteJceStructSlice(o, tag)
   159  		default:
   160  			w.writeSlice(v, tag)
   161  		}
   162  		return
   163  	}
   164  	switch k {
   165  	case reflect.Uint8, reflect.Int8:
   166  		w.WriteByte(*(*byte)(pointerOf(v)), tag)
   167  	case reflect.Uint16, reflect.Int16:
   168  		w.WriteInt16(*(*int16)(pointerOf(v)), tag)
   169  	case reflect.Uint32, reflect.Int32:
   170  		w.WriteInt32(*(*int32)(pointerOf(v)), tag)
   171  	case reflect.Uint64, reflect.Int64:
   172  		w.WriteInt64(*(*int64)(pointerOf(v)), tag)
   173  	case reflect.String:
   174  		w.WriteString(v.String(), tag)
   175  	default:
   176  		switch o := v.Interface().(type) {
   177  		case IJceStruct:
   178  			w.WriteJceStruct(o, tag)
   179  		case float32:
   180  			w.WriteFloat32(o, tag)
   181  		case float64:
   182  			w.WriteFloat64(o, tag)
   183  		}
   184  	}
   185  }
   186  
   187  type decoder struct {
   188  	index int
   189  	id    int
   190  }
   191  
   192  var decoderCache = sync.Map{}
   193  
   194  // WriteJceStructRaw 写入 Jce 结构体
   195  func (w *JceWriter) WriteJceStructRaw(s any) {
   196  	t := reflect.TypeOf(s)
   197  	if t.Kind() != reflect.Ptr {
   198  		return
   199  	}
   200  	t = t.Elem()
   201  	v := reflect.ValueOf(s).Elem()
   202  	var jceDec []decoder
   203  	dec, ok := decoderCache.Load(t)
   204  	if ok { // 从缓存中加载
   205  		jceDec = dec.([]decoder)
   206  	} else { // 初次反射
   207  		jceDec = make([]decoder, 0, t.NumField())
   208  		for i := 0; i < t.NumField(); i++ {
   209  			field := t.Field(i)
   210  			strId := field.Tag.Get("jceId")
   211  			if strId == "" {
   212  				continue
   213  			}
   214  			id, err := strconv.Atoi(strId)
   215  			if err != nil {
   216  				continue
   217  			}
   218  			jceDec = append(jceDec, decoder{
   219  				index: i,
   220  				id:    id,
   221  			})
   222  		}
   223  		decoderCache.Store(t, jceDec) // 存入缓存
   224  	}
   225  	for _, dec := range jceDec {
   226  		obj := v.Field(dec.index)
   227  		w.writeObject(obj, byte(dec.id))
   228  	}
   229  }
   230  
   231  func (w *JceWriter) WriteJceStruct(s IJceStruct, tag byte) {
   232  	w.writeHead(10, tag)
   233  	w.WriteJceStructRaw(s)
   234  	w.writeHead(11, 0)
   235  }
   236  
   237  func (w *JceWriter) WriteSlice(i any, tag byte) {
   238  	va := reflect.ValueOf(i)
   239  	if va.Kind() != reflect.Slice {
   240  		panic("JceWriter.WriteSlice: not a slice")
   241  	}
   242  	w.writeSlice(va, tag)
   243  }
   244  
   245  func (w *JceWriter) writeSlice(slice reflect.Value, tag byte) {
   246  	if slice.Kind() != reflect.Slice {
   247  		return
   248  	}
   249  	w.writeHead(9, tag)
   250  	if slice.Len() == 0 {
   251  		w.writeHead(12, 0) // w.WriteInt32(0, 0)
   252  		return
   253  	}
   254  	w.WriteInt32(int32(slice.Len()), 0)
   255  	for i := 0; i < slice.Len(); i++ {
   256  		v := slice.Index(i)
   257  		w.writeObject(v, 0)
   258  	}
   259  }
   260  
   261  func (w *JceWriter) WriteJceStructSlice(l []IJceStruct, tag byte) {
   262  	w.writeHead(9, tag)
   263  	if len(l) == 0 {
   264  		w.writeHead(12, 0) // w.WriteInt32(0, 0)
   265  		return
   266  	}
   267  	w.WriteInt32(int32(len(l)), 0)
   268  	for _, v := range l {
   269  		w.WriteJceStruct(v, 0)
   270  	}
   271  }
   272  
   273  func (w *JceWriter) WriteMap(m any, tag byte) {
   274  	va := reflect.ValueOf(m)
   275  	if va.Kind() != reflect.Map {
   276  		panic("JceWriter.WriteMap: not a map")
   277  	}
   278  	w.writeMap(va, tag)
   279  }
   280  
   281  func (w *JceWriter) writeMap(m reflect.Value, tag byte) {
   282  	if m.IsNil() {
   283  		w.writeHead(8, tag)
   284  		w.writeHead(12, 0) // w.WriteInt32(0, 0)
   285  		return
   286  	}
   287  	if m.Kind() != reflect.Map {
   288  		return
   289  	}
   290  	w.writeHead(8, tag)
   291  	w.WriteInt32(int32(m.Len()), 0)
   292  	iter := m.MapRange()
   293  	for iter.Next() {
   294  		w.writeObject(iter.Key(), 0)
   295  		w.writeObject(iter.Value(), 1)
   296  	}
   297  }
   298  
   299  type value struct {
   300  	typ  unsafe.Pointer
   301  	data unsafe.Pointer
   302  	flag uintptr
   303  }
   304  
   305  func pointerOf(v reflect.Value) unsafe.Pointer {
   306  	return (*value)(unsafe.Pointer(&v)).data
   307  }