github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/load_file.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "os" 20 "path/filepath" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 type LoadFile struct { 27 fileName sql.Expression 28 } 29 30 var _ sql.FunctionExpression = (*LoadFile)(nil) 31 var _ sql.CollationCoercible = (*LoadFile)(nil) 32 33 // NewLoadFile returns a LoadFile object for the LOAD_FILE() function. 34 func NewLoadFile(fileName sql.Expression) sql.Expression { 35 return &LoadFile{ 36 fileName: fileName, 37 } 38 } 39 40 // Description implements sql.FunctionExpression 41 func (l *LoadFile) Description() string { 42 return "returns a LoadFile object." 43 } 44 45 // Resolved implements sql.Expression. 46 func (l *LoadFile) Resolved() bool { 47 return l.fileName.Resolved() 48 } 49 50 // String implements sql.Expression. 51 func (l *LoadFile) String() string { 52 return fmt.Sprintf("%s(%s)", l.FunctionName(), l.fileName) 53 } 54 55 // Type implements sql.Expression. 56 func (l *LoadFile) Type() sql.Type { 57 return types.LongBlob 58 } 59 60 // CollationCoercibility implements the interface sql.CollationCoercible. 61 func (*LoadFile) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 62 return sql.Collation_binary, 5 63 } 64 65 // IsNullable implements sql.Expression. 66 func (l *LoadFile) IsNullable() bool { 67 return true 68 } 69 70 // TODO: Allow FILE privileges for GRANT 71 // Eval implements sql.Expression. 72 func (l *LoadFile) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 73 dir, err := ctx.Session.GetSessionVariable(ctx, "secure_file_priv") 74 if err != nil { 75 return "", err 76 } 77 78 // Read the file: Ensure it fits the max byte size 79 file, err := l.getFile(ctx, row, dir.(string)) 80 if err != nil { 81 // If the file doesn't exist we swallow that error 82 if os.IsNotExist(err) { 83 return nil, nil 84 } 85 86 return nil, err 87 } 88 if file == nil { 89 return nil, nil 90 } 91 92 defer file.Close() 93 94 size, isTooBig, err := isFileTooBig(ctx, file) 95 if err != nil { 96 return nil, err 97 } 98 // According to the mysql spec we must return NULL if the file is too big. 99 if isTooBig { 100 return nil, nil 101 } 102 103 // Finally, read the file 104 data := make([]byte, size) 105 _, err = file.Read(data) 106 if err != nil { 107 return nil, err 108 } 109 110 return data, nil 111 } 112 113 // getFile returns the file handler for the passed in filename. The file must be in the secure_file_priv 114 // directory. 115 func (l *LoadFile) getFile(ctx *sql.Context, row sql.Row, secureFileDir string) (*os.File, error) { 116 fileName, err := l.fileName.Eval(ctx, row) 117 if err != nil { 118 return nil, err 119 } 120 121 // If the secure_file_priv directory is not set, just read the file from whatever directory it is in 122 // Otherwise determine whether the file is in the secure_file_priv directory. 123 if secureFileDir == "" { 124 return os.Open(fileName.(string)) 125 } 126 127 // Open the two directories (secure_file_priv and the file dir) and validate they are the same. 128 sDir, err := os.Open(secureFileDir) 129 if err != nil { 130 return nil, err 131 } 132 133 sStat, err := sDir.Stat() 134 if err != nil { 135 return nil, err 136 } 137 138 ffDir, err := os.Open(filepath.Dir(fileName.(string))) 139 if err != nil { 140 return nil, err 141 } 142 143 fStat, err := ffDir.Stat() 144 if err != nil { 145 return nil, err 146 } 147 148 // If the two directories are not equivalent we return nil 149 if !os.SameFile(sStat, fStat) { 150 return nil, nil 151 } 152 153 return os.Open(fileName.(string)) 154 } 155 156 // isFileTooBig return the current file size and whether or not it is larger than max_allowed_packet. 157 func isFileTooBig(ctx *sql.Context, file *os.File) (int64, bool, error) { 158 fi, err := file.Stat() 159 if err != nil { 160 return -1, false, err 161 } 162 163 val, err := ctx.Session.GetSessionVariable(ctx, "max_allowed_packet") 164 if err != nil { 165 return -1, false, err 166 } 167 168 return fi.Size(), fi.Size() > val.(int64), nil 169 } 170 171 // Children implements sql.Expression. 172 func (l *LoadFile) Children() []sql.Expression { 173 return []sql.Expression{l.fileName} 174 } 175 176 // WithChildren implements sql.Expression. 177 func (l *LoadFile) WithChildren(children ...sql.Expression) (sql.Expression, error) { 178 if len(children) != 1 { 179 return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 1) 180 } 181 182 return NewLoadFile(children[0]), nil 183 } 184 185 // FunctionName implements sql.FunctionExpression. 186 func (l *LoadFile) FunctionName() string { 187 return "load_file" 188 }