github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/utils.go (about) 1 package dbshield 2 3 import ( 4 "fmt" 5 "io" 6 "net" 7 "os" 8 "os/signal" 9 "strings" 10 "time" 11 12 "github.com/qiwihui/DBShield/dbshield/config" 13 "github.com/qiwihui/DBShield/dbshield/dbms" 14 "github.com/qiwihui/DBShield/dbshield/logger" 15 "github.com/qiwihui/DBShield/dbshield/utils" 16 ) 17 18 const ( 19 mysql = iota 20 mssql 21 postgres 22 db2 23 oracle 24 ) 25 26 func closeHandlers() { 27 28 //TODO NEED to verify 29 if config.Config.LocalDB != nil { 30 config.Config.LocalDB.UpdateState() 31 config.Config.LocalDB.SyncAndClose() 32 } 33 if logger.Output != nil { 34 logger.Output.Close() 35 } 36 } 37 38 //catching Interrupts 39 func signalHandler() { 40 term := make(chan os.Signal) 41 signal.Notify(term, os.Interrupt) 42 <-term 43 logger.Info("Shutting down...") 44 //Closing open handler politely 45 closeHandlers() 46 } 47 48 //initLogging redirect log output to file/stdout/stderr 49 func initLogging() { 50 err := logger.Init(config.Config.LogPath, config.Config.LogLevel) 51 if err != nil { 52 panic(err) 53 } 54 } 55 56 //maps database name to corresponding struct 57 func dbNameToStruct(db string) (d uint, err error) { 58 switch strings.ToLower(db) { 59 case "db2": 60 d = db2 61 case "mssql": 62 d = mssql 63 case "mysql", "mariadb": 64 d = mysql 65 case "oracle": 66 d = oracle 67 case "postgres": 68 d = postgres 69 default: 70 err = fmt.Errorf("Unknown DBMS: %s", db) 71 } 72 return 73 } 74 75 //generateDBMS instantiate a new instance of DBMS 76 func generateDBMS() (utils.DBMS, func(io.Reader) ([]byte, error)) { 77 switch config.Config.DB { 78 case mssql: 79 return new(dbms.MSSQL), dbms.MSSQLReadPacket 80 case mysql: 81 return new(dbms.MySQL), dbms.MySQLReadPacket 82 case postgres: 83 return new(dbms.Postgres), dbms.ReadPacket //TODO: implement explicit reader 84 case oracle: 85 return new(dbms.Oracle), dbms.ReadPacket //TODO: implement explicit reader 86 case db2: 87 return new(dbms.DB2), dbms.ReadPacket //TODO: implement explicit reader 88 default: 89 return nil, nil 90 } 91 } 92 93 func handleClient(listenConn net.Conn, serverAddr *net.TCPAddr) error { 94 d, reader := generateDBMS() 95 // delay 96 // tcpConn := listenConn.(*net.TCPConn) 97 // tcpConn.SetNoDelay(false) 98 // // tcpConn.SetKeepAlive(true) 99 // listenConn = tcpConn 100 101 logger.Debugf("Connected from: %s", listenConn.RemoteAddr()) 102 serverConn, err := net.DialTCP("tcp", nil, serverAddr) 103 if err != nil { 104 logger.Warning(err) 105 listenConn.Close() 106 return err 107 } 108 // serverConn.SetNoDelay(false) 109 // serverConn.SetKeepAlive(true) 110 // if err := SetConnTimeout(listenConn); err != nil { 111 // return err 112 // } 113 // if err := SetConnTimeout(serverConn); err != nil { 114 // return err 115 // } 116 117 if config.Config.Timeout > 0 { 118 tcpCli := listenConn.(*net.TCPConn) 119 tcpCli.SetNoDelay(false) 120 tcpCli.SetKeepAlive(true) 121 serverConn.SetNoDelay(false) 122 serverConn.SetKeepAlive(true) 123 } 124 125 logger.Debugf("Connected to: %s", serverConn.RemoteAddr()) 126 d.SetSockets(listenConn, serverConn) 127 d.SetCertificate(config.Config.TLSCertificate, config.Config.TLSPrivateKey) 128 d.SetReader(reader) 129 err = d.Handler() 130 if err != nil { 131 logger.Warning(err) 132 } 133 return err 134 } 135 136 // SetConnTimeout for connection 137 func SetConnTimeout(conn net.Conn) error { 138 if config.Config.Timeout > 0 { 139 if err := conn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil { 140 return err 141 } 142 } 143 return nil 144 }