github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/client/client_test.go (about)

     1  // Copyright 2024 Block, Inc.
     2  
     3  package client_test
     4  
     5  import (
     6  	"context"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/go-test/deep"
    11  
    12  	"github.com/square/finch"
    13  	"github.com/square/finch/client"
    14  	"github.com/square/finch/data"
    15  	"github.com/square/finch/stats"
    16  	"github.com/square/finch/test"
    17  	"github.com/square/finch/trx"
    18  )
    19  
    20  var rl = finch.RunLevel{
    21  	Stage:       1,
    22  	ExecGroup:   1,
    23  	ClientGroup: 1,
    24  	Client:      1,
    25  }
    26  
    27  func TestClient_SELECT_1(t *testing.T) {
    28  	if test.Build {
    29  		t.Skip("GitHub Actions build")
    30  	}
    31  
    32  	_, db, err := test.Connection()
    33  	if err != nil {
    34  		t.Fatal(err)
    35  	}
    36  	defer db.Close()
    37  
    38  	doneChan := make(chan *client.Client, 1)
    39  
    40  	c := &client.Client{
    41  		DB:       db,
    42  		RunLevel: rl,
    43  		DoneChan: doneChan,
    44  		Statements: []*trx.Statement{
    45  			{
    46  				Query:     "SELECT 1",
    47  				ResultSet: true,
    48  			},
    49  		},
    50  		Data: []client.StatementData{
    51  			{
    52  				TrxBoundary: trx.BEGIN | trx.END,
    53  			},
    54  		},
    55  		Stats: []*stats.Trx{nil},
    56  		// --
    57  		Iter: 1, // need some runtime limit
    58  	}
    59  
    60  	err = c.Init()
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  
    65  	c.Run(context.Background())
    66  
    67  	timeout := time.After(2 * time.Second)
    68  	var ret *client.Client
    69  	select {
    70  	case ret = <-doneChan:
    71  	case <-timeout:
    72  		t.Fatal("Client timeout after 2s")
    73  	}
    74  
    75  	if ret != c {
    76  		t.Errorf("returned *Client != run *Client")
    77  	}
    78  
    79  	if ret.Error.Err != nil {
    80  		t.Errorf("Client error: %v", ret.Error.Err)
    81  	}
    82  }
    83  
    84  func TestClient_Write(t *testing.T) {
    85  	if test.Build {
    86  		t.Skip("GitHub Actions build")
    87  	}
    88  
    89  	_, db, err := test.Connection()
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	defer db.Close()
    94  
    95  	queries := []string{
    96  		"CREATE DATABASE IF NOT EXISTS finch",
    97  		"USE finch",
    98  		"DROP TABLE IF EXISTS writetest",
    99  		"CREATE TABLE writetest (i int auto_increment primary key not null, d int)",
   100  	}
   101  	for _, q := range queries {
   102  		if _, err := db.Exec(q); err != nil {
   103  			t.Fatalf("%s: %s", q, err)
   104  		}
   105  	}
   106  
   107  	// Returns value for input for @d to INSERT
   108  	vals := []interface{}{int64(1)}
   109  	valueFunc := func(_ data.RunCount) []interface{} {
   110  		return vals
   111  	}
   112  	// Receives last insertt ID (1 as well since it's first row)
   113  	col := data.NewColumn(nil)
   114  
   115  	doneChan := make(chan *client.Client, 1)
   116  
   117  	c := &client.Client{
   118  		DB:       db,
   119  		RunLevel: rl,
   120  		Iter:     1,
   121  		DoneChan: doneChan,
   122  		Statements: []*trx.Statement{
   123  			{
   124  				Query:  "INSERT INTO writetest VALUES (NULL, %d)",
   125  				Write:  true,
   126  				Inputs: []string{"@d"},
   127  			},
   128  		},
   129  		Data: []client.StatementData{
   130  			{
   131  				TrxBoundary: trx.BEGIN | trx.END,
   132  				Inputs:      []data.ValueFunc{valueFunc},
   133  				InsertId:    col,
   134  			},
   135  		},
   136  		Stats: []*stats.Trx{nil},
   137  	}
   138  
   139  	err = c.Init()
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  
   144  	c.Run(context.Background())
   145  
   146  	timeout := time.After(2 * time.Second)
   147  	var ret *client.Client
   148  	select {
   149  	case ret = <-doneChan:
   150  	case <-timeout:
   151  		t.Fatal("Client timeout after 2s")
   152  	}
   153  
   154  	if ret != c {
   155  		t.Errorf("returned *Client != run *Client")
   156  	}
   157  
   158  	if ret.Error.Err != nil {
   159  		t.Errorf("Client error: %v", ret.Error.Err)
   160  	}
   161  
   162  	// Auot inc insert id == 1
   163  	got := col.Values(data.RunCount{})
   164  	if diff := deep.Equal(got, vals); diff != nil {
   165  		t.Error(diff)
   166  	}
   167  }