github.com/perlchild/DBShield@v0.0.0-20170924200059-c888d9e40e13/dbshield/dbshield.go (about)

     1  /*
     2  Package dbshield implements the database firewall functionality
     3  */
     4  package dbshield
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"path"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/boltdb/bolt"
    17  	"github.com/nim4/DBShield/dbshield/config"
    18  	"github.com/nim4/DBShield/dbshield/httpserver"
    19  	"github.com/nim4/DBShield/dbshield/logger"
    20  	"github.com/nim4/DBShield/dbshield/sql"
    21  	"github.com/nim4/DBShield/dbshield/training"
    22  )
    23  
    24  //Version of the library
    25  var Version = "1.0.0-beta4"
    26  
    27  var configFile string
    28  
    29  //SetConfigFile of DBShield
    30  func SetConfigFile(cf string) error {
    31  	configFile = cf
    32  	err := config.ParseConfig(configFile)
    33  	if err != nil {
    34  		return err
    35  	}
    36  	return postConfig()
    37  }
    38  
    39  //ShowConfig writes parsed config file as JSON to STDUT
    40  func ShowConfig() error {
    41  	confJSON, err := json.MarshalIndent(config.Config, "", "    ")
    42  	fmt.Println(string(confJSON))
    43  	return err
    44  }
    45  
    46  //Purge local database
    47  func Purge() error {
    48  	return os.Remove(path.Join(config.Config.DBDir,
    49  		config.Config.TargetIP+"_"+config.Config.DBType) + ".db")
    50  }
    51  
    52  //Patterns lists the captured patterns
    53  func Patterns() (count int) {
    54  	initModel(
    55  		path.Join(config.Config.DBDir,
    56  			config.Config.TargetIP+"_"+config.Config.DBType) + ".db")
    57  
    58  	training.DBCon.View(func(tx *bolt.Tx) error {
    59  		b := tx.Bucket([]byte("pattern"))
    60  		if b != nil {
    61  			return b.ForEach(func(k, v []byte) error {
    62  				if strings.Index(string(k), "_client_") == -1 && strings.Index(string(k), "_user_") == -1 {
    63  					fmt.Printf(
    64  						`-----Pattern: 0x%x
    65  Sample: %s
    66  `,
    67  						k,
    68  						v,
    69  					)
    70  					count++
    71  				}
    72  				return nil
    73  			})
    74  		}
    75  		return nil
    76  	})
    77  	return
    78  }
    79  
    80  //Abnormals detected querties
    81  func Abnormals() (count int) {
    82  	initModel(
    83  		path.Join(config.Config.DBDir,
    84  			config.Config.TargetIP+"_"+config.Config.DBType) + ".db")
    85  
    86  	training.DBCon.View(func(tx *bolt.Tx) error {
    87  		b := tx.Bucket([]byte("abnormal"))
    88  		if b != nil {
    89  			return b.ForEach(func(k, v []byte) error {
    90  				var c sql.QueryContext
    91  				c.Unmarshal(v)
    92  				fmt.Printf("[%s] [User: %s] [Database: %s] %s\n",
    93  					c.Time.Format(time.RFC1123),
    94  					c.User,
    95  					c.Database,
    96  					c.Query)
    97  				count++
    98  				return nil
    99  			})
   100  		}
   101  		return nil
   102  	})
   103  	return count
   104  }
   105  
   106  //RemovePattern deletes a pattern from captured patterns DB
   107  func RemovePattern(pattern string) error {
   108  	initModel(
   109  		path.Join(config.Config.DBDir,
   110  			config.Config.TargetIP+"_"+config.Config.DBType) + ".db")
   111  
   112  	return training.DBCon.Update(func(tx *bolt.Tx) error {
   113  		b := tx.Bucket([]byte("pattern"))
   114  		if b != nil {
   115  			return b.Delete([]byte(pattern))
   116  		}
   117  		return nil
   118  	})
   119  }
   120  
   121  func postConfig() (err error) {
   122  
   123  	config.Config.DB, err = dbNameToStruct(config.Config.DBType)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	tmpDBMS, _ := generateDBMS()
   129  	if config.Config.ListenPort == 0 {
   130  		config.Config.ListenPort = tmpDBMS.DefaultPort()
   131  	}
   132  	if config.Config.TargetPort == 0 {
   133  		config.Config.TargetPort = tmpDBMS.DefaultPort()
   134  	}
   135  	return
   136  }
   137  
   138  func mainListner() error {
   139  	if config.Config.HTTP {
   140  		proto := "http"
   141  		if config.Config.HTTPSSL {
   142  			proto = "https"
   143  		}
   144  		logger.Infof("Web interface on %s://%s/", proto, config.Config.HTTPAddr)
   145  		go httpserver.Serve()
   146  	}
   147  	serverAddr, _ := net.ResolveTCPAddr("tcp", config.Config.TargetIP+":"+strconv.Itoa(int(config.Config.TargetPort)))
   148  	l, err := net.Listen("tcp", config.Config.ListenIP+":"+strconv.Itoa(int(config.Config.ListenPort)))
   149  	if err != nil {
   150  		return err
   151  	}
   152  	// Close the listener when the application closes.
   153  	defer l.Close()
   154  
   155  	for {
   156  		// Listen for an incoming connection.
   157  		listenConn, err := l.Accept()
   158  		if err != nil {
   159  			logger.Warningf("Error accepting connection: %v", err)
   160  			continue
   161  		}
   162  		go handleClient(listenConn, serverAddr)
   163  	}
   164  }
   165  
   166  //Start the proxy
   167  func Start() (err error) {
   168  	initModel(
   169  		path.Join(config.Config.DBDir,
   170  			config.Config.TargetIP+"_"+config.Config.DBType) + ".db")
   171  
   172  	initLogging()
   173  	logger.Infof("Config file: %s", configFile)
   174  	logger.Infof("Listening: %s:%v",
   175  		config.Config.ListenIP,
   176  		config.Config.ListenPort)
   177  	logger.Infof("Backend: %s (%s:%v)",
   178  		config.Config.DBType,
   179  		config.Config.TargetIP,
   180  		config.Config.TargetPort)
   181  	logger.Infof("Protect: %v", !config.Config.Learning)
   182  	go mainListner()
   183  	signalHandler()
   184  	return nil
   185  }