github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/database/mysql/infile.go (about)

     1  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     2  //
     3  // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysql
    10  
    11  import (
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strings"
    16  )
    17  
    18  var (
    19  	fileRegister   map[string]bool
    20  	readerRegister map[string]func() io.Reader
    21  )
    22  
    23  // RegisterLocalFile adds the given file to the file whitelist,
    24  // so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
    25  // Alternatively you can allow the use of all local files with
    26  // the DSN parameter 'allowAllFiles=true'
    27  //
    28  //  filePath := "/home/gopher/data.csv"
    29  //  mysql.RegisterLocalFile(filePath)
    30  //  err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
    31  //  if err != nil {
    32  //  ...
    33  //
    34  func RegisterLocalFile(filePath string) {
    35  	// lazy map init
    36  	if fileRegister == nil {
    37  		fileRegister = make(map[string]bool)
    38  	}
    39  
    40  	fileRegister[strings.Trim(filePath, `"`)] = true
    41  }
    42  
    43  // DeregisterLocalFile removes the given filepath from the whitelist.
    44  func DeregisterLocalFile(filePath string) {
    45  	delete(fileRegister, strings.Trim(filePath, `"`))
    46  }
    47  
    48  // RegisterReaderHandler registers a handler function which is used
    49  // to receive a io.Reader.
    50  // The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
    51  // If the handler returns a io.ReadCloser Close() is called when the
    52  // request is finished.
    53  //
    54  //  mysql.RegisterReaderHandler("data", func() io.Reader {
    55  //  	var csvReader io.Reader // Some Reader that returns CSV data
    56  //  	... // Open Reader here
    57  //  	return csvReader
    58  //  })
    59  //  err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
    60  //  if err != nil {
    61  //  ...
    62  //
    63  func RegisterReaderHandler(name string, handler func() io.Reader) {
    64  	// lazy map init
    65  	if readerRegister == nil {
    66  		readerRegister = make(map[string]func() io.Reader)
    67  	}
    68  
    69  	readerRegister[name] = handler
    70  }
    71  
    72  // DeregisterReaderHandler removes the ReaderHandler function with
    73  // the given name from the registry.
    74  func DeregisterReaderHandler(name string) {
    75  	delete(readerRegister, name)
    76  }
    77  
    78  func deferredClose(err *error, closer io.Closer) {
    79  	closeErr := closer.Close()
    80  	if *err == nil {
    81  		*err = closeErr
    82  	}
    83  }
    84  
    85  func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
    86  	var rdr io.Reader
    87  	var data []byte
    88  
    89  	if strings.HasPrefix(name, "Reader::") { // io.Reader
    90  		name = name[8:]
    91  		if handler, inMap := readerRegister[name]; inMap {
    92  			rdr = handler()
    93  			if rdr != nil {
    94  				data = make([]byte, 4+mc.maxWriteSize)
    95  
    96  				if cl, ok := rdr.(io.Closer); ok {
    97  					defer deferredClose(&err, cl)
    98  				}
    99  			} else {
   100  				err = fmt.Errorf("Reader '%s' is <nil>", name)
   101  			}
   102  		} else {
   103  			err = fmt.Errorf("Reader '%s' is not registered", name)
   104  		}
   105  	} else { // File
   106  		name = strings.Trim(name, `"`)
   107  		if mc.cfg.allowAllFiles || fileRegister[name] {
   108  			var file *os.File
   109  			var fi os.FileInfo
   110  
   111  			if file, err = os.Open(name); err == nil {
   112  				defer deferredClose(&err, file)
   113  
   114  				// get file size
   115  				if fi, err = file.Stat(); err == nil {
   116  					rdr = file
   117  					if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
   118  						data = make([]byte, 4+fileSize)
   119  					} else if fileSize <= mc.maxPacketAllowed {
   120  						data = make([]byte, 4+mc.maxWriteSize)
   121  					} else {
   122  						err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
   123  					}
   124  				}
   125  			}
   126  		} else {
   127  			err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
   128  		}
   129  	}
   130  
   131  	// send content packets
   132  	if err == nil {
   133  		var n int
   134  		for err == nil {
   135  			n, err = rdr.Read(data[4:])
   136  			if n > 0 {
   137  				if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
   138  					return ioErr
   139  				}
   140  			}
   141  		}
   142  		if err == io.EOF {
   143  			err = nil
   144  		}
   145  	}
   146  
   147  	// send empty packet (termination)
   148  	if data == nil {
   149  		data = make([]byte, 4)
   150  	}
   151  	if ioErr := mc.writePacket(data[:4]); ioErr != nil {
   152  		return ioErr
   153  	}
   154  
   155  	// read OK packet
   156  	if err == nil {
   157  		return mc.readResultOK()
   158  	} else {
   159  		mc.readPacket()
   160  	}
   161  	return err
   162  }