github.com/perlchild/DBShield@v0.0.0-20170924200059-c888d9e40e13/dbshield/utils.go (about)

     1  package dbshield
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"os/signal"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/boltdb/bolt"
    14  	"github.com/nim4/DBShield/dbshield/config"
    15  	"github.com/nim4/DBShield/dbshield/dbms"
    16  	"github.com/nim4/DBShield/dbshield/logger"
    17  	"github.com/nim4/DBShield/dbshield/training"
    18  	"github.com/nim4/DBShield/dbshield/utils"
    19  )
    20  
    21  const (
    22  	mysql = iota
    23  	mssql
    24  	postgres
    25  	db2
    26  	oracle
    27  )
    28  
    29  //initial boltdb database
    30  func initModel(path string) {
    31  	logger.Infof("Internal DB: %s", path)
    32  	if training.DBCon == nil {
    33  		training.DBCon, _ = bolt.Open(path, 0600, nil)
    34  		training.DBCon.Update(func(tx *bolt.Tx) error {
    35  			tx.CreateBucketIfNotExists([]byte("pattern"))
    36  			tx.CreateBucketIfNotExists([]byte("abnormal"))
    37  			b, _ := tx.CreateBucketIfNotExists([]byte("state"))
    38  			v := b.Get([]byte("QueryCounter"))
    39  			if v != nil {
    40  				training.QueryCounter = binary.BigEndian.Uint64(v)
    41  			}
    42  			v = b.Get([]byte("AbnormalCounter"))
    43  			if v != nil {
    44  				training.AbnormalCounter = binary.BigEndian.Uint64(v)
    45  			}
    46  			return nil
    47  		})
    48  	}
    49  
    50  	if config.Config.SyncInterval != 0 {
    51  		training.DBCon.NoSync = true
    52  		ticker := time.NewTicker(config.Config.SyncInterval)
    53  		go func() {
    54  			for range ticker.C {
    55  				training.DBCon.Sync()
    56  			}
    57  		}()
    58  	}
    59  }
    60  
    61  func closeHandlers() {
    62  	if training.DBCon != nil {
    63  		training.DBCon.Update(func(tx *bolt.Tx) error {
    64  			//Supplied value must remain valid for the life of the transaction
    65  			qCount := make([]byte, 8)
    66  			abCount := make([]byte, 8)
    67  
    68  			b := tx.Bucket([]byte("state"))
    69  			binary.BigEndian.PutUint64(qCount, training.QueryCounter)
    70  			b.Put([]byte("QueryCounter"), qCount)
    71  
    72  			binary.BigEndian.PutUint64(abCount, training.AbnormalCounter)
    73  			b.Put([]byte("AbnormalCounter"), abCount)
    74  
    75  			return nil
    76  		})
    77  		training.DBCon.Sync()
    78  		training.DBCon.Close()
    79  	}
    80  	if logger.Output != nil {
    81  		logger.Output.Close()
    82  	}
    83  }
    84  
    85  //catching Interrupts
    86  func signalHandler() {
    87  	term := make(chan os.Signal)
    88  	signal.Notify(term, os.Interrupt)
    89  	<-term
    90  	logger.Info("Shutting down...")
    91  	//Closing open handler politely
    92  	closeHandlers()
    93  }
    94  
    95  //initLogging redirect log output to file/stdout/stderr
    96  func initLogging() {
    97  	err := logger.Init(config.Config.LogPath, config.Config.LogLevel)
    98  	if err != nil {
    99  		panic(err)
   100  	}
   101  }
   102  
   103  //maps database name to corresponding struct
   104  func dbNameToStruct(db string) (d uint, err error) {
   105  	switch strings.ToLower(db) {
   106  	case "db2":
   107  		d = db2
   108  	case "mssql":
   109  		d = mssql
   110  	case "mysql", "mariadb":
   111  		d = mysql
   112  	case "oracle":
   113  		d = oracle
   114  	case "postgres":
   115  		d = postgres
   116  	default:
   117  		err = fmt.Errorf("Unknown DBMS: %s", db)
   118  	}
   119  	return
   120  }
   121  
   122  //generateDBMS instantiate a new instance of DBMS
   123  func generateDBMS() (utils.DBMS, func(io.Reader) ([]byte, error)) {
   124  	switch config.Config.DB {
   125  	case mssql:
   126  		return new(dbms.MSSQL), dbms.MSSQLReadPacket
   127  	case mysql:
   128  		return new(dbms.MySQL), dbms.MySQLReadPacket
   129  	case postgres:
   130  		return new(dbms.Postgres), dbms.ReadPacket //TODO: implement explicit reader
   131  	case oracle:
   132  		return new(dbms.Oracle), dbms.ReadPacket //TODO: implement explicit reader
   133  	case db2:
   134  		return new(dbms.DB2), dbms.ReadPacket //TODO: implement explicit reader
   135  	default:
   136  		return nil, nil
   137  	}
   138  }
   139  
   140  func handleClient(listenConn net.Conn, serverAddr *net.TCPAddr) error {
   141  	d, reader := generateDBMS()
   142  	logger.Debugf("Connected from: %s", listenConn.RemoteAddr())
   143  	serverConn, err := net.DialTCP("tcp", nil, serverAddr)
   144  	if err != nil {
   145  		logger.Warning(err)
   146  		listenConn.Close()
   147  		return err
   148  	}
   149  	if config.Config.Timeout > 0 {
   150  		if err = listenConn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil {
   151  			return err
   152  		}
   153  		if err = serverConn.SetDeadline(time.Now().Add(config.Config.Timeout)); err != nil {
   154  			return err
   155  		}
   156  	}
   157  	logger.Debugf("Connected to: %s", serverConn.RemoteAddr())
   158  	d.SetSockets(listenConn, serverConn)
   159  	d.SetCertificate(config.Config.TLSCertificate, config.Config.TLSPrivateKey)
   160  	d.SetReader(reader)
   161  	err = d.Handler()
   162  	if err != nil {
   163  		logger.Warning(err)
   164  	}
   165  	return err
   166  }