github.com/perlchild/DBShield@v0.0.0-20170924200059-c888d9e40e13/dbshield/utils.go (about) 1 package dbshield 2 3 import ( 4 "encoding/binary" 5 "fmt" 6 "io" 7 "net" 8 "os" 9 "os/signal" 10 "strings" 11 "time" 12 13 "github.com/boltdb/bolt" 14 "github.com/nim4/DBShield/dbshield/config" 15 "github.com/nim4/DBShield/dbshield/dbms" 16 "github.com/nim4/DBShield/dbshield/logger" 17 "github.com/nim4/DBShield/dbshield/training" 18 "github.com/nim4/DBShield/dbshield/utils" 19 ) 20 21 const ( 22 mysql = iota 23 mssql 24 postgres 25 db2 26 oracle 27 ) 28 29 //initial boltdb database 30 func initModel(path string) { 31 logger.Infof("Internal DB: %s", path) 32 if training.DBCon == nil { 33 training.DBCon, _ = bolt.Open(path, 0600, nil) 34 training.DBCon.Update(func(tx *bolt.Tx) error { 35 tx.CreateBucketIfNotExists([]byte("pattern")) 36 tx.CreateBucketIfNotExists([]byte("abnormal")) 37 b, _ := tx.CreateBucketIfNotExists([]byte("state")) 38 v := b.Get([]byte("QueryCounter")) 39 if v != nil { 40 training.QueryCounter = binary.BigEndian.Uint64(v) 41 } 42 v = b.Get([]byte("AbnormalCounter")) 43 if v != nil { 44 training.AbnormalCounter = binary.BigEndian.Uint64(v) 45 } 46 return nil 47 }) 48 } 49 50 if config.Config.SyncInterval != 0 { 51 training.DBCon.NoSync = true 52 ticker := time.NewTicker(config.Config.SyncInterval) 53 go func() { 54 for range ticker.C { 55 training.DBCon.Sync() 56 } 57 }() 58 } 59 } 60 61 func closeHandlers() { 62 if training.DBCon != nil { 63 training.DBCon.Update(func(tx *bolt.Tx) error { 64 //Supplied value must remain valid for the life of the transaction 65 qCount := make([]byte, 8) 66 abCount := make([]byte, 8) 67 68 b := tx.Bucket([]byte("state")) 69 binary.BigEndian.PutUint64(qCount, training.QueryCounter) 70 b.Put([]byte("QueryCounter"), qCount) 71 72 binary.BigEndian.PutUint64(abCount, training.AbnormalCounter) 73 b.Put([]byte("AbnormalCounter"), abCount) 74 75 return nil 76 }) 77 training.DBCon.Sync() 78 training.DBCon.Close() 79 } 80 if logger.Output != nil { 81 logger.Output.Close() 82 } 83 } 84 85 //catching Interrupts 86 func signalHandler() { 87 term := make(chan os.Signal) 88 signal.Notify(term, os.Interrupt) 89 <-term 90 logger.Info("Shutting down...") 91 //Closing open handler politely 92 closeHandlers() 93 } 94 95 //initLogging redirect log output to file/stdout/stderr 96 func initLogging() { 97 err := logger.Init(config.Config.LogPath, config.Config.LogLevel) 98 if err != nil { 99 panic(err) 100 } 101 } 102 103 //maps database name to corresponding struct 104 func dbNameToStruct(db string) (d uint, err error) { 105 switch strings.ToLower(db) { 106 case "db2": 107 d = db2 108 case "mssql": 109 d = mssql 110 case "mysql", "mariadb": 111 d = mysql 112 case "oracle": 113 d = oracle 114 case "postgres": 115 d = postgres 116 default: 117 err = fmt.Errorf("Unknown DBMS: %s", db) 118 } 119 return 120 } 121 122 //generateDBMS instantiate a new instance of DBMS 123 func generateDBMS() (utils.DBMS, func(io.Reader) ([]byte, error)) { 124 switch config.Config.DB { 125 case mssql: 126 return new(dbms.MSSQL), dbms.MSSQLReadPacket 127 case mysql: 128 return new(dbms.MySQL), dbms.MySQLReadPacket 129 case postgres: 130 return new(dbms.Postgres), dbms.ReadPacket //TODO: implement explicit reader 131 case oracle: 132 return new(dbms.Oracle), dbms.ReadPacket //TODO: implement explicit reader 133 case db2: 134 return new(dbms.DB2), dbms.ReadPacket //TODO: implement explicit reader 135 default: 136 return nil, nil 137 } 138 } 139 140 func handleClient(listenConn net.Conn, serverAddr *net.TCPAddr) error { 141 d, reader := generateDBMS() 142 logger.Debugf("Connected from: %s", listenConn.RemoteAddr()) 143 serverConn, err := net.DialTCP("tcp", nil, serverAddr) 144 if err != nil { 145 logger.Warning(err) 146 listenConn.Close() 147 return err 148 } 149 if config.Config.Timeout > 0 { 150 if err = listenConn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil { 151 return err 152 } 153 if err = serverConn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil { 154 return err 155 } 156 } 157 logger.Debugf("Connected to: %s", serverConn.RemoteAddr()) 158 d.SetSockets(listenConn, serverConn) 159 d.SetCertificate(config.Config.TLSCertificate, config.Config.TLSPrivateKey) 160 d.SetReader(reader) 161 err = d.Handler() 162 if err != nil { 163 logger.Warning(err) 164 } 165 return err 166 }