github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/rows/convert_assign_test.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rows
    16  
    17  import (
    18  	"database/sql"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestConvertNullable(t *testing.T) {
    27  	testcases := []struct {
    28  		name    string
    29  		src     any
    30  		dest    any
    31  		wantVal any
    32  		hasErr  bool
    33  	}{
    34  		{
    35  			name:    "sql.NUllbool",
    36  			src:     sql.NullBool{Valid: true, Bool: true},
    37  			dest:    &sql.NullBool{Valid: false, Bool: false},
    38  			wantVal: &sql.NullBool{Valid: true, Bool: true},
    39  		},
    40  		{
    41  			name:    "sql.NUllbool的valid为false",
    42  			src:     sql.NullBool{Valid: false, Bool: true},
    43  			dest:    &sql.NullBool{Valid: false, Bool: false},
    44  			wantVal: &sql.NullBool{Valid: false, Bool: false},
    45  		},
    46  		{
    47  			name:    "sql.NUllString",
    48  			src:     sql.NullString{Valid: true, String: "xx"},
    49  			dest:    &sql.NullString{Valid: false, String: ""},
    50  			wantVal: &sql.NullString{Valid: true, String: "xx"},
    51  		},
    52  		{
    53  			name:    "sql.NUllString的valid为false",
    54  			src:     sql.NullString{Valid: false, String: "xx"},
    55  			dest:    &sql.NullString{Valid: false, String: ""},
    56  			wantVal: &sql.NullString{Valid: false, String: ""},
    57  		},
    58  		{
    59  			name:    "sql.NUllByte",
    60  			src:     sql.NullByte{Valid: true, Byte: 'a'},
    61  			dest:    &sql.NullByte{Valid: false, Byte: ' '},
    62  			wantVal: &sql.NullByte{Valid: true, Byte: 'a'},
    63  		},
    64  		{
    65  			name:    "sql.NUllByte的valid的false",
    66  			src:     sql.NullByte{Valid: false, Byte: 'a'},
    67  			dest:    &sql.NullByte{Valid: false, Byte: 0},
    68  			wantVal: &sql.NullByte{Valid: false, Byte: 0},
    69  		},
    70  		{
    71  			name:    "sql.NUllInt32",
    72  			src:     sql.NullInt32{Valid: true, Int32: 5},
    73  			dest:    &sql.NullInt32{Valid: false, Int32: 0},
    74  			wantVal: &sql.NullInt32{Valid: true, Int32: 5},
    75  		},
    76  		{
    77  			name:    "sql.NUllInt32的valid的false",
    78  			src:     sql.NullInt32{Valid: false, Int32: 0},
    79  			dest:    &sql.NullInt32{Valid: false, Int32: 0},
    80  			wantVal: &sql.NullInt32{Valid: false, Int32: 0},
    81  		},
    82  		{
    83  			name:    "sql.NUllInt64",
    84  			src:     sql.NullInt64{Valid: true, Int64: 5},
    85  			dest:    &sql.NullInt64{Valid: false, Int64: 0},
    86  			wantVal: &sql.NullInt64{Valid: true, Int64: 5},
    87  		},
    88  		{
    89  			name:    "sql.NUllInt64的valid的false",
    90  			src:     sql.NullInt64{Valid: false, Int64: 0},
    91  			dest:    &sql.NullInt64{Valid: false, Int64: 0},
    92  			wantVal: &sql.NullInt64{Valid: false, Int64: 0},
    93  		},
    94  		{
    95  			name:    "sql.NUllInt16",
    96  			src:     sql.NullInt16{Valid: true, Int16: 5},
    97  			dest:    &sql.NullInt16{Valid: false, Int16: 0},
    98  			wantVal: &sql.NullInt16{Valid: true, Int16: 5},
    99  		},
   100  		{
   101  			name:    "sql.NUllInt16的valid的false",
   102  			src:     sql.NullInt16{Valid: false, Int16: 0},
   103  			dest:    &sql.NullInt16{Valid: false, Int16: 0},
   104  			wantVal: &sql.NullInt16{Valid: false, Int16: 0},
   105  		},
   106  		{
   107  			name:    "sql.NUllFloat64",
   108  			src:     sql.NullFloat64{Valid: true, Float64: 5},
   109  			dest:    &sql.NullFloat64{Valid: false, Float64: 0},
   110  			wantVal: &sql.NullFloat64{Valid: true, Float64: 5},
   111  		},
   112  		{
   113  			name:    "sql.NUllfloat64的valid的false",
   114  			src:     sql.NullFloat64{Valid: false, Float64: 0},
   115  			dest:    &sql.NullFloat64{Valid: false, Float64: 0},
   116  			wantVal: &sql.NullFloat64{Valid: false, Float64: 0},
   117  		},
   118  		{
   119  			name: "sql.NUllTime",
   120  			src: sql.NullTime{Valid: true, Time: func() time.Time {
   121  				val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local)
   122  				require.NoError(t, err)
   123  				return val
   124  			}()},
   125  			dest: &sql.NullTime{Valid: false, Time: time.Time{}},
   126  			wantVal: &sql.NullTime{Valid: true, Time: func() time.Time {
   127  				val, err := time.ParseInLocation("2006-01-02 15:04:05", "2022-01-01 12:00:00", time.Local)
   128  				require.NoError(t, err)
   129  				return val
   130  			}()},
   131  		},
   132  		{
   133  			name:    "sql.NUllTime的valid的false",
   134  			src:     sql.NullTime{Valid: false, Time: time.Time{}},
   135  			dest:    &sql.NullTime{Valid: false, Time: time.Time{}},
   136  			wantVal: &sql.NullTime{Valid: false, Time: time.Time{}},
   137  		},
   138  		{
   139  			name:    "使用sql.NullInt32接收sql.NullInt64",
   140  			src:     sql.NullInt64{Valid: true, Int64: 5},
   141  			dest:    &sql.NullInt32{Valid: false, Int32: 0},
   142  			wantVal: &sql.NullInt32{Valid: true, Int32: 5},
   143  		},
   144  		{
   145  			name:    "使用sql.NullInt16接收sql.NullInt64",
   146  			src:     sql.NullInt64{Valid: true, Int64: 5},
   147  			dest:    &sql.NullInt16{Valid: false, Int16: 0},
   148  			wantVal: &sql.NullInt16{Valid: true, Int16: 5},
   149  		},
   150  		{
   151  			name:    "使用sql.NullInt32接收sql.NullInt64,Valid为false",
   152  			src:     sql.NullInt64{Valid: false, Int64: 0},
   153  			dest:    &sql.NullInt32{Valid: false, Int32: 0},
   154  			wantVal: &sql.NullInt32{Valid: false, Int32: 0},
   155  		},
   156  		{
   157  			name:    "使用sql.NullInt16接收sql.NullInt64,Valid为false",
   158  			src:     sql.NullInt64{Valid: false, Int64: 0},
   159  			dest:    &sql.NullInt16{Valid: false, Int16: 0},
   160  			wantVal: &sql.NullInt16{Valid: false, Int16: 0},
   161  		},
   162  		{
   163  			name: "使用int32接收sql.NullInt64",
   164  			src:  sql.NullInt64{Valid: true, Int64: 5},
   165  			dest: func() *int32 {
   166  				var val int32
   167  				return &val
   168  			}(),
   169  			wantVal: func() *int32 {
   170  				val := int32(5)
   171  				return &val
   172  			}(),
   173  		},
   174  		{
   175  			name: "使用int16接收sql.NullInt64",
   176  			src:  sql.NullInt64{Valid: true, Int64: 5},
   177  			dest: func() *int16 {
   178  				var val int16
   179  				return &val
   180  			}(),
   181  			wantVal: func() *int16 {
   182  				val := int16(5)
   183  				return &val
   184  			}(),
   185  		},
   186  		{
   187  			name: "使用float32接收sql.Nullfloat64",
   188  			src:  sql.NullFloat64{Valid: true, Float64: 5},
   189  			dest: func() *float32 {
   190  				var val float32
   191  				return &val
   192  			}(),
   193  			wantVal: func() *float32 {
   194  				val := float32(5)
   195  				return &val
   196  			}(),
   197  		},
   198  		{
   199  			name: "使用int32接收sql.NullInt64,Valid为false",
   200  			src:  sql.NullInt64{Valid: false, Int64: 0},
   201  			dest: func() *int32 {
   202  				var val int32
   203  				return &val
   204  			}(),
   205  			hasErr: true,
   206  		},
   207  		{
   208  			name: "使用int16接收sql.NullInt64,valid为false",
   209  			src:  sql.NullInt64{Valid: false, Int64: 0},
   210  			dest: func() *int16 {
   211  				var val int16
   212  				return &val
   213  			}(),
   214  			hasErr: true,
   215  		},
   216  		{
   217  			name: "使用float32接收sql.Nullfloat64",
   218  			src:  sql.NullFloat64{Valid: false, Float64: 0},
   219  			dest: func() *float32 {
   220  				var val float32
   221  				return &val
   222  			}(),
   223  			hasErr: true,
   224  		},
   225  	}
   226  	for _, tc := range testcases {
   227  		t.Run(tc.name, func(t *testing.T) {
   228  			err := ConvertAssign(tc.dest, tc.src)
   229  			if tc.hasErr {
   230  				require.Error(t, err)
   231  				return
   232  			}
   233  			require.NoError(t, err)
   234  			assert.Equal(t, tc.dest, tc.wantVal)
   235  		})
   236  	}
   237  }