github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/utils.go (about)

     1  package dbshield
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"os"
     8  	"os/signal"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/qiwihui/DBShield/dbshield/config"
    13  	"github.com/qiwihui/DBShield/dbshield/dbms"
    14  	"github.com/qiwihui/DBShield/dbshield/logger"
    15  	"github.com/qiwihui/DBShield/dbshield/utils"
    16  )
    17  
    18  const (
    19  	mysql = iota
    20  	mssql
    21  	postgres
    22  	db2
    23  	oracle
    24  )
    25  
    26  func closeHandlers() {
    27  
    28  	//TODO NEED to verify
    29  	if config.Config.LocalDB != nil {
    30  		config.Config.LocalDB.UpdateState()
    31  		config.Config.LocalDB.SyncAndClose()
    32  	}
    33  	if logger.Output != nil {
    34  		logger.Output.Close()
    35  	}
    36  }
    37  
    38  //catching Interrupts
    39  func signalHandler() {
    40  	term := make(chan os.Signal)
    41  	signal.Notify(term, os.Interrupt)
    42  	<-term
    43  	logger.Info("Shutting down...")
    44  	//Closing open handler politely
    45  	closeHandlers()
    46  }
    47  
    48  //initLogging redirect log output to file/stdout/stderr
    49  func initLogging() {
    50  	err := logger.Init(config.Config.LogPath, config.Config.LogLevel)
    51  	if err != nil {
    52  		panic(err)
    53  	}
    54  }
    55  
    56  //maps database name to corresponding struct
    57  func dbNameToStruct(db string) (d uint, err error) {
    58  	switch strings.ToLower(db) {
    59  	case "db2":
    60  		d = db2
    61  	case "mssql":
    62  		d = mssql
    63  	case "mysql", "mariadb":
    64  		d = mysql
    65  	case "oracle":
    66  		d = oracle
    67  	case "postgres":
    68  		d = postgres
    69  	default:
    70  		err = fmt.Errorf("Unknown DBMS: %s", db)
    71  	}
    72  	return
    73  }
    74  
    75  //generateDBMS instantiate a new instance of DBMS
    76  func generateDBMS() (utils.DBMS, func(io.Reader) ([]byte, error)) {
    77  	switch config.Config.DB {
    78  	case mssql:
    79  		return new(dbms.MSSQL), dbms.MSSQLReadPacket
    80  	case mysql:
    81  		return new(dbms.MySQL), dbms.MySQLReadPacket
    82  	case postgres:
    83  		return new(dbms.Postgres), dbms.ReadPacket //TODO: implement explicit reader
    84  	case oracle:
    85  		return new(dbms.Oracle), dbms.ReadPacket //TODO: implement explicit reader
    86  	case db2:
    87  		return new(dbms.DB2), dbms.ReadPacket //TODO: implement explicit reader
    88  	default:
    89  		return nil, nil
    90  	}
    91  }
    92  
    93  func handleClient(listenConn net.Conn, serverAddr *net.TCPAddr) error {
    94  	d, reader := generateDBMS()
    95  	// delay
    96  	// tcpConn := listenConn.(*net.TCPConn)
    97  	// tcpConn.SetNoDelay(false)
    98  	// // tcpConn.SetKeepAlive(true)
    99  	// listenConn = tcpConn
   100  
   101  	logger.Debugf("Connected from: %s", listenConn.RemoteAddr())
   102  	serverConn, err := net.DialTCP("tcp", nil, serverAddr)
   103  	if err != nil {
   104  		logger.Warning(err)
   105  		listenConn.Close()
   106  		return err
   107  	}
   108  	// serverConn.SetNoDelay(false)
   109  	// serverConn.SetKeepAlive(true)
   110  	// if err := SetConnTimeout(listenConn); err != nil {
   111  	// 	return err
   112  	// }
   113  	// if err := SetConnTimeout(serverConn); err != nil {
   114  	// 	return err
   115  	// }
   116  
   117  	if config.Config.Timeout > 0 {
   118  		tcpCli := listenConn.(*net.TCPConn)
   119  		tcpCli.SetNoDelay(false)
   120  		tcpCli.SetKeepAlive(true)
   121  		serverConn.SetNoDelay(false)
   122  		serverConn.SetKeepAlive(true)
   123  	}
   124  
   125  	logger.Debugf("Connected to: %s", serverConn.RemoteAddr())
   126  	d.SetSockets(listenConn, serverConn)
   127  	d.SetCertificate(config.Config.TLSCertificate, config.Config.TLSPrivateKey)
   128  	d.SetReader(reader)
   129  	err = d.Handler()
   130  	if err != nil {
   131  		logger.Warning(err)
   132  	}
   133  	return err
   134  }
   135  
   136  // SetConnTimeout for connection
   137  func SetConnTimeout(conn net.Conn) error {
   138  	if config.Config.Timeout > 0 {
   139  		if err := conn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil {
   140  			return err
   141  		}
   142  	}
   143  	return nil
   144  }