github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/load_file_test.go (about) 1 package function 2 3 import ( 4 "os" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 9 "github.com/dolthub/go-mysql-server/sql" 10 "github.com/dolthub/go-mysql-server/sql/expression" 11 "github.com/dolthub/go-mysql-server/sql/types" 12 ) 13 14 // createTempDirAndFile returns the temporary directory, as well creates a new file (var filename) that lives in it. 15 func createTempDirAndFile(fileName string) (string, *os.File, error) { 16 dir := os.TempDir() 17 18 file, err := os.CreateTemp(dir, fileName) 19 if err != nil { 20 return "", nil, err 21 } 22 23 return dir, file, nil 24 } 25 26 func TestLoadFileNoSecurePriv(t *testing.T) { 27 // Create a valid temp file and temp directory 28 _, file, err := createTempDirAndFile("myfile.txt") 29 assert.NoError(t, err) 30 31 defer file.Close() 32 defer os.Remove(file.Name()) 33 34 _, err = file.Write([]byte("my data")) 35 assert.NoError(t, err) 36 37 fileName := expression.NewLiteral(file.Name(), types.Text) 38 fn := NewLoadFile(fileName) 39 40 // Assert that Load File returns the regardless since secure_file_priv is set to an empty directory 41 res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{}) 42 assert.NoError(t, err) 43 assert.Equal(t, []byte("my data"), res) 44 } 45 46 func TestLoadFileBadDir(t *testing.T) { 47 // Create a valid temp file and temp directory 48 _, file, err := createTempDirAndFile("myfile.txt") 49 assert.NoError(t, err) 50 51 defer file.Close() 52 defer os.Remove(file.Name()) 53 54 // Set the secure_file_priv var but make it different than the file directory 55 vars := make(map[string]interface{}) 56 vars["secure_file_priv"] = "/not/a/real/directory" 57 err = sql.SystemVariables.AssignValues(vars) 58 assert.NoError(t, err) 59 60 _, err = file.Write([]byte("my data")) 61 assert.NoError(t, err) 62 63 fileName := expression.NewLiteral(file.Name(), types.Text) 64 fn := NewLoadFile(fileName) 65 66 // Assert that Load File returns nil since the file is not in secure_file_priv directory 67 res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{}) 68 assert.NoError(t, err) 69 assert.Equal(t, nil, res) 70 } 71 72 type loadFileTestCase struct { 73 name string 74 fileData []byte 75 fileName string 76 } 77 78 func TestLoadFile(t *testing.T) { 79 testCases := []loadFileTestCase{ 80 { 81 "simple example", 82 []byte("important test case"), 83 "myfile.txt", 84 }, 85 { 86 "blob", 87 []byte("\\xFF\\xD8\\xFF\\xE1\\x00"), 88 "myfile.jpg", 89 }, 90 } 91 92 // create the temp dir 93 dir := os.TempDir() 94 95 // Set the secure_file_priv var 96 vars := make(map[string]interface{}) 97 vars["secure_file_priv"] = dir 98 err := sql.SystemVariables.AssignValues(vars) 99 assert.NoError(t, err) 100 101 for _, tt := range testCases { 102 runLoadFileTest(t, tt, dir) 103 } 104 } 105 106 // runLoadFileTest takes in a loadFileTestCase and its relevant directory and validates whether LOAD_FILE is reading 107 // the file accordingly. 108 func runLoadFileTest(t *testing.T, tt loadFileTestCase, dir string) { 109 file, err := os.CreateTemp(dir, tt.fileName) 110 assert.NoError(t, err) 111 112 defer file.Close() 113 defer os.Remove(file.Name()) 114 115 // Write some data to the file 116 _, err = file.Write(tt.fileData) 117 assert.NoError(t, err) 118 119 // Setup the file data 120 fileName := expression.NewLiteral(file.Name(), types.Text) 121 fn := NewLoadFile(fileName) 122 123 // Load the file in 124 res, err := fn.Eval(sql.NewEmptyContext(), sql.Row{}) 125 assert.NoError(t, err) 126 assert.Equal(t, tt.fileData, res) 127 }