gitlab.com/CoiaPrant/sqlite3@v1.19.1/functest/func_test.go (about)

     1  // Copyright 2022 The Sqlite Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package functest // gitlab.com/CoiaPrant/sqlite3/functest
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/md5"
    10  	"database/sql"
    11  	"database/sql/driver"
    12  	"encoding/hex"
    13  	"errors"
    14  	"fmt"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	sqlite3 "gitlab.com/CoiaPrant/sqlite3"
    20  )
    21  
    22  func init() {
    23  	sqlite3.MustRegisterDeterministicScalarFunction(
    24  		"test_int64",
    25  		0,
    26  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    27  			return int64(42), nil
    28  		},
    29  	)
    30  
    31  	sqlite3.MustRegisterDeterministicScalarFunction(
    32  		"test_float64",
    33  		0,
    34  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    35  			return float64(1e-2), nil
    36  		},
    37  	)
    38  
    39  	sqlite3.MustRegisterDeterministicScalarFunction(
    40  		"test_null",
    41  		0,
    42  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    43  			return nil, nil
    44  		},
    45  	)
    46  
    47  	sqlite3.MustRegisterDeterministicScalarFunction(
    48  		"test_error",
    49  		0,
    50  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    51  			return nil, errors.New("boom")
    52  		},
    53  	)
    54  
    55  	sqlite3.MustRegisterDeterministicScalarFunction(
    56  		"test_empty_byte_slice",
    57  		0,
    58  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    59  			return []byte{}, nil
    60  		},
    61  	)
    62  
    63  	sqlite3.MustRegisterDeterministicScalarFunction(
    64  		"test_nonempty_byte_slice",
    65  		0,
    66  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    67  			return []byte("abcdefg"), nil
    68  		},
    69  	)
    70  
    71  	sqlite3.MustRegisterDeterministicScalarFunction(
    72  		"test_empty_string",
    73  		0,
    74  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    75  			return "", nil
    76  		},
    77  	)
    78  
    79  	sqlite3.MustRegisterDeterministicScalarFunction(
    80  		"test_nonempty_string",
    81  		0,
    82  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    83  			return "abcdefg", nil
    84  		},
    85  	)
    86  
    87  	sqlite3.MustRegisterDeterministicScalarFunction(
    88  		"yesterday",
    89  		1,
    90  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
    91  			var arg time.Time
    92  			switch argTyped := args[0].(type) {
    93  			case int64:
    94  				arg = time.Unix(argTyped, 0)
    95  			default:
    96  				fmt.Println(argTyped)
    97  				return nil, fmt.Errorf("expected argument to be int64, got: %T", argTyped)
    98  			}
    99  			return arg.Add(-24 * time.Hour), nil
   100  		},
   101  	)
   102  
   103  	sqlite3.MustRegisterDeterministicScalarFunction(
   104  		"md5",
   105  		1,
   106  		func(ctx *sqlite3.FunctionContext, args []driver.Value) (driver.Value, error) {
   107  			var arg *bytes.Buffer
   108  			switch argTyped := args[0].(type) {
   109  			case string:
   110  				arg = bytes.NewBuffer([]byte(argTyped))
   111  			case []byte:
   112  				arg = bytes.NewBuffer(argTyped)
   113  			default:
   114  				return nil, fmt.Errorf("expected argument to be a string, got: %T", argTyped)
   115  			}
   116  			w := md5.New()
   117  			if _, err := arg.WriteTo(w); err != nil {
   118  				return nil, fmt.Errorf("unable to compute md5 checksum: %s", err)
   119  			}
   120  			return hex.EncodeToString(w.Sum(nil)), nil
   121  		},
   122  	)
   123  }
   124  
   125  func TestRegisteredFunctions(t *testing.T) {
   126  	withDB := func(test func(db *sql.DB)) {
   127  		db, err := sql.Open("sqlite", "file::memory:")
   128  		if err != nil {
   129  			t.Fatalf("failed to open database: %v", err)
   130  		}
   131  		defer db.Close()
   132  
   133  		test(db)
   134  	}
   135  
   136  	t.Run("int64", func(tt *testing.T) {
   137  		withDB(func(db *sql.DB) {
   138  			row := db.QueryRow("select test_int64()")
   139  
   140  			var a int
   141  			if err := row.Scan(&a); err != nil {
   142  				tt.Fatal(err)
   143  			}
   144  			if g, e := a, 42; g != e {
   145  				tt.Fatal(g, e)
   146  			}
   147  
   148  		})
   149  	})
   150  
   151  	t.Run("float64", func(tt *testing.T) {
   152  		withDB(func(db *sql.DB) {
   153  			row := db.QueryRow("select test_float64()")
   154  
   155  			var a float64
   156  			if err := row.Scan(&a); err != nil {
   157  				tt.Fatal(err)
   158  			}
   159  			if g, e := a, 1e-2; g != e {
   160  				tt.Fatal(g, e)
   161  			}
   162  
   163  		})
   164  	})
   165  
   166  	t.Run("error", func(tt *testing.T) {
   167  		withDB(func(db *sql.DB) {
   168  			_, err := db.Query("select test_error()")
   169  			if err == nil {
   170  				tt.Fatal("expected error, got none")
   171  			}
   172  			if !strings.Contains(err.Error(), "boom") {
   173  				tt.Fatal(err)
   174  			}
   175  		})
   176  	})
   177  
   178  	t.Run("empty_byte_slice", func(tt *testing.T) {
   179  		withDB(func(db *sql.DB) {
   180  			row := db.QueryRow("select test_empty_byte_slice()")
   181  
   182  			var a []byte
   183  			if err := row.Scan(&a); err != nil {
   184  				tt.Fatal(err)
   185  			}
   186  			if len(a) > 0 {
   187  				tt.Fatal("expected empty byte slice")
   188  			}
   189  		})
   190  	})
   191  
   192  	t.Run("nonempty_byte_slice", func(tt *testing.T) {
   193  		withDB(func(db *sql.DB) {
   194  			row := db.QueryRow("select test_nonempty_byte_slice()")
   195  
   196  			var a []byte
   197  			if err := row.Scan(&a); err != nil {
   198  				tt.Fatal(err)
   199  			}
   200  			if g, e := a, []byte("abcdefg"); !bytes.Equal(g, e) {
   201  				tt.Fatal(string(g), string(e))
   202  			}
   203  		})
   204  	})
   205  
   206  	t.Run("empty_string", func(tt *testing.T) {
   207  		withDB(func(db *sql.DB) {
   208  			row := db.QueryRow("select test_empty_string()")
   209  
   210  			var a string
   211  			if err := row.Scan(&a); err != nil {
   212  				tt.Fatal(err)
   213  			}
   214  			if len(a) > 0 {
   215  				tt.Fatal("expected empty string")
   216  			}
   217  		})
   218  	})
   219  
   220  	t.Run("nonempty_string", func(tt *testing.T) {
   221  		withDB(func(db *sql.DB) {
   222  			row := db.QueryRow("select test_nonempty_string()")
   223  
   224  			var a string
   225  			if err := row.Scan(&a); err != nil {
   226  				tt.Fatal(err)
   227  			}
   228  			if g, e := a, "abcdefg"; g != e {
   229  				tt.Fatal(g, e)
   230  			}
   231  		})
   232  	})
   233  
   234  	t.Run("null", func(tt *testing.T) {
   235  		withDB(func(db *sql.DB) {
   236  			row := db.QueryRow("select test_null()")
   237  
   238  			var a interface{}
   239  			if err := row.Scan(&a); err != nil {
   240  				tt.Fatal(err)
   241  			}
   242  			if a != nil {
   243  				tt.Fatal("expected nil")
   244  			}
   245  		})
   246  	})
   247  
   248  	t.Run("dates", func(tt *testing.T) {
   249  		withDB(func(db *sql.DB) {
   250  			row := db.QueryRow("select yesterday(unixepoch('2018-11-01'))")
   251  
   252  			var a int64
   253  			if err := row.Scan(&a); err != nil {
   254  				tt.Fatal(err)
   255  			}
   256  			if g, e := time.Unix(a, 0), time.Date(2018, time.October, 31, 0, 0, 0, 0, time.UTC); !g.Equal(e) {
   257  				tt.Fatal(g, e)
   258  			}
   259  		})
   260  	})
   261  
   262  	t.Run("md5", func(tt *testing.T) {
   263  		withDB(func(db *sql.DB) {
   264  			row := db.QueryRow("select md5('abcdefg')")
   265  
   266  			var a string
   267  			if err := row.Scan(&a); err != nil {
   268  				tt.Fatal(err)
   269  			}
   270  			if g, e := a, "7ac66c0f148de9519b8bd264312c4d64"; g != e {
   271  				tt.Fatal(g, e)
   272  			}
   273  		})
   274  	})
   275  
   276  	t.Run("md5 with blob input", func(tt *testing.T) {
   277  		withDB(func(db *sql.DB) {
   278  			if _, err := db.Exec("create table t(b blob); insert into t values (?)", []byte("abcdefg")); err != nil {
   279  				tt.Fatal(err)
   280  			}
   281  			row := db.QueryRow("select md5(b) from t")
   282  
   283  			var a []byte
   284  			if err := row.Scan(&a); err != nil {
   285  				tt.Fatal(err)
   286  			}
   287  			if g, e := a, []byte("7ac66c0f148de9519b8bd264312c4d64"); !bytes.Equal(g, e) {
   288  				tt.Fatal(string(g), string(e))
   289  			}
   290  		})
   291  	})
   292  }