github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/dbshield_test.go (about)

     1  // +build !windows
     2  
     3  package dbshield
     4  
     5  import (
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net"
     9  	"os"
    10  	"syscall"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/boltdb/bolt"
    15  	"github.com/nim4/DBShield/dbshield/config"
    16  	"github.com/nim4/DBShield/dbshield/sql"
    17  	"github.com/nim4/DBShield/dbshield/training"
    18  )
    19  
    20  func TestMain(m *testing.M) {
    21  	os.Chdir("../")
    22  	m.Run()
    23  }
    24  
    25  func TestSetConfigFile(t *testing.T) {
    26  	err := SetConfigFile("Invalid.yml")
    27  	if err == nil {
    28  		t.Error("Expected error")
    29  	}
    30  }
    31  
    32  func TestShowConfig(t *testing.T) {
    33  	SetConfigFile("conf/dbshield.yml")
    34  	err := ShowConfig()
    35  	if err != nil {
    36  		t.Error("Got error", err)
    37  	}
    38  }
    39  
    40  func TestPurge(t *testing.T) {
    41  	SetConfigFile("conf/dbshield.yml")
    42  	err := Purge()
    43  	if err == nil {
    44  		t.Error("Expected error")
    45  	}
    46  }
    47  
    48  func TestPostConfig(t *testing.T) {
    49  	SetConfigFile("conf/dbshield.yml")
    50  	config.Config.DBType = "Invalid"
    51  	err := postConfig()
    52  	if err == nil {
    53  		t.Error("Expected error")
    54  	}
    55  
    56  	config.Config.ListenPort = 0
    57  	config.Config.DBType = "mysql"
    58  	err = postConfig()
    59  	if err != nil {
    60  		t.Error("Expected nil got ", err)
    61  	}
    62  }
    63  
    64  func TestEveryThing(t *testing.T) {
    65  	closeHandlers()
    66  	SetConfigFile("conf/dbshield.yml")
    67  	//It should fail if port is already open
    68  	l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", config.Config.ListenIP, config.Config.ListenPort))
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  	defer l.Close()
    73  
    74  	err = mainListner()
    75  	if err == nil {
    76  		t.Error("Expected error")
    77  	}
    78  
    79  	go func() {
    80  		timer := time.NewTimer(time.Second * 2)
    81  		<-timer.C
    82  		syscall.Kill(syscall.Getpid(), syscall.SIGINT)
    83  	}()
    84  	err = Start()
    85  	if err != nil {
    86  		t.Error("Got error", err)
    87  	}
    88  	file, _ := ioutil.TempFile(os.TempDir(), "tempDB")
    89  	defer os.Remove(file.Name())
    90  	training.DBCon, _ = bolt.Open(file.Name(), 0600, nil)
    91  	training.DBCon.Update(func(tx *bolt.Tx) error {
    92  		tx.CreateBucket([]byte("pattern"))
    93  		tx.CreateBucket([]byte("abnormal"))
    94  		tx.CreateBucket([]byte("state"))
    95  		return nil
    96  	})
    97  	query := []byte("select * from test;")
    98  	c := sql.QueryContext{
    99  		Query:    query,
   100  		Database: []byte("test"),
   101  		User:     []byte("test"),
   102  		Client:   []byte("127.0.0.1"),
   103  		Time:     time.Now(),
   104  	}
   105  	training.CheckQuery(c)
   106  	err = training.AddToTrainingSet(c)
   107  	if err != nil {
   108  		t.Error("Got error", err)
   109  	}
   110  	pattern := sql.Pattern(query)
   111  
   112  	count := Patterns()
   113  	if count != 1 {
   114  		t.Error("Expected 1 got", count)
   115  	}
   116  
   117  	count = Abnormals()
   118  	if count != 0 {
   119  		t.Error("Expected 0 got", count)
   120  	}
   121  
   122  	err = RemovePattern(string(pattern))
   123  	if err != nil {
   124  		t.Error("Expected nil got", err)
   125  	}
   126  
   127  	count = Patterns()
   128  	if count != 0 {
   129  		t.Error("Expected 0 got", count)
   130  	}
   131  
   132  	//Test without bucket
   133  	tmpCon := training.DBCon
   134  	defer func() {
   135  		training.DBCon = tmpCon
   136  	}()
   137  	tmpfile, err := ioutil.TempFile("", "testdb")
   138  	if err != nil {
   139  		panic(err)
   140  	}
   141  	defer tmpfile.Close()
   142  	path := tmpfile.Name()
   143  	training.DBCon, err = bolt.Open(path, 0600, nil)
   144  	if err != nil {
   145  		panic(err)
   146  	}
   147  
   148  	count = Patterns()
   149  	if count != 0 {
   150  		t.Error("Expected 0 got", count)
   151  	}
   152  
   153  	count = Abnormals()
   154  	if count != 0 {
   155  		t.Error("Expected 0 got", count)
   156  	}
   157  
   158  	err = RemovePattern(string(pattern))
   159  	if err == nil {
   160  		t.Error("Expected error got", err)
   161  	}
   162  }