github.com/ncruces/go-sqlite3@v0.15.1-0.20240520133447-53eef1510ff0/tests/func_test.go (about)

     1  package tests
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/ncruces/go-sqlite3"
     8  	_ "github.com/ncruces/go-sqlite3/embed"
     9  	_ "github.com/ncruces/go-sqlite3/tests/testcfg"
    10  )
    11  
    12  func TestCreateFunction(t *testing.T) {
    13  	t.Parallel()
    14  
    15  	db, err := sqlite3.Open(":memory:")
    16  	if err != nil {
    17  		t.Fatal(err)
    18  	}
    19  	defer db.Close()
    20  
    21  	err = db.CreateFunction("test", 1, sqlite3.INNOCUOUS, func(ctx sqlite3.Context, arg ...sqlite3.Value) {
    22  		switch arg := arg[0]; arg.Int() {
    23  		case 0:
    24  			ctx.ResultInt(arg.Int())
    25  		case 1:
    26  			ctx.ResultInt64(arg.Int64())
    27  		case 2:
    28  			ctx.ResultBool(arg.Bool())
    29  		case 3:
    30  			ctx.ResultFloat(arg.Float())
    31  		case 4:
    32  			ctx.ResultText(arg.Text())
    33  		case 5:
    34  			ctx.ResultBlob(arg.Blob(nil))
    35  		case 6:
    36  			ctx.ResultZeroBlob(arg.Int64())
    37  		case 7:
    38  			ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
    39  		case 8:
    40  			var v any
    41  			if err := arg.JSON(&v); err != nil {
    42  				ctx.ResultError(err)
    43  			} else {
    44  				ctx.ResultJSON(v)
    45  			}
    46  		case 9:
    47  			ctx.ResultValue(arg)
    48  		case 10:
    49  			ctx.ResultNull()
    50  		case 11:
    51  			ctx.ResultError(sqlite3.FULL)
    52  		}
    53  	})
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  
    58  	stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
    59  	if err != nil {
    60  		t.Error(err)
    61  	}
    62  	defer stmt.Close()
    63  
    64  	if stmt.Step() {
    65  		if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
    66  			t.Errorf("got %v, want INTEGER", got)
    67  		}
    68  		if got := stmt.ColumnInt(0); got != 0 {
    69  			t.Errorf("got %v, want 1", got)
    70  		}
    71  	}
    72  
    73  	if stmt.Step() {
    74  		if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
    75  			t.Errorf("got %v, want INTEGER", got)
    76  		}
    77  		if got := stmt.ColumnInt64(0); got != 1 {
    78  			t.Errorf("got %v, want 2", got)
    79  		}
    80  	}
    81  
    82  	if stmt.Step() {
    83  		if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
    84  			t.Errorf("got %v, want INTEGER", got)
    85  		}
    86  		if got := stmt.ColumnBool(0); got != true {
    87  			t.Errorf("got %v, want true", got)
    88  		}
    89  	}
    90  
    91  	if stmt.Step() {
    92  		if got := stmt.ColumnType(0); got != sqlite3.FLOAT {
    93  			t.Errorf("got %v, want FLOAT", got)
    94  		}
    95  		if got := stmt.ColumnInt64(0); got != 3 {
    96  			t.Errorf("got %v, want 3", got)
    97  		}
    98  	}
    99  
   100  	if stmt.Step() {
   101  		if got := stmt.ColumnType(0); got != sqlite3.TEXT {
   102  			t.Errorf("got %v, want TEXT", got)
   103  		}
   104  		if got := stmt.ColumnText(0); got != "4" {
   105  			t.Errorf("got %s, want 4", got)
   106  		}
   107  	}
   108  
   109  	if stmt.Step() {
   110  		if got := stmt.ColumnType(0); got != sqlite3.BLOB {
   111  			t.Errorf("got %v, want BLOB", got)
   112  		}
   113  		if got := stmt.ColumnRawBlob(0); string(got) != "5" {
   114  			t.Errorf("got %s, want 5", got)
   115  		}
   116  	}
   117  
   118  	if stmt.Step() {
   119  		if got := stmt.ColumnType(0); got != sqlite3.BLOB {
   120  			t.Errorf("got %v, want BLOB", got)
   121  		}
   122  		if got := stmt.ColumnRawBlob(0); len(got) != 6 {
   123  			t.Errorf("got %v, want 6", got)
   124  		}
   125  	}
   126  
   127  	if stmt.Step() {
   128  		if got := stmt.ColumnType(0); got != sqlite3.TEXT {
   129  			t.Errorf("got %v, want TEXT", got)
   130  		}
   131  		if got := stmt.ColumnTime(0, sqlite3.TimeFormatAuto); got.Unix() != 7 {
   132  			t.Errorf("got %v, want 7", got)
   133  		}
   134  	}
   135  
   136  	if stmt.Step() {
   137  		if got := stmt.ColumnType(0); got != sqlite3.TEXT {
   138  			t.Errorf("got %v, want TEXT", got)
   139  		}
   140  		var got int
   141  		if err := stmt.ColumnJSON(0, &got); err != nil {
   142  			t.Error(err)
   143  		} else if got != 8 {
   144  			t.Errorf("got %v, want 8", got)
   145  		}
   146  	}
   147  
   148  	if stmt.Step() {
   149  		if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
   150  			t.Errorf("got %v, want INTEGER", got)
   151  		}
   152  		if got := stmt.ColumnInt64(0); got != 9 {
   153  			t.Errorf("got %v, want 9", got)
   154  		}
   155  	}
   156  
   157  	if stmt.Step() {
   158  		if got := stmt.ColumnType(0); got != sqlite3.NULL {
   159  			t.Errorf("got %v, want NULL", got)
   160  		}
   161  	}
   162  
   163  	if stmt.Step() {
   164  		t.Error("want error")
   165  	}
   166  	if err := stmt.Err(); !errors.Is(err, sqlite3.FULL) {
   167  		t.Errorf("got %v, want sqlite3.FULL", err)
   168  	}
   169  }
   170  
   171  func TestOverloadFunction(t *testing.T) {
   172  	t.Parallel()
   173  
   174  	db, err := sqlite3.Open(":memory:")
   175  	if err != nil {
   176  		t.Fatal(err)
   177  	}
   178  	defer db.Close()
   179  
   180  	err = db.OverloadFunction("test", 0)
   181  	if err != nil {
   182  		t.Fatal(err)
   183  	}
   184  
   185  	err = db.Exec(`SELECT test()`)
   186  	if err == nil {
   187  		t.Fatal("want error")
   188  	}
   189  }
   190  
   191  func TestAnyCollationNeeded(t *testing.T) {
   192  	t.Parallel()
   193  
   194  	db, err := sqlite3.Open(":memory:")
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  	defer db.Close()
   199  
   200  	err = db.Exec(`CREATE TABLE users (id INT, name VARCHAR(10))`)
   201  	if err != nil {
   202  		t.Fatal(err)
   203  	}
   204  
   205  	err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	db.AnyCollationNeeded()
   211  
   212  	stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
   213  	if err != nil {
   214  		t.Fatal(err)
   215  	}
   216  	defer stmt.Close()
   217  
   218  	row := 0
   219  	ids := []int{0, 2, 1}
   220  	names := []string{"go", "whatever", "zig"}
   221  	for ; stmt.Step(); row++ {
   222  		id := stmt.ColumnInt(0)
   223  		name := stmt.ColumnText(1)
   224  
   225  		if id != ids[row] {
   226  			t.Errorf("got %d, want %d", id, ids[row])
   227  		}
   228  		if name != names[row] {
   229  			t.Errorf("got %q, want %q", name, names[row])
   230  		}
   231  	}
   232  	if row != 3 {
   233  		t.Errorf("got %d, want %d", row, len(ids))
   234  	}
   235  
   236  	if err := stmt.Err(); err != nil {
   237  		t.Fatal(err)
   238  	}
   239  }