vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/planbuilder/plan_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package planbuilder
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"log"
    26  	"os"
    27  	"path/filepath"
    28  	"strings"
    29  	"testing"
    30  
    31  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    32  
    33  	"github.com/stretchr/testify/require"
    34  
    35  	"vitess.io/vitess/go/vt/sqlparser"
    36  	"vitess.io/vitess/go/vt/tableacl"
    37  	"vitess.io/vitess/go/vt/vttablet/tabletserver/schema"
    38  )
    39  
    40  // MarshalJSON returns a JSON of the given Plan.
    41  // This is only for testing.
    42  func (p *Plan) MarshalJSON() ([]byte, error) {
    43  	mplan := struct {
    44  		PlanID            PlanType
    45  		TableName         sqlparser.IdentifierCS `json:",omitempty"`
    46  		Permissions       []Permission           `json:",omitempty"`
    47  		FieldQuery        *sqlparser.ParsedQuery `json:",omitempty"`
    48  		FullQuery         *sqlparser.ParsedQuery `json:",omitempty"`
    49  		NextCount         string                 `json:",omitempty"`
    50  		WhereClause       *sqlparser.ParsedQuery `json:",omitempty"`
    51  		NeedsReservedConn bool                   `json:",omitempty"`
    52  	}{
    53  		PlanID:      p.PlanID,
    54  		TableName:   p.TableName(),
    55  		Permissions: p.Permissions,
    56  		FullQuery:   p.FullQuery,
    57  		WhereClause: p.WhereClause,
    58  	}
    59  	if p.NextCount != nil {
    60  		mplan.NextCount = evalengine.FormatExpr(p.NextCount)
    61  	}
    62  	if p.NeedsReservedConn {
    63  		mplan.NeedsReservedConn = true
    64  	}
    65  	return json.Marshal(&mplan)
    66  }
    67  
    68  func TestPlan(t *testing.T) {
    69  	testPlan(t, "exec_cases.txt")
    70  }
    71  
    72  func TestDDLPlan(t *testing.T) {
    73  	testPlan(t, "ddl_cases.txt")
    74  }
    75  
    76  func testPlan(t *testing.T, fileName string) {
    77  	t.Helper()
    78  	testSchema := loadSchema("schema_test.json")
    79  	for tcase := range iterateExecFile(fileName) {
    80  		t.Run(tcase.input, func(t *testing.T) {
    81  			if strings.Contains(tcase.options, "PassthroughDMLs") {
    82  				PassthroughDMLs = true
    83  			}
    84  			var plan *Plan
    85  			var err error
    86  			statement, err := sqlparser.Parse(tcase.input)
    87  			if err == nil {
    88  				plan, err = Build(statement, testSchema, "dbName", false)
    89  			}
    90  			PassthroughDMLs = false
    91  
    92  			var out string
    93  			if err != nil {
    94  				out = err.Error()
    95  			} else {
    96  				bout, err := json.Marshal(plan)
    97  				require.NoError(t, err, "Error marshalling %v: %v", plan, err)
    98  				out = string(bout)
    99  			}
   100  			if out != tcase.output {
   101  				t.Errorf("Line:%v\ngot  = %s\nwant = %s", tcase.lineno, out, tcase.output)
   102  				if err != nil {
   103  					out = fmt.Sprintf("\"%s\"", out)
   104  				} else {
   105  					bout, _ := json.MarshalIndent(plan, "", "  ")
   106  					out = string(bout)
   107  				}
   108  				fmt.Printf("\"in> %s\"\nout>%s\nexpected: %s\n\n", tcase.input, out, tcase.output)
   109  			}
   110  		})
   111  	}
   112  }
   113  
   114  func TestPlanInReservedConn(t *testing.T) {
   115  	testSchema := loadSchema("schema_test.json")
   116  	for tcase := range iterateExecFile("exec_cases.txt") {
   117  		t.Run(tcase.input, func(t *testing.T) {
   118  			if strings.Contains(tcase.options, "PassthroughDMLs") {
   119  				PassthroughDMLs = true
   120  			}
   121  			var plan *Plan
   122  			var err error
   123  			statement, err := sqlparser.Parse(tcase.input)
   124  			if err == nil {
   125  				plan, err = Build(statement, testSchema, "dbName", false)
   126  			}
   127  			PassthroughDMLs = false
   128  
   129  			var out string
   130  			if err != nil {
   131  				out = err.Error()
   132  			} else {
   133  				bout, err := json.Marshal(plan)
   134  				if err != nil {
   135  					t.Fatalf("Error marshalling %v: %v", plan, err)
   136  				}
   137  				out = string(bout)
   138  			}
   139  			if out != tcase.output {
   140  				t.Errorf("Line:%v\ngot  = %s\nwant = %s", tcase.lineno, out, tcase.output)
   141  				if err != nil {
   142  					out = fmt.Sprintf("\"%s\"", out)
   143  				} else {
   144  					bout, _ := json.MarshalIndent(plan, "", "  ")
   145  					out = string(bout)
   146  				}
   147  				fmt.Printf("\"%s\"\n%s\n\n", tcase.input, out)
   148  			}
   149  		})
   150  	}
   151  }
   152  
   153  func TestCustom(t *testing.T) {
   154  	testSchemas, _ := filepath.Glob("testdata/*_schema.json")
   155  	if len(testSchemas) == 0 {
   156  		t.Log("No schemas to test")
   157  		return
   158  	}
   159  	for _, schemFile := range testSchemas {
   160  		schem := loadSchema(schemFile)
   161  		t.Logf("Testing schema %s", schemFile)
   162  		files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1))
   163  		if err != nil {
   164  			log.Fatal(err)
   165  		}
   166  		if len(files) == 0 {
   167  			t.Fatalf("No test files for %s", schemFile)
   168  		}
   169  		for _, file := range files {
   170  			t.Logf("Testing file %s", file)
   171  			for tcase := range iterateExecFile(file) {
   172  				statement, err := sqlparser.Parse(tcase.input)
   173  				if err != nil {
   174  					t.Fatalf("Got error: %v, parsing sql: %v", err.Error(), tcase.input)
   175  				}
   176  				plan, err := Build(statement, schem, "dbName", false)
   177  				var out string
   178  				if err != nil {
   179  					out = err.Error()
   180  				} else {
   181  					bout, err := json.Marshal(plan)
   182  					if err != nil {
   183  						t.Fatalf("Error marshalling %v: %v", plan, err)
   184  					}
   185  					out = string(bout)
   186  				}
   187  				if out != tcase.output {
   188  					t.Errorf("File: %s: Line:%v\ngot  = %s\nwant = %s", file, tcase.lineno, out, tcase.output)
   189  				}
   190  			}
   191  		}
   192  	}
   193  }
   194  
   195  func TestStreamPlan(t *testing.T) {
   196  	testSchema := loadSchema("schema_test.json")
   197  	for tcase := range iterateExecFile("stream_cases.txt") {
   198  		plan, err := BuildStreaming(tcase.input, testSchema)
   199  		var out string
   200  		if err != nil {
   201  			out = err.Error()
   202  		} else {
   203  			bout, err := json.Marshal(plan)
   204  			if err != nil {
   205  				t.Fatalf("Error marshalling %v: %v", plan, err)
   206  			}
   207  			out = string(bout)
   208  		}
   209  		if out != tcase.output {
   210  			t.Errorf("Line:%v\ngot  = %s\nwant = %s", tcase.lineno, out, tcase.output)
   211  		}
   212  	}
   213  }
   214  
   215  func TestMessageStreamingPlan(t *testing.T) {
   216  	testSchema := loadSchema("schema_test.json")
   217  	plan, err := BuildMessageStreaming("msg", testSchema)
   218  	require.NoError(t, err)
   219  	bout, _ := json.Marshal(plan)
   220  	planJSON := string(bout)
   221  
   222  	wantPlan := &Plan{
   223  		PlanID: PlanMessageStream,
   224  		Table:  testSchema["msg"],
   225  		Permissions: []Permission{{
   226  			TableName: "msg",
   227  			Role:      tableacl.WRITER,
   228  		}},
   229  	}
   230  	bout, _ = json.Marshal(wantPlan)
   231  	wantJSON := string(bout)
   232  
   233  	if planJSON != wantJSON {
   234  		t.Errorf("BuildMessageStreaming: \n%s, want\n%s", planJSON, wantJSON)
   235  	}
   236  
   237  	_, err = BuildMessageStreaming("absent", testSchema)
   238  	want := "table absent not found in schema"
   239  	if err == nil || err.Error() != want {
   240  		t.Errorf("BuildMessageStreaming(absent) error: %v, want %s", err, want)
   241  	}
   242  
   243  	_, err = BuildMessageStreaming("a", testSchema)
   244  	want = "'a' is not a message table"
   245  	if err == nil || err.Error() != want {
   246  		t.Errorf("BuildMessageStreaming(absent) error: %v, want %s", err, want)
   247  	}
   248  }
   249  
   250  func TestLockPlan(t *testing.T) {
   251  	testSchema := loadSchema("schema_test.json")
   252  	for tcase := range iterateExecFile("lock_cases.txt") {
   253  		t.Run(tcase.input, func(t *testing.T) {
   254  			var plan *Plan
   255  			var err error
   256  			statement, err := sqlparser.Parse(tcase.input)
   257  			if err == nil {
   258  				plan, err = Build(statement, testSchema, "dbName", false)
   259  			}
   260  
   261  			var out string
   262  			if err != nil {
   263  				out = err.Error()
   264  			} else {
   265  				bout, err := json.Marshal(plan)
   266  				if err != nil {
   267  					t.Fatalf("Error marshalling %v: %v", plan, err)
   268  				}
   269  				out = string(bout)
   270  			}
   271  			if out != tcase.output {
   272  				t.Errorf("Line:%v\ngot  = %s\nwant = %s", tcase.lineno, out, tcase.output)
   273  				if err != nil {
   274  					out = fmt.Sprintf("\"%s\"", out)
   275  				} else {
   276  					bout, _ := json.MarshalIndent(plan, "", "  ")
   277  					out = string(bout)
   278  				}
   279  				fmt.Printf("\"in> %s\"\nout>%s\nexpected: %s\n\n", tcase.input, out, tcase.output)
   280  			}
   281  		})
   282  	}
   283  }
   284  
   285  func loadSchema(name string) map[string]*schema.Table {
   286  	b, err := os.ReadFile(locateFile(name))
   287  	if err != nil {
   288  		panic(err)
   289  	}
   290  	tables := make([]*schema.Table, 0, 10)
   291  	err = json.Unmarshal(b, &tables)
   292  	if err != nil {
   293  		panic(err)
   294  	}
   295  	s := make(map[string]*schema.Table)
   296  	for _, t := range tables {
   297  		s[t.Name.String()] = t
   298  	}
   299  	return s
   300  }
   301  
   302  type testCase struct {
   303  	file    string
   304  	lineno  int
   305  	options string
   306  	input   string
   307  	output  string
   308  }
   309  
   310  func iterateExecFile(name string) (testCaseIterator chan testCase) {
   311  	name = locateFile(name)
   312  	fd, err := os.OpenFile(name, os.O_RDONLY, 0)
   313  	if err != nil {
   314  		panic(fmt.Sprintf("Could not open file %s", name))
   315  	}
   316  	testCaseIterator = make(chan testCase)
   317  	go func() {
   318  		defer close(testCaseIterator)
   319  
   320  		r := bufio.NewReader(fd)
   321  		lineno := 0
   322  		options := ""
   323  		for {
   324  			binput, err := r.ReadBytes('\n')
   325  			if err != nil {
   326  				if err != io.EOF {
   327  					fmt.Printf("Line: %d\n", lineno)
   328  					panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
   329  				}
   330  				break
   331  			}
   332  			lineno++
   333  			input := string(binput)
   334  			if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") {
   335  				// fmt.Printf("%s\n", input)
   336  				continue
   337  			}
   338  
   339  			if strings.HasPrefix(input, "options:") {
   340  				options = input[8:]
   341  				continue
   342  			}
   343  			err = json.Unmarshal(binput, &input)
   344  			if err != nil {
   345  				fmt.Printf("Line: %d, input: %s\n", lineno, binput)
   346  				panic(err)
   347  			}
   348  			input = strings.Trim(input, "\"")
   349  			var output []byte
   350  			for {
   351  				l, err := r.ReadBytes('\n')
   352  				lineno++
   353  				if err != nil {
   354  					fmt.Printf("Line: %d\n", lineno)
   355  					panic(fmt.Errorf("Error reading file %s: %s", name, err.Error()))
   356  				}
   357  				output = append(output, l...)
   358  				if l[0] == '}' {
   359  					output = output[:len(output)-1]
   360  					b := bytes.NewBuffer(make([]byte, 0, 64))
   361  					if err := json.Compact(b, output); err == nil {
   362  						output = b.Bytes()
   363  					}
   364  					break
   365  				}
   366  				if l[0] == '"' {
   367  					output = output[1 : len(output)-2]
   368  					break
   369  				}
   370  			}
   371  			testCaseIterator <- testCase{name, lineno, options, input, string(output)}
   372  			options = ""
   373  		}
   374  	}()
   375  	return testCaseIterator
   376  }
   377  
   378  func locateFile(name string) string {
   379  	return "testdata/" + name
   380  }