github.com/kubewharf/katalyst-core@v0.5.3/pkg/util/bitmask/bitmask_test.go (about)

     1  /*
     2  Copyright 2022 The Katalyst Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package bitmask
    18  
    19  import (
    20  	"reflect"
    21  	"testing"
    22  )
    23  
    24  func TestNewEmptyiBitMask(t *testing.T) {
    25  	t.Parallel()
    26  
    27  	tcases := []struct {
    28  		name         string
    29  		expectedMask string
    30  	}{
    31  		{
    32  			name:         "New empty BitMask",
    33  			expectedMask: "00",
    34  		},
    35  	}
    36  	for _, tc := range tcases {
    37  		bm := NewEmptyBitMask()
    38  		if bm.String() != tc.expectedMask {
    39  			t.Errorf("Expected mask to be %v, got %v", tc.expectedMask, bm)
    40  		}
    41  	}
    42  }
    43  
    44  func TestNewBitMask(t *testing.T) {
    45  	t.Parallel()
    46  
    47  	tcases := []struct {
    48  		name         string
    49  		bits         []int
    50  		expectedMask string
    51  	}{
    52  		{
    53  			name:         "New BitMask with bit 0 set",
    54  			bits:         []int{0},
    55  			expectedMask: "01",
    56  		},
    57  		{
    58  			name:         "New BitMask with bit 1 set",
    59  			bits:         []int{1},
    60  			expectedMask: "10",
    61  		},
    62  		{
    63  			name:         "New BitMask with bit 0 and bit 1 set",
    64  			bits:         []int{0, 1},
    65  			expectedMask: "11",
    66  		},
    67  	}
    68  	for _, tc := range tcases {
    69  		mask, _ := NewBitMask(tc.bits...)
    70  		if mask.String() != tc.expectedMask {
    71  			t.Errorf("Expected mask to be %v, got %v", tc.expectedMask, mask)
    72  		}
    73  	}
    74  }
    75  
    76  func TestAdd(t *testing.T) {
    77  	t.Parallel()
    78  
    79  	tcases := []struct {
    80  		name         string
    81  		bits         []int
    82  		expectedMask string
    83  	}{
    84  		{
    85  			name:         "Add BitMask with bit 0 set",
    86  			bits:         []int{0},
    87  			expectedMask: "01",
    88  		},
    89  		{
    90  			name:         "Add BitMask with bit 1 set",
    91  			bits:         []int{1},
    92  			expectedMask: "10",
    93  		},
    94  		{
    95  			name:         "Add BitMask with bits 0 and 1 set",
    96  			bits:         []int{0, 1},
    97  			expectedMask: "11",
    98  		},
    99  		{
   100  			name:         "Add BitMask with bits outside range 0-63",
   101  			bits:         []int{-1, 64},
   102  			expectedMask: "00",
   103  		},
   104  	}
   105  	for _, tc := range tcases {
   106  		mask, _ := NewBitMask()
   107  		mask.Add(tc.bits...)
   108  		if mask.String() != tc.expectedMask {
   109  			t.Errorf("Expected mask to be %v, got %v", tc.expectedMask, mask)
   110  		}
   111  	}
   112  }
   113  
   114  func TestRemove(t *testing.T) {
   115  	t.Parallel()
   116  
   117  	tcases := []struct {
   118  		name         string
   119  		bitsSet      []int
   120  		bitsRemove   []int
   121  		expectedMask string
   122  	}{
   123  		{
   124  			name:         "Set bit 0. Remove bit 0",
   125  			bitsSet:      []int{0},
   126  			bitsRemove:   []int{0},
   127  			expectedMask: "00",
   128  		},
   129  		{
   130  			name:         "Set bits 0 and 1. Remove bit 1",
   131  			bitsSet:      []int{0, 1},
   132  			bitsRemove:   []int{1},
   133  			expectedMask: "01",
   134  		},
   135  		{
   136  			name:         "Set bits 0 and 1. Remove bits 0 and 1",
   137  			bitsSet:      []int{0, 1},
   138  			bitsRemove:   []int{0, 1},
   139  			expectedMask: "00",
   140  		},
   141  		{
   142  			name:         "Set bit 0. Attempt to remove bits outside range 0-63",
   143  			bitsSet:      []int{0},
   144  			bitsRemove:   []int{-1, 64},
   145  			expectedMask: "01",
   146  		},
   147  	}
   148  	for _, tc := range tcases {
   149  		mask, _ := NewBitMask(tc.bitsSet...)
   150  		mask.Remove(tc.bitsRemove...)
   151  		if mask.String() != tc.expectedMask {
   152  			t.Errorf("Expected mask to be %v, got %v", tc.expectedMask, mask)
   153  		}
   154  	}
   155  }
   156  
   157  func TestAnd(t *testing.T) {
   158  	t.Parallel()
   159  
   160  	tcases := []struct {
   161  		name    string
   162  		masks   [][]int
   163  		andMask string
   164  	}{
   165  		{
   166  			name:    "Mask 11 AND mask 11",
   167  			masks:   [][]int{{0, 1}, {0, 1}},
   168  			andMask: "11",
   169  		},
   170  		{
   171  			name:    "Mask 11 AND mask 10",
   172  			masks:   [][]int{{0, 1}, {1}},
   173  			andMask: "10",
   174  		},
   175  		{
   176  			name:    "Mask 01 AND mask 11",
   177  			masks:   [][]int{{0}, {0, 1}},
   178  			andMask: "01",
   179  		},
   180  		{
   181  			name:    "Mask 11 AND mask 11 AND mask 10",
   182  			masks:   [][]int{{0, 1}, {0, 1}, {1}},
   183  			andMask: "10",
   184  		},
   185  		{
   186  			name:    "Mask 01 AND mask 01 AND mask 10 AND mask 11",
   187  			masks:   [][]int{{0}, {0}, {1}, {0, 1}},
   188  			andMask: "00",
   189  		},
   190  		{
   191  			name:    "Mask 1111 AND mask 1110 AND mask 1100 AND mask 1000",
   192  			masks:   [][]int{{0, 1, 2, 3}, {1, 2, 3}, {2, 3}, {3}},
   193  			andMask: "1000",
   194  		},
   195  	}
   196  	for _, tc := range tcases {
   197  		var bitMasks []BitMask
   198  		for i := range tc.masks {
   199  			bitMask, _ := NewBitMask(tc.masks[i]...)
   200  			bitMasks = append(bitMasks, bitMask)
   201  		}
   202  		resultMask := And(bitMasks[0], bitMasks...)
   203  		if resultMask.String() != tc.andMask {
   204  			t.Errorf("Expected mask to be %v, got %v", tc.andMask, resultMask)
   205  		}
   206  
   207  	}
   208  }
   209  
   210  func TestOr(t *testing.T) {
   211  	t.Parallel()
   212  
   213  	tcases := []struct {
   214  		name   string
   215  		masks  [][]int
   216  		orMask string
   217  	}{
   218  		{
   219  			name:   "Mask 01 OR mask 00",
   220  			masks:  [][]int{{0}, {}},
   221  			orMask: "01",
   222  		},
   223  		{
   224  			name:   "Mask 10 OR mask 10",
   225  			masks:  [][]int{{1}, {1}},
   226  			orMask: "10",
   227  		},
   228  		{
   229  			name:   "Mask 01 OR mask 10",
   230  			masks:  [][]int{{0}, {1}},
   231  			orMask: "11",
   232  		},
   233  		{
   234  			name:   "Mask 11 OR mask 11",
   235  			masks:  [][]int{{0, 1}, {0, 1}},
   236  			orMask: "11",
   237  		},
   238  		{
   239  			name:   "Mask 01 OR mask 10 OR mask 11",
   240  			masks:  [][]int{{0}, {1}, {0, 1}},
   241  			orMask: "11",
   242  		},
   243  		{
   244  			name:   "Mask 1000 OR mask 0100 OR mask 0010 OR mask 0001",
   245  			masks:  [][]int{{3}, {2}, {1}, {0}},
   246  			orMask: "1111",
   247  		},
   248  	}
   249  	for _, tc := range tcases {
   250  		var bitMasks []BitMask
   251  		for i := range tc.masks {
   252  			bitMask, _ := NewBitMask(tc.masks[i]...)
   253  			bitMasks = append(bitMasks, bitMask)
   254  		}
   255  		resultMask := Or(bitMasks[0], bitMasks...)
   256  		if resultMask.String() != tc.orMask {
   257  			t.Errorf("Expected mask to be %v, got %v", tc.orMask, resultMask)
   258  		}
   259  	}
   260  }
   261  
   262  func TestClear(t *testing.T) {
   263  	t.Parallel()
   264  
   265  	tcases := []struct {
   266  		name        string
   267  		mask        []int
   268  		clearedMask string
   269  	}{
   270  		{
   271  			name:        "Clear mask 01",
   272  			mask:        []int{0},
   273  			clearedMask: "00",
   274  		},
   275  		{
   276  			name:        "Clear mask 10",
   277  			mask:        []int{1},
   278  			clearedMask: "00",
   279  		},
   280  		{
   281  			name:        "Clear mask 11",
   282  			mask:        []int{0, 1},
   283  			clearedMask: "00",
   284  		},
   285  	}
   286  	for _, tc := range tcases {
   287  		mask, _ := NewBitMask(tc.mask...)
   288  		mask.Clear()
   289  		if mask.String() != tc.clearedMask {
   290  			t.Errorf("Expected mask to be %v, got %v", tc.clearedMask, mask)
   291  		}
   292  	}
   293  }
   294  
   295  func TestFill(t *testing.T) {
   296  	t.Parallel()
   297  
   298  	tcases := []struct {
   299  		name       string
   300  		mask       []int
   301  		filledMask string
   302  	}{
   303  		{
   304  			name:       "Fill empty mask",
   305  			mask:       nil,
   306  			filledMask: "1111111111111111111111111111111111111111111111111111111111111111",
   307  		},
   308  		{
   309  			name:       "Fill mask 10",
   310  			mask:       []int{0},
   311  			filledMask: "1111111111111111111111111111111111111111111111111111111111111111",
   312  		},
   313  		{
   314  			name:       "Fill mask 11",
   315  			mask:       []int{0, 1},
   316  			filledMask: "1111111111111111111111111111111111111111111111111111111111111111",
   317  		},
   318  	}
   319  	for _, tc := range tcases {
   320  		mask, _ := NewBitMask(tc.mask...)
   321  		mask.Fill()
   322  		if mask.String() != tc.filledMask {
   323  			t.Errorf("Expected mask to be %v, got %v", tc.filledMask, mask)
   324  		}
   325  	}
   326  }
   327  
   328  func TestIsEmpty(t *testing.T) {
   329  	t.Parallel()
   330  
   331  	tcases := []struct {
   332  		name          string
   333  		mask          []int
   334  		expectedEmpty bool
   335  	}{
   336  		{
   337  			name:          "Check if mask 00 is empty",
   338  			mask:          nil,
   339  			expectedEmpty: true,
   340  		},
   341  		{
   342  			name:          "Check if mask 01 is empty",
   343  			mask:          []int{0},
   344  			expectedEmpty: false,
   345  		},
   346  		{
   347  			name:          "Check if mask 11 is empty",
   348  			mask:          []int{0, 1},
   349  			expectedEmpty: false,
   350  		},
   351  	}
   352  	for _, tc := range tcases {
   353  		mask, _ := NewBitMask(tc.mask...)
   354  		empty := mask.IsEmpty()
   355  		if empty != tc.expectedEmpty {
   356  			t.Errorf("Expected value to be %v, got %v", tc.expectedEmpty, empty)
   357  		}
   358  	}
   359  }
   360  
   361  func TestIsSet(t *testing.T) {
   362  	t.Parallel()
   363  
   364  	tcases := []struct {
   365  		name        string
   366  		mask        []int
   367  		checkBit    int
   368  		expectedSet bool
   369  	}{
   370  		{
   371  			name:        "Check if bit 0 in mask 00 is set",
   372  			mask:        nil,
   373  			checkBit:    0,
   374  			expectedSet: false,
   375  		},
   376  		{
   377  			name:        "Check if bit 0 in mask 01 is set",
   378  			mask:        []int{0},
   379  			checkBit:    0,
   380  			expectedSet: true,
   381  		},
   382  		{
   383  			name:        "Check if bit 1 in mask 11 is set",
   384  			mask:        []int{0, 1},
   385  			checkBit:    1,
   386  			expectedSet: true,
   387  		},
   388  		{
   389  			name:        "Check if bit outside range 0-63 is set",
   390  			mask:        []int{0, 1},
   391  			checkBit:    64,
   392  			expectedSet: false,
   393  		},
   394  	}
   395  	for _, tc := range tcases {
   396  		mask, _ := NewBitMask(tc.mask...)
   397  		set := mask.IsSet(tc.checkBit)
   398  		if set != tc.expectedSet {
   399  			t.Errorf("Expected value to be %v, got %v", tc.expectedSet, set)
   400  		}
   401  	}
   402  }
   403  
   404  func TestAnySet(t *testing.T) {
   405  	t.Parallel()
   406  
   407  	tcases := []struct {
   408  		name        string
   409  		mask        []int
   410  		checkBits   []int
   411  		expectedSet bool
   412  	}{
   413  		{
   414  			name:        "Check if any bits from 11 in mask 00 is set",
   415  			mask:        nil,
   416  			checkBits:   []int{0, 1},
   417  			expectedSet: false,
   418  		},
   419  		{
   420  			name:        "Check if any bits from 11 in mask 01 is set",
   421  			mask:        []int{0},
   422  			checkBits:   []int{0, 1},
   423  			expectedSet: true,
   424  		},
   425  		{
   426  			name:        "Check if any bits from 11 in mask 11 is set",
   427  			mask:        []int{0, 1},
   428  			checkBits:   []int{0, 1},
   429  			expectedSet: true,
   430  		},
   431  		{
   432  			name:        "Check if any bit outside range 0-63 is set",
   433  			mask:        []int{0, 1},
   434  			checkBits:   []int{64, 65},
   435  			expectedSet: false,
   436  		},
   437  		{
   438  			name:        "Check if any bits from 1001 in mask 0110 is set",
   439  			mask:        []int{1, 2},
   440  			checkBits:   []int{0, 3},
   441  			expectedSet: false,
   442  		},
   443  	}
   444  	for _, tc := range tcases {
   445  		mask, _ := NewBitMask(tc.mask...)
   446  		set := mask.AnySet(tc.checkBits)
   447  		if set != tc.expectedSet {
   448  			t.Errorf("Expected value to be %v, got %v", tc.expectedSet, set)
   449  		}
   450  	}
   451  }
   452  
   453  func TestIsEqual(t *testing.T) {
   454  	t.Parallel()
   455  
   456  	tcases := []struct {
   457  		name          string
   458  		firstMask     []int
   459  		secondMask    []int
   460  		expectedEqual bool
   461  	}{
   462  		{
   463  			name:          "Check if mask 00 equals mask 00",
   464  			firstMask:     nil,
   465  			secondMask:    nil,
   466  			expectedEqual: true,
   467  		},
   468  		{
   469  			name:          "Check if mask 00 equals mask 01",
   470  			firstMask:     nil,
   471  			secondMask:    []int{0},
   472  			expectedEqual: false,
   473  		},
   474  		{
   475  			name:          "Check if mask 01 equals mask 01",
   476  			firstMask:     []int{0},
   477  			secondMask:    []int{0},
   478  			expectedEqual: true,
   479  		},
   480  		{
   481  			name:          "Check if mask 01 equals mask 10",
   482  			firstMask:     []int{0},
   483  			secondMask:    []int{1},
   484  			expectedEqual: false,
   485  		},
   486  		{
   487  			name:          "Check if mask 11 equals mask 11",
   488  			firstMask:     []int{0, 1},
   489  			secondMask:    []int{0, 1},
   490  			expectedEqual: true,
   491  		},
   492  	}
   493  	for _, tc := range tcases {
   494  		firstMask, _ := NewBitMask(tc.firstMask...)
   495  		secondMask, _ := NewBitMask(tc.secondMask...)
   496  		isEqual := firstMask.IsEqual(secondMask)
   497  		if isEqual != tc.expectedEqual {
   498  			t.Errorf("Expected mask to be %v, got %v", tc.expectedEqual, isEqual)
   499  		}
   500  	}
   501  }
   502  
   503  func TestCount(t *testing.T) {
   504  	t.Parallel()
   505  
   506  	tcases := []struct {
   507  		name          string
   508  		bits          []int
   509  		expectedCount int
   510  	}{
   511  		{
   512  			name:          "Count number of bits set in mask 00",
   513  			bits:          nil,
   514  			expectedCount: 0,
   515  		},
   516  		{
   517  			name:          "Count number of bits set in mask 01",
   518  			bits:          []int{0},
   519  			expectedCount: 1,
   520  		},
   521  		{
   522  			name:          "Count number of bits set in mask 11",
   523  			bits:          []int{0, 1},
   524  			expectedCount: 2,
   525  		},
   526  	}
   527  	for _, tc := range tcases {
   528  		mask, _ := NewBitMask(tc.bits...)
   529  		count := mask.Count()
   530  		if count != tc.expectedCount {
   531  			t.Errorf("Expected value to be %v, got %v", tc.expectedCount, count)
   532  		}
   533  	}
   534  }
   535  
   536  func TestGetBits(t *testing.T) {
   537  	t.Parallel()
   538  
   539  	tcases := []struct {
   540  		name         string
   541  		bits         []int
   542  		expectedBits []int
   543  	}{
   544  		{
   545  			name:         "Get bits of mask 00",
   546  			bits:         nil,
   547  			expectedBits: nil,
   548  		},
   549  		{
   550  			name:         "Get bits of mask 01",
   551  			bits:         []int{0},
   552  			expectedBits: []int{0},
   553  		},
   554  		{
   555  			name:         "Get bits of mask 11",
   556  			bits:         []int{0, 1},
   557  			expectedBits: []int{0, 1},
   558  		},
   559  	}
   560  	for _, tc := range tcases {
   561  		mask, _ := NewBitMask(tc.bits...)
   562  		bits := mask.GetBits()
   563  		if !reflect.DeepEqual(bits, tc.expectedBits) {
   564  			t.Errorf("Expected value to be %v, got %v", tc.expectedBits, bits)
   565  		}
   566  	}
   567  }
   568  
   569  func TestIsNarrowerThan(t *testing.T) {
   570  	t.Parallel()
   571  
   572  	tcases := []struct {
   573  		name                  string
   574  		firstMask             []int
   575  		secondMask            []int
   576  		expectedFirstNarrower bool
   577  	}{
   578  		{
   579  			name:                  "Check narrowness of masks with unequal bits set 1/2",
   580  			firstMask:             []int{0},
   581  			secondMask:            []int{0, 1},
   582  			expectedFirstNarrower: true,
   583  		},
   584  		{
   585  			name:                  "Check narrowness of masks with unequal bits set 2/2",
   586  			firstMask:             []int{0, 1},
   587  			secondMask:            []int{0},
   588  			expectedFirstNarrower: false,
   589  		},
   590  		{
   591  			name:                  "Check narrowness of masks with equal bits set 1/2",
   592  			firstMask:             []int{0},
   593  			secondMask:            []int{1},
   594  			expectedFirstNarrower: true,
   595  		},
   596  		{
   597  			name:                  "Check narrowness of masks with equal bits set 2/2",
   598  			firstMask:             []int{1},
   599  			secondMask:            []int{0},
   600  			expectedFirstNarrower: false,
   601  		},
   602  	}
   603  	for _, tc := range tcases {
   604  		firstMask, _ := NewBitMask(tc.firstMask...)
   605  		secondMask, _ := NewBitMask(tc.secondMask...)
   606  		expectedFirstNarrower := firstMask.IsNarrowerThan(secondMask)
   607  		if expectedFirstNarrower != tc.expectedFirstNarrower {
   608  			t.Errorf("Expected value to be %v, got %v", tc.expectedFirstNarrower, expectedFirstNarrower)
   609  		}
   610  	}
   611  }
   612  
   613  func TestIterateBitMasks(t *testing.T) {
   614  	t.Parallel()
   615  
   616  	tcases := []struct {
   617  		name    string
   618  		numbits int
   619  	}{
   620  		{
   621  			name:    "1 bit",
   622  			numbits: 1,
   623  		},
   624  		{
   625  			name:    "2 bits",
   626  			numbits: 2,
   627  		},
   628  		{
   629  			name:    "4 bits",
   630  			numbits: 4,
   631  		},
   632  		{
   633  			name:    "8 bits",
   634  			numbits: 8,
   635  		},
   636  		{
   637  			name:    "16 bits",
   638  			numbits: 16,
   639  		},
   640  	}
   641  	for _, tc := range tcases {
   642  		// Generate a list of bits from tc.numbits.
   643  		var bits []int
   644  		for i := 0; i < tc.numbits; i++ {
   645  			bits = append(bits, i)
   646  		}
   647  
   648  		// Calculate the expected number of masks. Since we always have masks
   649  		// with bits from 0..n, this is just (2^n - 1) since we want 1 mask
   650  		// represented by each integer between 1 and 2^n-1.
   651  		expectedNumMasks := (1 << uint(tc.numbits)) - 1
   652  
   653  		// Iterate all masks and count them.
   654  		numMasks := 0
   655  		IterateBitMasks(bits, func(BitMask) {
   656  			numMasks++
   657  		})
   658  
   659  		// Compare the number of masks generated to the expected amount.
   660  		if expectedNumMasks != numMasks {
   661  			t.Errorf("Expected to iterate %v masks, got %v", expectedNumMasks, numMasks)
   662  		}
   663  	}
   664  }