github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/tests/integration/basic_example_database_sql_bindings_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package integration
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"path"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/stretchr/testify/require"
    18  	"google.golang.org/grpc/metadata"
    19  
    20  	"github.com/ydb-platform/ydb-go-sdk/v3"
    21  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    22  	"github.com/ydb-platform/ydb-go-sdk/v3/meta"
    23  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    24  	"github.com/ydb-platform/ydb-go-sdk/v3/sugar"
    25  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    26  )
    27  
    28  func TestBasicExampleDatabaseSqlBindings(t *testing.T) {
    29  	folder := t.Name()
    30  
    31  	ctx, cancel := context.WithTimeout(xtest.Context(t), 42*time.Second)
    32  	defer cancel()
    33  
    34  	var totalConsumedUnits atomic.Uint64
    35  	defer func() {
    36  		t.Logf("total consumed units: %d", totalConsumedUnits.Load())
    37  	}()
    38  
    39  	ctx = meta.WithTrailerCallback(ctx, func(md metadata.MD) {
    40  		totalConsumedUnits.Add(meta.ConsumedUnits(md))
    41  	})
    42  
    43  	t.Run("sql.Open", func(t *testing.T) {
    44  		db, err := sql.Open("ydb", os.Getenv("YDB_CONNECTION_STRING"))
    45  		require.NoError(t, err)
    46  
    47  		err = db.PingContext(ctx)
    48  		require.NoError(t, err)
    49  
    50  		_, err = ydb.Unwrap(db)
    51  		require.NoError(t, err)
    52  
    53  		err = db.Close()
    54  		require.NoError(t, err)
    55  	})
    56  
    57  	t.Run("sql.OpenDB", func(t *testing.T) {
    58  		nativeDriver, err := ydb.Open(ctx, os.Getenv("YDB_CONNECTION_STRING"),
    59  			withMetrics(t, trace.DetailsAll, 0),
    60  			ydb.WithDiscoveryInterval(time.Second),
    61  		)
    62  		require.NoError(t, err)
    63  
    64  		defer func() {
    65  			// cleanup
    66  			_ = nativeDriver.Close(ctx)
    67  		}()
    68  
    69  		c, err := ydb.Connector(nativeDriver,
    70  			ydb.WithTablePathPrefix(path.Join(nativeDriver.Name(), folder)),
    71  			ydb.WithAutoDeclare(),
    72  			ydb.WithPositionalArgs(),
    73  		)
    74  		require.NoError(t, err)
    75  
    76  		defer func() {
    77  			// cleanup
    78  			_ = c.Close()
    79  		}()
    80  
    81  		db := sql.OpenDB(c)
    82  		defer func() {
    83  			// cleanup
    84  			_ = db.Close()
    85  		}()
    86  
    87  		err = db.PingContext(ctx)
    88  		require.NoError(t, err)
    89  
    90  		db.SetMaxOpenConns(50)
    91  		db.SetMaxIdleConns(50)
    92  
    93  		t.Run("prepare", func(t *testing.T) {
    94  			t.Run("scheme", func(t *testing.T) {
    95  				err = sugar.RemoveRecursive(ctx, nativeDriver, folder)
    96  				require.NoError(t, err)
    97  
    98  				err = sugar.MakeRecursive(ctx, nativeDriver, folder)
    99  				require.NoError(t, err)
   100  
   101  				t.Run("series", func(t *testing.T) {
   102  					var (
   103  						ctx    = ydb.WithQueryMode(ctx, ydb.SchemeQueryMode)
   104  						exists bool
   105  					)
   106  
   107  					exists, err = sugar.IsTableExists(ctx, nativeDriver.Scheme(), path.Join(nativeDriver.Name(), folder, "series"))
   108  					require.NoError(t, err)
   109  
   110  					if exists {
   111  						_, err = db.ExecContext(ctx, `DROP TABLE series;`)
   112  						require.NoError(t, err)
   113  					}
   114  
   115  					_, err = db.ExecContext(ctx, `
   116  						CREATE TABLE series (
   117  							series_id Uint64,
   118  							title UTF8,
   119  							series_info UTF8,
   120  							release_date Date,
   121  							comment UTF8,
   122  							PRIMARY KEY (
   123  								series_id
   124  							)
   125  						);
   126  					`)
   127  					require.NoError(t, err)
   128  				})
   129  				t.Run("seasons", func(t *testing.T) {
   130  					var (
   131  						ctx    = ydb.WithQueryMode(ctx, ydb.SchemeQueryMode)
   132  						exists bool
   133  					)
   134  
   135  					exists, err = sugar.IsTableExists(ctx, nativeDriver.Scheme(), path.Join(nativeDriver.Name(), folder, "seasons"))
   136  					require.NoError(t, err)
   137  
   138  					if exists {
   139  						_, err = db.ExecContext(ctx, `DROP TABLE seasons;`)
   140  						require.NoError(t, err)
   141  					}
   142  
   143  					_, err = db.ExecContext(ctx, `
   144  						CREATE TABLE seasons (
   145  							series_id Uint64,
   146  							season_id Uint64,
   147  							title UTF8,
   148  							first_aired Date,
   149  							last_aired Date,
   150  							PRIMARY KEY (
   151  								series_id,
   152  								season_id
   153  							)
   154  						);
   155  					`)
   156  					require.NoError(t, err)
   157  				})
   158  				t.Run("episodes", func(t *testing.T) {
   159  					var (
   160  						ctx    = ydb.WithQueryMode(ctx, ydb.SchemeQueryMode)
   161  						exists bool
   162  					)
   163  
   164  					exists, err = sugar.IsTableExists(ctx, nativeDriver.Scheme(), path.Join(nativeDriver.Name(), folder, "episodes"))
   165  					require.NoError(t, err)
   166  
   167  					if exists {
   168  						_, err = db.ExecContext(ctx, `DROP TABLE episodes;`)
   169  						require.NoError(t, err)
   170  					}
   171  
   172  					_, err = db.ExecContext(ctx, `
   173  						CREATE TABLE episodes (
   174  							series_id Uint64,
   175  							season_id Uint64,
   176  							episode_id Uint64,
   177  							title UTF8,
   178  							air_date Date,
   179  							views Uint64,
   180  							PRIMARY KEY (
   181  								series_id,
   182  								season_id,
   183  								episode_id
   184  							)
   185  						);
   186  					`)
   187  					require.NoError(t, err)
   188  				})
   189  			})
   190  		})
   191  
   192  		t.Run("batch", func(t *testing.T) {
   193  			t.Run("upsert", func(t *testing.T) {
   194  				err = retry.Do(ctx, db, func(ctx context.Context, cc *sql.Conn) error {
   195  					stmt, err := cc.PrepareContext(ctx, `
   196  						REPLACE INTO series SELECT * FROM AS_TABLE(?);
   197  						REPLACE INTO seasons SELECT * FROM AS_TABLE(?);
   198  						REPLACE INTO episodes SELECT * FROM AS_TABLE(?);
   199  					`)
   200  					if err != nil {
   201  						return fmt.Errorf("failed to prepare query: %w", err)
   202  					}
   203  					_, err = stmt.ExecContext(ctx,
   204  						getSeriesData(),
   205  						getSeasonsData(),
   206  						getEpisodesData(),
   207  					)
   208  					if err != nil {
   209  						return fmt.Errorf("failed to execute statement: %w", err)
   210  					}
   211  					return nil
   212  				}, retry.WithIdempotent(true))
   213  				require.NoError(t, err)
   214  			})
   215  		})
   216  
   217  		t.Run("query", func(t *testing.T) {
   218  			t.Run("explain", func(t *testing.T) {
   219  				row := db.QueryRowContext(
   220  					ydb.WithQueryMode(ctx, ydb.ExplainQueryMode), `
   221  						SELECT views FROM episodes WHERE series_id = ? AND season_id = ? AND episode_id = ?;
   222  					`,
   223  					uint64(1),
   224  					uint64(1),
   225  					uint64(1),
   226  				)
   227  				var (
   228  					ast  string
   229  					plan string
   230  				)
   231  
   232  				err = row.Scan(&ast, &plan)
   233  				require.NoError(t, err)
   234  
   235  				t.Logf("ast = %v", ast)
   236  				t.Logf("plan = %v", plan)
   237  			})
   238  			t.Run("increment", func(t *testing.T) {
   239  				t.Run("views", func(t *testing.T) {
   240  					err = retry.DoTx(ctx, db, func(ctx context.Context, tx *sql.Tx) (err error) {
   241  						var stmt *sql.Stmt
   242  						stmt, err = tx.PrepareContext(ctx, `
   243  							SELECT views FROM episodes WHERE series_id = ? AND season_id = ? AND episode_id = ?;
   244  						`)
   245  						if err != nil {
   246  							return fmt.Errorf("cannot prepare query: %w", err)
   247  						}
   248  
   249  						row := stmt.QueryRowContext(ctx,
   250  							uint64(1),
   251  							uint64(1),
   252  							uint64(1),
   253  						)
   254  						var views sql.NullFloat64
   255  						if err = row.Scan(&views); err != nil {
   256  							return fmt.Errorf("cannot scan views: %w", err)
   257  						}
   258  						if views.Valid {
   259  							return fmt.Errorf("unexpected valid views: %v", views.Float64)
   260  						}
   261  						// increment `views`
   262  						_, err = tx.ExecContext(ctx, `
   263  								UPSERT INTO episodes ( series_id, season_id, episode_id, views )
   264  								VALUES ( ?, ?, ?, ? );
   265  							`,
   266  							uint64(1),
   267  							uint64(1),
   268  							uint64(1),
   269  							uint64(views.Float64+1), // increment views
   270  						)
   271  						if err != nil {
   272  							return fmt.Errorf("cannot upsert views: %w", err)
   273  						}
   274  						return nil
   275  					}, retry.WithIdempotent(true))
   276  					require.NoError(t, err)
   277  				})
   278  			})
   279  			t.Run("select", func(t *testing.T) {
   280  				t.Run("isolation", func(t *testing.T) {
   281  					t.Run("snapshot", func(t *testing.T) {
   282  						query := `
   283  							SELECT views FROM episodes 
   284  							WHERE 
   285  								series_id = ? AND 
   286  								season_id = ? AND 
   287  								episode_id = ?;
   288  						`
   289  						err = retry.DoTx(ctx, db,
   290  							func(ctx context.Context, tx *sql.Tx) error {
   291  								row := tx.QueryRowContext(ctx, query,
   292  									uint64(1),
   293  									uint64(1),
   294  									uint64(1),
   295  								)
   296  								var views sql.NullFloat64
   297  								if err = row.Scan(&views); err != nil {
   298  									return fmt.Errorf("cannot select current views: %w", err)
   299  								}
   300  								if !views.Valid {
   301  									return fmt.Errorf("unexpected invalid views: %v", views)
   302  								}
   303  								if views.Float64 != 1 {
   304  									return fmt.Errorf("unexpected views value: %v", views)
   305  								}
   306  								return nil
   307  							},
   308  							retry.WithIdempotent(true),
   309  							retry.WithTxOptions(&sql.TxOptions{
   310  								Isolation: sql.LevelSnapshot,
   311  								ReadOnly:  true,
   312  							}),
   313  						)
   314  						if !errors.Is(err, context.DeadlineExceeded) {
   315  							require.NoError(t, err)
   316  						}
   317  					})
   318  				})
   319  				t.Run("scan", func(t *testing.T) {
   320  					t.Run("query", func(t *testing.T) {
   321  						var (
   322  							seriesID  *uint64
   323  							seasonID  *uint64
   324  							episodeID *uint64
   325  							title     *string
   326  							airDate   *time.Time
   327  							views     sql.NullFloat64
   328  							query     = `
   329  								SELECT 
   330  									series_id,
   331  									season_id,
   332  									episode_id,
   333  									title,
   334  									air_date,
   335  									views
   336  								FROM episodes
   337  								WHERE 
   338  									(series_id >= ? OR ? IS NULL) AND
   339  									(season_id >= ? OR ? IS NULL) AND
   340  									(episode_id >= ? OR ? IS NULL) 
   341  								ORDER BY 
   342  									series_id, season_id, episode_id;
   343  							`
   344  						)
   345  						err := retry.DoTx(ctx, db,
   346  							func(ctx context.Context, cc *sql.Tx) error {
   347  								rows, err := cc.QueryContext(ctx, query,
   348  									seriesID,
   349  									seriesID,
   350  									seasonID,
   351  									seasonID,
   352  									episodeID,
   353  									episodeID,
   354  								)
   355  								if err != nil {
   356  									return err
   357  								}
   358  								defer func() {
   359  									_ = rows.Close()
   360  								}()
   361  								for rows.NextResultSet() {
   362  									for rows.Next() {
   363  										if err = rows.Scan(&seriesID, &seasonID, &episodeID, &title, &airDate, &views); err != nil {
   364  											return fmt.Errorf("cannot select current views: %w", err)
   365  										}
   366  										t.Logf("[%d][%d][%d] - %s %q (%d views)",
   367  											*seriesID, *seasonID, *episodeID, airDate.Format("2006-01-02"),
   368  											*title, uint64(views.Float64),
   369  										)
   370  									}
   371  								}
   372  								return rows.Err()
   373  							},
   374  							retry.WithIdempotent(true),
   375  							retry.WithTxOptions(&sql.TxOptions{Isolation: sql.LevelSnapshot, ReadOnly: true}),
   376  						)
   377  						if !errors.Is(err, context.DeadlineExceeded) {
   378  							require.NoError(t, err)
   379  						}
   380  					})
   381  				})
   382  			})
   383  		})
   384  	})
   385  }