vitess.io/vitess@v0.16.2/go/vt/mysqlctl/schema_test.go (about) 1 package mysqlctl 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 9 "vitess.io/vitess/go/mysql/fakesqldb" 10 "vitess.io/vitess/go/sqltypes" 11 querypb "vitess.io/vitess/go/vt/proto/query" 12 ) 13 14 var queryMap map[string]*sqltypes.Result 15 16 func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, error) { 17 queryMap = make(map[string]*sqltypes.Result) 18 getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "t1") 19 queryMap[getColsQuery] = &sqltypes.Result{ 20 Fields: []*querypb.Field{{ 21 Name: "column_name", 22 Type: sqltypes.VarChar, 23 }}, 24 Rows: [][]sqltypes.Value{ 25 {sqltypes.NewVarChar("col1")}, 26 {sqltypes.NewVarChar("col2")}, 27 {sqltypes.NewVarChar("col3")}, 28 }, 29 } 30 31 queryMap["SELECT `col1`, `col2`, `col3` FROM `test`.`t1` WHERE 1 != 1"] = &sqltypes.Result{ 32 Fields: []*querypb.Field{{ 33 Name: "col1", 34 Type: sqltypes.VarChar, 35 }, { 36 Name: "col2", 37 Type: sqltypes.Int64, 38 }, { 39 Name: "col3", 40 Type: sqltypes.VarBinary, 41 }}, 42 } 43 getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "t2") 44 queryMap[getColsQuery] = &sqltypes.Result{ 45 Fields: []*querypb.Field{{ 46 Name: "column_name", 47 Type: sqltypes.VarChar, 48 }}, 49 Rows: [][]sqltypes.Value{ 50 {sqltypes.NewVarChar("col1")}, 51 }, 52 } 53 54 queryMap["SELECT `col1` FROM `t2` WHERE 1 != 1"] = &sqltypes.Result{ 55 Fields: []*querypb.Field{{ 56 Name: "col1", 57 Type: sqltypes.VarChar, 58 }}, 59 } 60 result, ok := queryMap[query] 61 if ok { 62 return result, nil 63 } 64 return nil, fmt.Errorf("query %s not found in mock setup", query) 65 } 66 67 func TestColumnList(t *testing.T) { 68 db := fakesqldb.New(t) 69 defer db.Close() 70 fields, _, err := GetColumns("test", "t1", mockExec) 71 require.NoError(t, err) 72 require.Equal(t, `[name:"col1" type:VARCHAR name:"col2" type:INT64 name:"col3" type:VARBINARY]`, fmt.Sprintf("%+v", fields)) 73 74 fields, _, err = GetColumns("", "t2", mockExec) 75 require.NoError(t, err) 76 require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields)) 77 }