git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/mmdb/decoder_test.go (about)

     1  package mmdb
     2  
     3  import (
     4  	"encoding/hex"
     5  	"math/big"
     6  	"os"
     7  	"reflect"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestBool(t *testing.T) {
    16  	bools := map[string]any{
    17  		"0007": false,
    18  		"0107": true,
    19  	}
    20  
    21  	validateDecoding(t, bools)
    22  }
    23  
    24  func TestDouble(t *testing.T) {
    25  	doubles := map[string]any{
    26  		"680000000000000000": 0.0,
    27  		"683FE0000000000000": 0.5,
    28  		"68400921FB54442EEA": 3.14159265359,
    29  		"68405EC00000000000": 123.0,
    30  		"6841D000000007F8F4": 1073741824.12457,
    31  		"68BFE0000000000000": -0.5,
    32  		"68C00921FB54442EEA": -3.14159265359,
    33  		"68C1D000000007F8F4": -1073741824.12457,
    34  	}
    35  	validateDecoding(t, doubles)
    36  }
    37  
    38  func TestFloat(t *testing.T) {
    39  	floats := map[string]any{
    40  		"040800000000": float32(0.0),
    41  		"04083F800000": float32(1.0),
    42  		"04083F8CCCCD": float32(1.1),
    43  		"04084048F5C3": float32(3.14),
    44  		"0408461C3FF6": float32(9999.99),
    45  		"0408BF800000": float32(-1.0),
    46  		"0408BF8CCCCD": float32(-1.1),
    47  		"0408C048F5C3": -float32(3.14),
    48  		"0408C61C3FF6": float32(-9999.99),
    49  	}
    50  	validateDecoding(t, floats)
    51  }
    52  
    53  func TestInt32(t *testing.T) {
    54  	int32s := map[string]any{
    55  		"0001":         0,
    56  		"0401ffffffff": -1,
    57  		"0101ff":       255,
    58  		"0401ffffff01": -255,
    59  		"020101f4":     500,
    60  		"0401fffffe0c": -500,
    61  		"0201ffff":     65535,
    62  		"0401ffff0001": -65535,
    63  		"0301ffffff":   16777215,
    64  		"0401ff000001": -16777215,
    65  		"04017fffffff": 2147483647,
    66  		"040180000001": -2147483647,
    67  	}
    68  	validateDecoding(t, int32s)
    69  }
    70  
    71  func TestMap(t *testing.T) {
    72  	maps := map[string]any{
    73  		"e0":                             map[string]any{},
    74  		"e142656e43466f6f":               map[string]any{"en": "Foo"},
    75  		"e242656e43466f6f427a6843e4baba": map[string]any{"en": "Foo", "zh": "人"},
    76  		"e1446e616d65e242656e43466f6f427a6843e4baba": map[string]any{
    77  			"name": map[string]any{"en": "Foo", "zh": "人"},
    78  		},
    79  		"e1496c616e677561676573020442656e427a68": map[string]any{
    80  			"languages": []any{"en", "zh"},
    81  		},
    82  	}
    83  	validateDecoding(t, maps)
    84  }
    85  
    86  func TestSlice(t *testing.T) {
    87  	slice := map[string]any{
    88  		"0004":                 []any{},
    89  		"010443466f6f":         []any{"Foo"},
    90  		"020443466f6f43e4baba": []any{"Foo", "人"},
    91  	}
    92  	validateDecoding(t, slice)
    93  }
    94  
    95  var testStrings = makeTestStrings()
    96  
    97  func makeTestStrings() map[string]any {
    98  	str := map[string]any{
    99  		"40":       "",
   100  		"4131":     "1",
   101  		"43E4BABA": "人",
   102  		"5b313233343536373839303132333435363738393031323334353637":         "123456789012345678901234567",
   103  		"5c31323334353637383930313233343536373839303132333435363738":       "1234567890123456789012345678",
   104  		"5d003132333435363738393031323334353637383930313233343536373839":   "12345678901234567890123456789",
   105  		"5d01313233343536373839303132333435363738393031323334353637383930": "123456789012345678901234567890",
   106  	}
   107  
   108  	for k, v := range map[string]int{"5e00d7": 500, "5e06b3": 2000, "5f001053": 70000} {
   109  		key := k + strings.Repeat("78", v)
   110  		str[key] = strings.Repeat("x", v)
   111  	}
   112  
   113  	return str
   114  }
   115  
   116  func TestString(t *testing.T) {
   117  	validateDecoding(t, testStrings)
   118  }
   119  
   120  func TestByte(t *testing.T) {
   121  	b := make(map[string]any)
   122  	for key, val := range testStrings {
   123  		oldCtrl, err := hex.DecodeString(key[0:2])
   124  		require.NoError(t, err)
   125  		newCtrl := []byte{oldCtrl[0] ^ 0xc0}
   126  		key = strings.Replace(key, hex.EncodeToString(oldCtrl), hex.EncodeToString(newCtrl), 1)
   127  		b[key] = []byte(val.(string))
   128  	}
   129  
   130  	validateDecoding(t, b)
   131  }
   132  
   133  func TestUint16(t *testing.T) {
   134  	uint16s := map[string]any{
   135  		"a0":     uint64(0),
   136  		"a1ff":   uint64(255),
   137  		"a201f4": uint64(500),
   138  		"a22a78": uint64(10872),
   139  		"a2ffff": uint64(65535),
   140  	}
   141  	validateDecoding(t, uint16s)
   142  }
   143  
   144  func TestUint32(t *testing.T) {
   145  	uint32s := map[string]any{
   146  		"c0":         uint64(0),
   147  		"c1ff":       uint64(255),
   148  		"c201f4":     uint64(500),
   149  		"c22a78":     uint64(10872),
   150  		"c2ffff":     uint64(65535),
   151  		"c3ffffff":   uint64(16777215),
   152  		"c4ffffffff": uint64(4294967295),
   153  	}
   154  	validateDecoding(t, uint32s)
   155  }
   156  
   157  func TestUint64(t *testing.T) {
   158  	ctrlByte := "02"
   159  	bits := uint64(64)
   160  
   161  	uints := map[string]any{
   162  		"00" + ctrlByte:          uint64(0),
   163  		"02" + ctrlByte + "01f4": uint64(500),
   164  		"02" + ctrlByte + "2a78": uint64(10872),
   165  	}
   166  	for i := uint64(0); i <= bits/8; i++ {
   167  		expected := uint64((1 << (8 * i)) - 1)
   168  
   169  		input := hex.EncodeToString([]byte{byte(i)}) + ctrlByte + strings.Repeat("ff", int(i))
   170  		uints[input] = expected
   171  	}
   172  
   173  	validateDecoding(t, uints)
   174  }
   175  
   176  // Dedup with above somehow.
   177  func TestUint128(t *testing.T) {
   178  	ctrlByte := "03"
   179  	bits := uint(128)
   180  
   181  	uints := map[string]any{
   182  		"00" + ctrlByte:          big.NewInt(0),
   183  		"02" + ctrlByte + "01f4": big.NewInt(500),
   184  		"02" + ctrlByte + "2a78": big.NewInt(10872),
   185  	}
   186  	for i := uint(1); i <= bits/8; i++ {
   187  		expected := powBigInt(big.NewInt(2), 8*i)
   188  		expected = expected.Sub(expected, big.NewInt(1))
   189  		input := hex.EncodeToString([]byte{byte(i)}) + ctrlByte + strings.Repeat("ff", int(i))
   190  
   191  		uints[input] = expected
   192  	}
   193  
   194  	validateDecoding(t, uints)
   195  }
   196  
   197  // No pow or bit shifting for big int, apparently :-(
   198  // This is _not_ meant to be a comprehensive power function.
   199  func powBigInt(bi *big.Int, pow uint) *big.Int {
   200  	newInt := big.NewInt(1)
   201  	for i := uint(0); i < pow; i++ {
   202  		newInt.Mul(newInt, bi)
   203  	}
   204  	return newInt
   205  }
   206  
   207  func validateDecoding(t *testing.T, tests map[string]any) {
   208  	for inputStr, expected := range tests {
   209  		inputBytes, err := hex.DecodeString(inputStr)
   210  		require.NoError(t, err)
   211  		d := decoder{inputBytes}
   212  
   213  		var result any
   214  		_, err = d.decode(0, reflect.ValueOf(&result), 0)
   215  		assert.NoError(t, err)
   216  
   217  		if !reflect.DeepEqual(result, expected) {
   218  			// A big case statement would produce nicer errors
   219  			t.Errorf("Output was incorrect: %s  %s", inputStr, expected)
   220  		}
   221  	}
   222  }
   223  
   224  func TestPointers(t *testing.T) {
   225  	bytes, err := os.ReadFile(testFile("maps-with-pointers.raw"))
   226  	require.NoError(t, err)
   227  	d := decoder{bytes}
   228  
   229  	expected := map[uint]map[string]string{
   230  		0:  {"long_key": "long_value1"},
   231  		22: {"long_key": "long_value2"},
   232  		37: {"long_key2": "long_value1"},
   233  		50: {"long_key2": "long_value2"},
   234  		55: {"long_key": "long_value1"},
   235  		57: {"long_key2": "long_value2"},
   236  	}
   237  
   238  	for offset, expectedValue := range expected {
   239  		var actual map[string]string
   240  		_, err := d.decode(offset, reflect.ValueOf(&actual), 0)
   241  		assert.NoError(t, err)
   242  		if !reflect.DeepEqual(actual, expectedValue) {
   243  			t.Errorf("Decode for pointer at %d failed", offset)
   244  		}
   245  	}
   246  }