github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/bit_ops_test.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  )
    25  
    26  func TestBitAnd(t *testing.T) {
    27  	var testCases = []struct {
    28  		name                string
    29  		left, right         interface{}
    30  		leftType, rightType sql.Type
    31  		expected            uint64
    32  	}{
    33  		{"1 & 1", 1, 1, types.Uint64, types.Uint64, 1},
    34  		{"8 & 1", 8, 1, types.Uint64, types.Uint64, 0},
    35  		{"3 & 1", 3, 1, types.Uint64, types.Uint64, 1},
    36  		{"1024 & 0", 1024, 0, types.Uint64, types.Uint64, 0},
    37  		{"0 & 1024", 0, 1024, types.Uint64, types.Uint64, 0},
    38  		{"-1 & -12", -1, -12, types.Int64, types.Int64, 18446744073709551604},
    39  		{"0.6 & 10.24", 0.6, 10.24, types.Float64, types.Float64, 0},
    40  		{"0.6 & -10.24", 0.6, -10.24, types.Float64, types.Float64, 0},
    41  		{"-0.6 & -10.24", -0.6, -10.24, types.Float64, types.Float64, 18446744073709551606},
    42  	}
    43  
    44  	for _, tt := range testCases {
    45  		t.Run(tt.name, func(t *testing.T) {
    46  			require := require.New(t)
    47  			result, err := NewBitAnd(
    48  				NewLiteral(tt.left, tt.leftType),
    49  				NewLiteral(tt.right, tt.rightType),
    50  			).Eval(sql.NewEmptyContext(), sql.NewRow())
    51  			require.NoError(err)
    52  			require.Equal(tt.expected, result)
    53  		})
    54  	}
    55  }
    56  
    57  func TestBitOr(t *testing.T) {
    58  	var testCases = []struct {
    59  		name                string
    60  		left, right         interface{}
    61  		leftType, rightType sql.Type
    62  		expected            uint64
    63  	}{
    64  		{"1 | 1", 1, 1, types.Uint64, types.Uint64, 1},
    65  		{"8 | 1", 8, 1, types.Uint64, types.Uint64, 9},
    66  		{"3 | 1", 3, 1, types.Uint64, types.Uint64, 3},
    67  		{"1024 | 0", 1024, 0, types.Uint64, types.Uint64, 1024},
    68  		{"0 | 1024", 0, 1024, types.Uint64, types.Uint64, 1024},
    69  		{"-1 | -12", -1, -12, types.Int64, types.Int64, 18446744073709551615},
    70  		{"0.6 | 10.24", 0.6, 10.24, types.Float64, types.Float64, 11},
    71  		{"0.6 | -10.24", 0.6, -10.24, types.Float64, types.Float64, 18446744073709551607},
    72  		{"-0.6 | -10.24", -0.6, -10.24, types.Float64, types.Float64, 18446744073709551615},
    73  	}
    74  
    75  	for _, tt := range testCases {
    76  		t.Run(tt.name, func(t *testing.T) {
    77  			require := require.New(t)
    78  			result, err := NewBitOr(
    79  				NewLiteral(tt.left, tt.leftType),
    80  				NewLiteral(tt.right, tt.rightType),
    81  			).Eval(sql.NewEmptyContext(), sql.NewRow())
    82  			require.NoError(err)
    83  			require.Equal(tt.expected, result)
    84  		})
    85  	}
    86  }
    87  
    88  func TestBitXor(t *testing.T) {
    89  	var testCases = []struct {
    90  		name                string
    91  		left, right         interface{}
    92  		leftType, rightType sql.Type
    93  		expected            uint64
    94  	}{
    95  		{"1 ^ 1", 1, 1, types.Uint64, types.Uint64, 0},
    96  		{"8 ^ 1", 8, 1, types.Uint64, types.Uint64, 9},
    97  		{"3 ^ 1", 3, 1, types.Uint64, types.Uint64, 2},
    98  		{"1024 ^ 0", 1024, 0, types.Uint64, types.Uint64, 1024},
    99  		{"0 ^ -1024", 0, -1024, types.Int64, types.Int64, 18446744073709550592},
   100  		{"-1 ^ -12", -1, -12, types.Int64, types.Int64, 11},
   101  		{"0.6 ^ 10.24", 0.6, 10.24, types.Float64, types.Float64, 11},
   102  		{"0.6 ^ -10.24", 0.6, -10.24, types.Float64, types.Float64, 18446744073709551607},
   103  		{"-0.6 ^ -10.24", -0.6, -10.24, types.Float64, types.Float64, 9},
   104  	}
   105  
   106  	for _, tt := range testCases {
   107  		t.Run(tt.name, func(t *testing.T) {
   108  			require := require.New(t)
   109  			result, err := NewBitXor(
   110  				NewLiteral(tt.left, tt.leftType),
   111  				NewLiteral(tt.right, tt.rightType),
   112  			).Eval(sql.NewEmptyContext(), sql.NewRow())
   113  			require.NoError(err)
   114  			require.Equal(tt.expected, result)
   115  		})
   116  	}
   117  }
   118  
   119  func TestShiftLeft(t *testing.T) {
   120  	var testCases = []struct {
   121  		name        string
   122  		left, right uint64
   123  		expected    uint64
   124  	}{
   125  		{"1 << 1", 1, 1, 2},
   126  		{"1 << 3", 1, 3, 8},
   127  		{"1024 << 0", 1024, 0, 1024},
   128  		{"0 << 1024", 0, 1024, 0},
   129  	}
   130  
   131  	for _, tt := range testCases {
   132  		t.Run(tt.name, func(t *testing.T) {
   133  			require := require.New(t)
   134  			result, err := NewShiftLeft(
   135  				NewLiteral(tt.left, types.Uint64),
   136  				NewLiteral(tt.right, types.Uint64),
   137  			).Eval(sql.NewEmptyContext(), sql.NewRow())
   138  			require.NoError(err)
   139  			require.Equal(tt.expected, result)
   140  		})
   141  	}
   142  }
   143  
   144  func TestShiftRight(t *testing.T) {
   145  	var testCases = []struct {
   146  		name        string
   147  		left, right uint64
   148  		expected    uint64
   149  	}{
   150  		{"1 >> 1", 1, 1, 0},
   151  		{"8 >> 1", 8, 1, 4},
   152  		{"3 >> 1", 3, 1, 1},
   153  		{"1024 >> 0", 1024, 0, 1024},
   154  		{"0 >> 1024", 0, 1024, 0},
   155  	}
   156  
   157  	for _, tt := range testCases {
   158  		t.Run(tt.name, func(t *testing.T) {
   159  			require := require.New(t)
   160  			result, err := NewShiftRight(
   161  				NewLiteral(tt.left, types.Uint64),
   162  				NewLiteral(tt.right, types.Uint64),
   163  			).Eval(sql.NewEmptyContext(), sql.NewRow())
   164  			require.NoError(err)
   165  			require.Equal(tt.expected, result)
   166  		})
   167  	}
   168  }
   169  
   170  func TestAllUint64(t *testing.T) {
   171  	var testCases = []struct {
   172  		op        string
   173  		value     interface{}
   174  		valueType sql.Type
   175  		expected  uint64
   176  	}{
   177  		{"|", 1, types.Uint64, 1},
   178  		{"&", 3.4, types.Float64, 1},
   179  		{"^", -1024, types.Int64, 18446744073709550593},
   180  		{"<<", 50, types.Uint64, 17294948469009547264},
   181  		{">>", 50, types.Uint64, 15361},
   182  	}
   183  
   184  	// (((((0 | 1) & 3.4) ^ -1024) << 50) >> 50) == 15361
   185  	lval := NewLiteral(int64(0), types.Uint64)
   186  	for _, tt := range testCases {
   187  		t.Run(tt.op, func(t *testing.T) {
   188  			require := require.New(t)
   189  			result, err := NewBitOp(lval,
   190  				NewLiteral(tt.value, tt.valueType), tt.op,
   191  			).Eval(sql.NewEmptyContext(), sql.NewRow())
   192  			require.NoError(err)
   193  			require.Equal(tt.expected, result)
   194  
   195  			lval = NewLiteral(result, types.Uint64)
   196  		})
   197  	}
   198  }