github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-sql-driver/mysql/infile.go (about)

     1  // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
     2  //
     3  // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
     4  //
     5  // This Source Code Form is subject to the terms of the Mozilla Public
     6  // License, v. 2.0. If a copy of the MPL was not distributed with this file,
     7  // You can obtain one at http://mozilla.org/MPL/2.0/.
     8  
     9  package mysql
    10  
    11  import (
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strings"
    16  	"sync"
    17  )
    18  
    19  var (
    20  	fileRegister       map[string]bool
    21  	fileRegisterLock   sync.RWMutex
    22  	readerRegister     map[string]func() io.Reader
    23  	readerRegisterLock sync.RWMutex
    24  )
    25  
    26  // RegisterLocalFile adds the given file to the file whitelist,
    27  // so that it can be used by "LOAD DATA LOCAL INFILE <filepath>".
    28  // Alternatively you can allow the use of all local files with
    29  // the DSN parameter 'allowAllFiles=true'
    30  //
    31  //  filePath := "/home/gopher/data.csv"
    32  //  mysql.RegisterLocalFile(filePath)
    33  //  err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo")
    34  //  if err != nil {
    35  //  ...
    36  //
    37  func RegisterLocalFile(filePath string) {
    38  	fileRegisterLock.Lock()
    39  	// lazy map init
    40  	if fileRegister == nil {
    41  		fileRegister = make(map[string]bool)
    42  	}
    43  
    44  	fileRegister[strings.Trim(filePath, `"`)] = true
    45  	fileRegisterLock.Unlock()
    46  }
    47  
    48  // DeregisterLocalFile removes the given filepath from the whitelist.
    49  func DeregisterLocalFile(filePath string) {
    50  	fileRegisterLock.Lock()
    51  	delete(fileRegister, strings.Trim(filePath, `"`))
    52  	fileRegisterLock.Unlock()
    53  }
    54  
    55  // RegisterReaderHandler registers a handler function which is used
    56  // to receive a io.Reader.
    57  // The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>".
    58  // If the handler returns a io.ReadCloser Close() is called when the
    59  // request is finished.
    60  //
    61  //  mysql.RegisterReaderHandler("data", func() io.Reader {
    62  //  	var csvReader io.Reader // Some Reader that returns CSV data
    63  //  	... // Open Reader here
    64  //  	return csvReader
    65  //  })
    66  //  err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo")
    67  //  if err != nil {
    68  //  ...
    69  //
    70  func RegisterReaderHandler(name string, handler func() io.Reader) {
    71  	readerRegisterLock.Lock()
    72  	// lazy map init
    73  	if readerRegister == nil {
    74  		readerRegister = make(map[string]func() io.Reader)
    75  	}
    76  
    77  	readerRegister[name] = handler
    78  	readerRegisterLock.Unlock()
    79  }
    80  
    81  // DeregisterReaderHandler removes the ReaderHandler function with
    82  // the given name from the registry.
    83  func DeregisterReaderHandler(name string) {
    84  	readerRegisterLock.Lock()
    85  	delete(readerRegister, name)
    86  	readerRegisterLock.Unlock()
    87  }
    88  
    89  func deferredClose(err *error, closer io.Closer) {
    90  	closeErr := closer.Close()
    91  	if *err == nil {
    92  		*err = closeErr
    93  	}
    94  }
    95  
    96  func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
    97  	var rdr io.Reader
    98  	var data []byte
    99  	packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
   100  	if mc.maxWriteSize < packetSize {
   101  		packetSize = mc.maxWriteSize
   102  	}
   103  
   104  	if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
   105  		// The server might return an an absolute path. See issue #355.
   106  		name = name[idx+8:]
   107  
   108  		readerRegisterLock.RLock()
   109  		handler, inMap := readerRegister[name]
   110  		readerRegisterLock.RUnlock()
   111  
   112  		if inMap {
   113  			rdr = handler()
   114  			if rdr != nil {
   115  				if cl, ok := rdr.(io.Closer); ok {
   116  					defer deferredClose(&err, cl)
   117  				}
   118  			} else {
   119  				err = fmt.Errorf("Reader '%s' is <nil>", name)
   120  			}
   121  		} else {
   122  			err = fmt.Errorf("Reader '%s' is not registered", name)
   123  		}
   124  	} else { // File
   125  		name = strings.Trim(name, `"`)
   126  		fileRegisterLock.RLock()
   127  		fr := fileRegister[name]
   128  		fileRegisterLock.RUnlock()
   129  		if mc.cfg.AllowAllFiles || fr {
   130  			var file *os.File
   131  			var fi os.FileInfo
   132  
   133  			if file, err = os.Open(name); err == nil {
   134  				defer deferredClose(&err, file)
   135  
   136  				// get file size
   137  				if fi, err = file.Stat(); err == nil {
   138  					rdr = file
   139  					if fileSize := int(fi.Size()); fileSize < packetSize {
   140  						packetSize = fileSize
   141  					}
   142  				}
   143  			}
   144  		} else {
   145  			err = fmt.Errorf("local file '%s' is not registered", name)
   146  		}
   147  	}
   148  
   149  	// send content packets
   150  	// if packetSize == 0, the Reader contains no data
   151  	if err == nil && packetSize > 0 {
   152  		data := make([]byte, 4+packetSize)
   153  		var n int
   154  		for err == nil {
   155  			n, err = rdr.Read(data[4:])
   156  			if n > 0 {
   157  				if ioErr := mc.writePacket(data[:4+n]); ioErr != nil {
   158  					return ioErr
   159  				}
   160  			}
   161  		}
   162  		if err == io.EOF {
   163  			err = nil
   164  		}
   165  	}
   166  
   167  	// send empty packet (termination)
   168  	if data == nil {
   169  		data = make([]byte, 4)
   170  	}
   171  	if ioErr := mc.writePacket(data[:4]); ioErr != nil {
   172  		return ioErr
   173  	}
   174  
   175  	// read OK packet
   176  	if err == nil {
   177  		return mc.readResultOK()
   178  	}
   179  
   180  	mc.readPacket()
   181  	return err
   182  }