github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/load_file_test.go (about)

     1  package function
     2  
     3  import (
     4  	"os"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  
     9  	"github.com/dolthub/go-mysql-server/sql"
    10  	"github.com/dolthub/go-mysql-server/sql/expression"
    11  	"github.com/dolthub/go-mysql-server/sql/types"
    12  )
    13  
    14  // createTempDirAndFile returns the temporary directory, as well creates a new file (var filename) that lives in it.
    15  func createTempDirAndFile(fileName string) (string, *os.File, error) {
    16  	dir := os.TempDir()
    17  
    18  	file, err := os.CreateTemp(dir, fileName)
    19  	if err != nil {
    20  		return "", nil, err
    21  	}
    22  
    23  	return dir, file, nil
    24  }
    25  
    26  func TestLoadFileNoSecurePriv(t *testing.T) {
    27  	// Create a valid temp file and temp directory
    28  	_, file, err := createTempDirAndFile("myfile.txt")
    29  	assert.NoError(t, err)
    30  
    31  	defer file.Close()
    32  	defer os.Remove(file.Name())
    33  
    34  	_, err = file.Write([]byte("my data"))
    35  	assert.NoError(t, err)
    36  
    37  	fileName := expression.NewLiteral(file.Name(), types.Text)
    38  	fn := NewLoadFile(fileName)
    39  
    40  	// Assert that Load File returns the regardless since secure_file_priv is set to an empty directory
    41  	res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{})
    42  	assert.NoError(t, err)
    43  	assert.Equal(t, []byte("my data"), res)
    44  }
    45  
    46  func TestLoadFileBadDir(t *testing.T) {
    47  	// Create a valid temp file and temp directory
    48  	_, file, err := createTempDirAndFile("myfile.txt")
    49  	assert.NoError(t, err)
    50  
    51  	defer file.Close()
    52  	defer os.Remove(file.Name())
    53  
    54  	// Set the secure_file_priv var but make it different than the file directory
    55  	vars := make(map[string]interface{})
    56  	vars["secure_file_priv"] = "/not/a/real/directory"
    57  	err = sql.SystemVariables.AssignValues(vars)
    58  	assert.NoError(t, err)
    59  
    60  	_, err = file.Write([]byte("my data"))
    61  	assert.NoError(t, err)
    62  
    63  	fileName := expression.NewLiteral(file.Name(), types.Text)
    64  	fn := NewLoadFile(fileName)
    65  
    66  	// Assert that Load File returns nil since the file is not in secure_file_priv directory
    67  	res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{})
    68  	assert.NoError(t, err)
    69  	assert.Equal(t, nil, res)
    70  }
    71  
    72  type loadFileTestCase struct {
    73  	name     string
    74  	fileData []byte
    75  	fileName string
    76  }
    77  
    78  func TestLoadFile(t *testing.T) {
    79  	testCases := []loadFileTestCase{
    80  		{
    81  			"simple example",
    82  			[]byte("important test case"),
    83  			"myfile.txt",
    84  		},
    85  		{
    86  			"blob",
    87  			[]byte("\\xFF\\xD8\\xFF\\xE1\\x00"),
    88  			"myfile.jpg",
    89  		},
    90  	}
    91  
    92  	// create the temp dir
    93  	dir := os.TempDir()
    94  
    95  	// Set the secure_file_priv var
    96  	vars := make(map[string]interface{})
    97  	vars["secure_file_priv"] = dir
    98  	err := sql.SystemVariables.AssignValues(vars)
    99  	assert.NoError(t, err)
   100  
   101  	for _, tt := range testCases {
   102  		runLoadFileTest(t, tt, dir)
   103  	}
   104  }
   105  
   106  // runLoadFileTest takes in a loadFileTestCase and its relevant directory and validates whether LOAD_FILE is reading
   107  // the file accordingly.
   108  func runLoadFileTest(t *testing.T, tt loadFileTestCase, dir string) {
   109  	file, err := os.CreateTemp(dir, tt.fileName)
   110  	assert.NoError(t, err)
   111  
   112  	defer file.Close()
   113  	defer os.Remove(file.Name())
   114  
   115  	// Write some data to the file
   116  	_, err = file.Write(tt.fileData)
   117  	assert.NoError(t, err)
   118  
   119  	// Setup the file data
   120  	fileName := expression.NewLiteral(file.Name(), types.Text)
   121  	fn := NewLoadFile(fileName)
   122  
   123  	// Load the file in
   124  	res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{})
   125  	assert.NoError(t, err)
   126  	assert.Equal(t, tt.fileData, res)
   127  }