github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/tests/logic_test.go (about)

     1  // Copyright 2023 zGraph Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tests_test
    16  
    17  import (
    18  	"bufio"
    19  	"context"
    20  	"crypto/md5"
    21  	"database/sql"
    22  	"fmt"
    23  	"os"
    24  	"path/filepath"
    25  	"sort"
    26  	"strings"
    27  	"testing"
    28  	"unicode"
    29  
    30  	"github.com/stretchr/testify/require"
    31  	_ "github.com/vescale/zgraph"
    32  )
    33  
    34  // This file implements an end-to-end test framework for zGraph. The test files are located
    35  // in the testdata/logic_test directory. Each file is written in Test-Script format, which
    36  // is designed by SQLite. See https://www.sqlite.org/sqllogictest/.
    37  //
    38  // Test scripts are line-oriented ASCII text files. Lines starting with a '#' character are
    39  // comments and are ignored. Test scripts consist of zero or more records. A record is a single
    40  // statement or query or a control record. Each record is separated from its neighbors by one
    41  // or more blank line
    42  //
    43  // Currently, TestLogic only supports statement and query records. The syntax is as follows:
    44  //
    45  //  - statement ok
    46  //    Run the statement and expect it to succeed.
    47  //    e.g.
    48  //      statement ok
    49  //      CREATE GRAPH g
    50  //
    51  //  - statement|query error <regexp>
    52  //    Run the statement or query and expect it to fail with an error message matching the regexp.
    53  //    e.g.
    54  //      statement error graph g already exists
    55  //      CREATE GRAPH g
    56  //
    57  //  - query <type-string> <sort-mode> <label>
    58  //    Run the query and expect it to succeed. The result is compared with the expected results.
    59  //    The expected results is separated by '----' after the query statement. If '----' is omitted,
    60  //    the query is expected to return empty set or results hash matches the previous query with
    61  //    the same label.
    62  //
    63  //    The <type-string> argument to the query statement is a short string that specifies the
    64  //    number of result columns and the expected datatype of each result column. There is one
    65  //    character in the <type-string> for each result column. The characters codes are "T" for
    66  //    a text result, "I" for an integer result, and "R" for a floating-point result.
    67  //
    68  //    The <sort-mode> argument is optional, which specifies how the result rows should be
    69  //    sorted before comparing with the expected results. If <sort-mode> is present, it must
    70  //    be either "nosort", "rowsort", "valuesort".
    71  //    - "nosort" means that the result rows should not be sorted before comparing with the
    72  //      expected results, which is the default behavior.
    73  //    - "rowsort" means that the result rows should be sorted by rows before comparing with
    74  //      the expected results.
    75  //    - "valuesort" is similar to "rowsort", but the results are sorted by values, regardless
    76  //      of how row groupings.
    77  //
    78  //    The <label> argument is optional. If present, the test runner will compute a hash of
    79  //    the results. If the same label is reused, the results must be the same.
    80  //
    81  //    In the results section, integer values are rendered as if by printf("%d"). Floating
    82  //    point values are rendered as if by printf("%.3f"). NULL values are rendered as "NULL".
    83  //    Empty strings are rendered as "(empty)". Within non-empty strings, all control characters
    84  //    and unprintable characters are rendered as "@".
    85  //
    86  //    e.g.
    87  //      query I rowsort
    88  //      SELECT n.name FROM MATCH (n)
    89  //      ----
    90  //      Alice
    91  //      Bob
    92  
    93  const logicTestPath = "testdata/logic_test"
    94  
    95  func TestLogic(t *testing.T) {
    96  	err := filepath.Walk(logicTestPath, func(path string, info os.FileInfo, err error) error {
    97  		if err != nil {
    98  			return err
    99  		}
   100  		if !info.Mode().IsRegular() {
   101  			return nil
   102  		}
   103  
   104  		t.Run(info.Name(), func(t *testing.T) {
   105  			t.Parallel()
   106  
   107  			db, err := sql.Open("zgraph", t.TempDir())
   108  			require.NoError(t, err)
   109  			defer db.Close()
   110  
   111  			conn, err := db.Conn(context.Background())
   112  			require.NoError(t, err)
   113  			defer conn.Close()
   114  
   115  			lt := logicTest{
   116  				t:        t,
   117  				conn:     conn,
   118  				labelMap: make(map[string]string),
   119  			}
   120  			lt.run(path)
   121  		})
   122  		return nil
   123  	})
   124  	require.NoError(t, err)
   125  }
   126  
   127  type lineScanner struct {
   128  	*bufio.Scanner
   129  	line int
   130  }
   131  
   132  func (ls *lineScanner) Scan() bool {
   133  	if ls.Scanner.Scan() {
   134  		ls.line++
   135  		return true
   136  	}
   137  	return false
   138  }
   139  
   140  type logicTest struct {
   141  	t    *testing.T
   142  	conn *sql.Conn
   143  
   144  	// labelMap is a map from label to hash.
   145  	labelMap map[string]string
   146  }
   147  
   148  func (lt *logicTest) run(path string) {
   149  	t := lt.t
   150  	f, err := os.Open(path)
   151  	require.NoError(t, err)
   152  	defer f.Close()
   153  
   154  	s := lineScanner{Scanner: bufio.NewScanner(f)}
   155  	for s.Scan() {
   156  		line := s.Text()
   157  		if strings.HasPrefix(line, "#") {
   158  			continue
   159  		}
   160  		fields := strings.Fields(line)
   161  		if len(fields) == 0 {
   162  			continue
   163  		}
   164  
   165  		cmd := fields[0]
   166  		switch cmd {
   167  		case "statement":
   168  			stmt := logicStatement{
   169  				pos: fmt.Sprintf("%s:%d", path, s.line),
   170  			}
   171  
   172  			if len(fields) < 2 {
   173  				t.Fatalf("%s: statement command should have at least 2 arguments", stmt.pos)
   174  			}
   175  
   176  			if fields[1] == "error" {
   177  				expectedErr := strings.TrimSpace(strings.TrimPrefix(line, "statement"))
   178  				expectedErr = strings.TrimSpace(strings.TrimPrefix(expectedErr, "error"))
   179  				stmt.expectedErr = expectedErr
   180  			}
   181  
   182  			var sqlStr strings.Builder
   183  			for s.Scan() {
   184  				line = s.Text()
   185  				if strings.TrimSpace(line) == "" {
   186  					break
   187  				}
   188  				if line == "----" {
   189  					t.Fatalf("%s:%d: unexpected '----' after a statement", path, s.line)
   190  				}
   191  				fmt.Fprintf(&sqlStr, "\n%s", line)
   192  			}
   193  			stmt.sql = sqlStr.String()
   194  			lt.execStatement(stmt)
   195  		case "query":
   196  			query := logicQuery{}
   197  			query.pos = fmt.Sprintf("%s:%d", path, s.line)
   198  
   199  			if len(fields) < 2 {
   200  				t.Fatalf("%s: query command should have at least 2 arguments", query.pos)
   201  			}
   202  
   203  			if fields[1] == "error" {
   204  				expectedErr := strings.TrimSpace(strings.TrimPrefix(line, "query"))
   205  				expectedErr = strings.TrimSpace(strings.TrimPrefix(expectedErr, "error"))
   206  				query.expectedErr = expectedErr
   207  			} else {
   208  				query.typeStr = fields[1]
   209  				query.sorter = noSort
   210  				if len(fields) >= 3 {
   211  					switch fields[2] {
   212  					case "nosort":
   213  					case "rowsort":
   214  						query.sorter = rowSort
   215  					case "valuesort":
   216  						query.sorter = valueSort
   217  					default:
   218  						t.Fatalf("%s:%d unknown sort mode: %s", path, s.line, fields[2])
   219  					}
   220  				}
   221  				if len(fields) >= 4 {
   222  					query.label = fields[3]
   223  				}
   224  
   225  				// Parse SQL query.
   226  				var sqlStr strings.Builder
   227  				var hasSeparator bool
   228  				for s.Scan() {
   229  					line = s.Text()
   230  					if strings.TrimSpace(line) == "" {
   231  						break
   232  					}
   233  					if line == "----" {
   234  						if query.expectedErr != "" {
   235  							t.Fatalf("%s:%d unexpected '----' after a query that expects an error", path, s.line)
   236  						}
   237  						hasSeparator = true
   238  						break
   239  					}
   240  					fmt.Fprintf(&sqlStr, "\n%s", line)
   241  				}
   242  				query.sql = sqlStr.String()
   243  
   244  				// Parse expected results.
   245  				if hasSeparator {
   246  					for s.Scan() {
   247  						line = s.Text()
   248  						if strings.TrimSpace(line) == "" {
   249  							break
   250  						}
   251  						query.expectedResults = append(query.expectedResults, strings.Fields(line)...)
   252  					}
   253  				}
   254  			}
   255  
   256  			lt.execQuery(query)
   257  		default:
   258  			t.Fatalf("%s:%d unknown command: %s", path, s.line, cmd)
   259  		}
   260  
   261  	}
   262  }
   263  
   264  func (lt *logicTest) execStatement(stmt logicStatement) {
   265  	t := lt.t
   266  
   267  	_, err := lt.conn.ExecContext(context.Background(), stmt.sql)
   268  	if stmt.expectedErr != "" {
   269  		require.Error(t, err)
   270  		require.Regexp(t, stmt.expectedErr, err.Error())
   271  	} else {
   272  		require.NoError(t, err)
   273  	}
   274  }
   275  
   276  func (lt *logicTest) execQuery(query logicQuery) {
   277  	t := lt.t
   278  
   279  	rows, err := lt.conn.QueryContext(context.Background(), query.sql)
   280  	if query.expectedErr != "" {
   281  		require.Error(t, err)
   282  		require.Regexp(t, query.expectedErr, err.Error())
   283  	} else {
   284  		require.NoError(t, err)
   285  		defer rows.Close()
   286  	}
   287  
   288  	var values []string
   289  	numCols := len(query.typeStr)
   290  	for rows.Next() {
   291  		cols, err := rows.Columns()
   292  		require.NoError(t, err)
   293  		require.Equal(t, numCols, len(cols), "number of columns mismatch")
   294  
   295  		dest := make([]interface{}, len(cols))
   296  		for i := range query.typeStr {
   297  			switch query.typeStr[i] {
   298  			case 'T':
   299  				dest[i] = &sql.NullString{}
   300  			case 'I':
   301  				dest[i] = &sql.NullInt64{}
   302  			case 'R':
   303  				dest[i] = &sql.NullFloat64{}
   304  			default:
   305  				t.Fatalf("unknown type character: %c", query.typeStr[i])
   306  			}
   307  		}
   308  		require.NoError(t, rows.Scan(dest...))
   309  
   310  		for i := range dest {
   311  			switch query.typeStr[i] {
   312  			case 'T':
   313  				val := *dest[i].(*sql.NullString)
   314  				s := val.String
   315  				if !val.Valid {
   316  					s = "NULL"
   317  				}
   318  				if s == "" {
   319  					s = "(empty)"
   320  				}
   321  				s = strings.Map(func(r rune) rune {
   322  					if unicode.IsControl(r) {
   323  						return '@'
   324  					}
   325  					return r
   326  				}, s)
   327  				// Replace consecutive spaces with a single space.
   328  				s = strings.Join(strings.Fields(s), " ")
   329  				values = append(values, s)
   330  			case 'I':
   331  				val := *dest[i].(*sql.NullInt64)
   332  				if val.Valid {
   333  					values = append(values, fmt.Sprintf("%d", val.Int64))
   334  				} else {
   335  					values = append(values, "NULL")
   336  				}
   337  			case 'R':
   338  				val := *dest[i].(*sql.NullFloat64)
   339  				if val.Valid {
   340  					values = append(values, fmt.Sprintf("%.3f", val.Float64))
   341  				} else {
   342  					values = append(values, "NULL")
   343  				}
   344  			}
   345  		}
   346  	}
   347  	require.NoError(t, rows.Err())
   348  
   349  	values = query.sorter(numCols, values)
   350  	// Format values so that they can be compared with expected results.
   351  	values = strings.Fields(strings.Join(values, " "))
   352  
   353  	if len(query.expectedResults) > 0 || query.label == "" {
   354  		// If there are expected results, or if there is no label, then the results must match.
   355  		require.Equal(t, query.expectedResults, values, "%s: results mismatch", query.pos)
   356  	}
   357  
   358  	if query.label != "" {
   359  		hash := hashResults(values)
   360  		if prevHash, ok := lt.labelMap[query.label]; ok {
   361  			require.Equal(t, prevHash, hash, "%s: results for label %s mismatch", query.pos, query.label)
   362  		}
   363  		lt.labelMap[query.label] = hash
   364  	}
   365  }
   366  
   367  func hashResults(results []string) string {
   368  	return fmt.Sprintf("%x", md5.Sum([]byte(strings.Join(results, " "))))
   369  }
   370  
   371  type logicStatement struct {
   372  	pos         string
   373  	sql         string
   374  	expectedErr string
   375  }
   376  
   377  type logicSorter func(numCols int, values []string) []string
   378  
   379  func noSort(_ int, values []string) []string {
   380  	return values
   381  }
   382  
   383  type rowSorter struct {
   384  	numCols int
   385  	values  []string
   386  }
   387  
   388  func (rs *rowSorter) Len() int {
   389  	return len(rs.values) / rs.numCols
   390  }
   391  
   392  func (rs *rowSorter) Less(i, j int) bool {
   393  	a := rs.row(i)
   394  	b := rs.row(j)
   395  	for k := 0; k < rs.numCols; k++ {
   396  		if a[k] != b[k] {
   397  			return a[k] < b[k]
   398  		}
   399  	}
   400  	return false
   401  }
   402  
   403  func (rs *rowSorter) Swap(i, j int) {
   404  	a := rs.row(i)
   405  	b := rs.row(j)
   406  	for k := 0; k < rs.numCols; k++ {
   407  		a[k], b[k] = b[k], a[k]
   408  	}
   409  }
   410  
   411  func (rs *rowSorter) row(i int) []string {
   412  	return rs.values[i*rs.numCols : (i+1)*rs.numCols]
   413  }
   414  
   415  func rowSort(numCols int, values []string) []string {
   416  	rs := rowSorter{
   417  		numCols: numCols,
   418  		values:  values,
   419  	}
   420  	sort.Sort(&rs)
   421  	return rs.values
   422  }
   423  
   424  func valueSort(_ int, values []string) []string {
   425  	sort.Strings(values)
   426  	return values
   427  }
   428  
   429  type logicQuery struct {
   430  	logicStatement
   431  
   432  	typeStr         string
   433  	sorter          logicSorter
   434  	label           string
   435  	expectedResults []string
   436  }