vitess.io/vitess@v0.16.2/go/vt/mysqlctl/schema_test.go (about)

     1  package mysqlctl
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"vitess.io/vitess/go/mysql/fakesqldb"
    10  	"vitess.io/vitess/go/sqltypes"
    11  	querypb "vitess.io/vitess/go/vt/proto/query"
    12  )
    13  
    14  var queryMap map[string]*sqltypes.Result
    15  
    16  func mockExec(query string, maxRows int, wantFields bool) (*sqltypes.Result, error) {
    17  	queryMap = make(map[string]*sqltypes.Result)
    18  	getColsQuery := fmt.Sprintf(GetColumnNamesQuery, "'test'", "t1")
    19  	queryMap[getColsQuery] = &sqltypes.Result{
    20  		Fields: []*querypb.Field{{
    21  			Name: "column_name",
    22  			Type: sqltypes.VarChar,
    23  		}},
    24  		Rows: [][]sqltypes.Value{
    25  			{sqltypes.NewVarChar("col1")},
    26  			{sqltypes.NewVarChar("col2")},
    27  			{sqltypes.NewVarChar("col3")},
    28  		},
    29  	}
    30  
    31  	queryMap["SELECT `col1`, `col2`, `col3` FROM `test`.`t1` WHERE 1 != 1"] = &sqltypes.Result{
    32  		Fields: []*querypb.Field{{
    33  			Name: "col1",
    34  			Type: sqltypes.VarChar,
    35  		}, {
    36  			Name: "col2",
    37  			Type: sqltypes.Int64,
    38  		}, {
    39  			Name: "col3",
    40  			Type: sqltypes.VarBinary,
    41  		}},
    42  	}
    43  	getColsQuery = fmt.Sprintf(GetColumnNamesQuery, "database()", "t2")
    44  	queryMap[getColsQuery] = &sqltypes.Result{
    45  		Fields: []*querypb.Field{{
    46  			Name: "column_name",
    47  			Type: sqltypes.VarChar,
    48  		}},
    49  		Rows: [][]sqltypes.Value{
    50  			{sqltypes.NewVarChar("col1")},
    51  		},
    52  	}
    53  
    54  	queryMap["SELECT `col1` FROM `t2` WHERE 1 != 1"] = &sqltypes.Result{
    55  		Fields: []*querypb.Field{{
    56  			Name: "col1",
    57  			Type: sqltypes.VarChar,
    58  		}},
    59  	}
    60  	result, ok := queryMap[query]
    61  	if ok {
    62  		return result, nil
    63  	}
    64  	return nil, fmt.Errorf("query %s not found in mock setup", query)
    65  }
    66  
    67  func TestColumnList(t *testing.T) {
    68  	db := fakesqldb.New(t)
    69  	defer db.Close()
    70  	fields, _, err := GetColumns("test", "t1", mockExec)
    71  	require.NoError(t, err)
    72  	require.Equal(t, `[name:"col1" type:VARCHAR name:"col2" type:INT64 name:"col3" type:VARBINARY]`, fmt.Sprintf("%+v", fields))
    73  
    74  	fields, _, err = GetColumns("", "t2", mockExec)
    75  	require.NoError(t, err)
    76  	require.Equal(t, `[name:"col1" type:VARCHAR]`, fmt.Sprintf("%+v", fields))
    77  }