github.com/perlchild/DBShield@v0.0.0-20170924200059-c888d9e40e13/dbshield/dbshield_test.go (about) 1 // +build !windows 2 3 package dbshield 4 5 import ( 6 "fmt" 7 "io/ioutil" 8 "net" 9 "os" 10 "syscall" 11 "testing" 12 "time" 13 14 "github.com/boltdb/bolt" 15 "github.com/nim4/DBShield/dbshield/config" 16 "github.com/nim4/DBShield/dbshield/sql" 17 "github.com/nim4/DBShield/dbshield/training" 18 ) 19 20 func TestMain(m *testing.M) { 21 os.Chdir("../") 22 m.Run() 23 } 24 25 func TestSetConfigFile(t *testing.T) { 26 err := SetConfigFile("Invalid.yml") 27 if err == nil { 28 t.Error("Expected error") 29 } 30 } 31 32 func TestShowConfig(t *testing.T) { 33 SetConfigFile("conf/dbshield.yml") 34 err := ShowConfig() 35 if err != nil { 36 t.Error("Got error", err) 37 } 38 } 39 40 func TestPurge(t *testing.T) { 41 SetConfigFile("conf/dbshield.yml") 42 err := Purge() 43 if err == nil { 44 t.Error("Expected error") 45 } 46 } 47 48 func TestPostConfig(t *testing.T) { 49 SetConfigFile("conf/dbshield.yml") 50 config.Config.DBType = "Invalid" 51 err := postConfig() 52 if err == nil { 53 t.Error("Expected error") 54 } 55 56 config.Config.ListenPort = 0 57 config.Config.DBType = "mysql" 58 err = postConfig() 59 if err != nil { 60 t.Error("Expected nil got ", err) 61 } 62 } 63 64 func TestEveryThing(t *testing.T) { 65 closeHandlers() 66 SetConfigFile("conf/dbshield.yml") 67 //It should fail if port is already open 68 l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", config.Config.ListenIP, config.Config.ListenPort)) 69 if err != nil { 70 t.Fatal(err) 71 } 72 defer l.Close() 73 74 err = mainListner() 75 if err == nil { 76 t.Error("Expected error") 77 } 78 79 go func() { 80 timer := time.NewTimer(time.Second * 2) 81 <-timer.C 82 syscall.Kill(syscall.Getpid(), syscall.SIGINT) 83 }() 84 err = Start() 85 if err != nil { 86 t.Error("Got error", err) 87 } 88 file, _ := ioutil.TempFile(os.TempDir(), "tempDB") 89 defer os.Remove(file.Name()) 90 training.DBCon, _ = bolt.Open(file.Name(), 0600, nil) 91 training.DBCon.Update(func(tx *bolt.Tx) error { 92 tx.CreateBucket([]byte("pattern")) 93 tx.CreateBucket([]byte("abnormal")) 94 tx.CreateBucket([]byte("state")) 95 return nil 96 }) 97 query := []byte("select * from test;") 98 c := sql.QueryContext{ 99 Query: query, 100 Database: []byte("test"), 101 User: []byte("test"), 102 Client: []byte("127.0.0.1"), 103 Time: time.Now(), 104 } 105 training.CheckQuery(c) 106 err = training.AddToTrainingSet(c) 107 if err != nil { 108 t.Error("Got error", err) 109 } 110 pattern := sql.Pattern(query) 111 112 count := Patterns() 113 if count != 1 { 114 t.Error("Expected 1 got", count) 115 } 116 117 count = Abnormals() 118 if count != 0 { 119 t.Error("Expected 0 got", count) 120 } 121 122 err = RemovePattern(string(pattern)) 123 if err != nil { 124 t.Error("Expected nil got", err) 125 } 126 127 count = Patterns() 128 if count != 0 { 129 t.Error("Expected 0 got", count) 130 } 131 132 //Test without bucket 133 tmpCon := training.DBCon 134 defer func() { 135 training.DBCon = tmpCon 136 }() 137 tmpfile, err := ioutil.TempFile("", "testdb") 138 if err != nil { 139 panic(err) 140 } 141 defer tmpfile.Close() 142 path := tmpfile.Name() 143 training.DBCon, err = bolt.Open(path, 0600, nil) 144 if err != nil { 145 panic(err) 146 } 147 148 count = Patterns() 149 if count != 0 { 150 t.Error("Expected 0 got", count) 151 } 152 153 count = Abnormals() 154 if count != 0 { 155 t.Error("Expected 0 got", count) 156 } 157 158 err = RemovePattern(string(pattern)) 159 if err == nil { 160 t.Error("Expected error got", err) 161 } 162 }