github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/httpserver/server_test.go (about) 1 package httpserver 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "strings" 12 "testing" 13 "time" 14 15 "github.com/boltdb/bolt" 16 "github.com/nim4/DBShield/dbshield/config" 17 "github.com/nim4/DBShield/dbshield/sql" 18 "github.com/nim4/DBShield/dbshield/training" 19 ) 20 21 func TestMain(m *testing.M) { 22 config.Config.HTTPAddr = ":-1" 23 config.Config.HTTPPassword = "foo" 24 25 tmpfile, err := ioutil.TempFile("", "testdb") 26 if err != nil { 27 panic(err) 28 } 29 defer tmpfile.Close() 30 path := tmpfile.Name() 31 training.DBCon, err = bolt.Open(path, 0600, nil) 32 if err != nil { 33 panic(err) 34 } 35 training.DBCon.Update(func(tx *bolt.Tx) error { 36 tx.CreateBucket([]byte("pattern")) 37 tx.CreateBucket([]byte("abnormal")) 38 tx.CreateBucket([]byte("state")) 39 return nil 40 }) 41 m.Run() 42 } 43 44 func TestServe(t *testing.T) { 45 err := Serve() 46 if err == nil { 47 t.Error("Expected error") 48 } 49 config.Config.HTTPSSL = true 50 err = Serve() 51 if err == nil { 52 t.Error("Expected error") 53 } 54 } 55 56 func TestMainHandler(t *testing.T) { 57 r, err := http.NewRequest("GET", "/", nil) 58 if err != nil { 59 t.Error("Got an error ", err) 60 } 61 62 w := httptest.NewRecorder() 63 mainHandler(w, r) 64 65 if w.Code != 200 { 66 t.Error("Expected 200 got ", w.Code) 67 } 68 69 r, err = http.NewRequest("GET", "/", nil) 70 if err != nil { 71 t.Error("Got an error ", err) 72 } 73 w = httptest.NewRecorder() 74 setSession(w) 75 r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie")) 76 mainHandler(w, r) 77 body, err := ioutil.ReadAll(w.Body) 78 if err != nil { 79 t.Error("Got an error ", err) 80 } 81 if strings.Index(string(body), "\"Logout\"") == -1 { 82 t.Error("Expected report page") 83 } 84 } 85 86 func TestAPIHandler(t *testing.T) { 87 defer recover() 88 r, err := http.NewRequest("GET", "/", nil) 89 if err != nil { 90 t.Error("Got an error ", err) 91 } 92 w := httptest.NewRecorder() 93 apiHandler(w, r) 94 body, err := ioutil.ReadAll(w.Body) 95 if err != nil { 96 t.Error("Got an error ", err) 97 } 98 if len(body) != 0 { 99 t.Error("Expected 0 length got", len(body)) 100 } 101 setSession(w) 102 r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie")) 103 apiHandler(w, r) 104 body, _ = ioutil.ReadAll(w.Body) 105 var j struct { 106 Total int 107 Abnormal int 108 } 109 err = json.Unmarshal(body, &j) 110 if err != nil { 111 t.Error("Got an error ", err) 112 } 113 if j.Total != 0 || j.Abnormal != 0 { 114 t.Error("Expected 0, 0 got", j) 115 } 116 117 j.Total = 1 118 j.Abnormal = 1 119 120 c1 := sql.QueryContext{ 121 Query: []byte("select * from test;"), 122 Database: []byte("test"), 123 User: []byte("test"), 124 Client: []byte("127.0.0.1"), 125 Time: time.Now(), 126 } 127 c2 := sql.QueryContext{ 128 Query: []byte("select * from user;"), 129 Database: []byte("test"), 130 User: []byte("test"), 131 Client: []byte("127.0.0.1"), 132 Time: time.Now(), 133 } 134 err = training.AddToTrainingSet(c1) 135 if err != nil { 136 t.Error("Got an error ", err) 137 } 138 training.CheckQuery(c2) 139 apiHandler(w, r) 140 body, _ = ioutil.ReadAll(w.Body) 141 err = json.Unmarshal(body, &j) 142 fmt.Println(j) 143 if err != nil { 144 t.Error("Got an error ", err) 145 } 146 if j.Total != 2 || j.Abnormal != 1 { 147 t.Error("Expected 1, 1 got", j) 148 } 149 150 tmpCon := training.DBCon 151 defer func() { 152 training.DBCon = tmpCon 153 }() 154 tmpfile, err := ioutil.TempFile("", "testdb") 155 if err != nil { 156 panic(err) 157 } 158 defer tmpfile.Close() 159 path := tmpfile.Name() 160 training.DBCon, err = bolt.Open(path, 0600, nil) 161 if err != nil { 162 t.Error("Got an error ", err) 163 } 164 apiHandler(w, r) 165 body, err = ioutil.ReadAll(w.Body) 166 if err != nil { 167 t.Error("Got an error ", err) 168 } 169 err = json.Unmarshal(body, &j) 170 if err != nil { 171 t.Error("Got an error ", err) 172 } 173 174 apiHandler(w, r) 175 body, err = ioutil.ReadAll(w.Body) 176 if err != nil { 177 t.Error("Got an error ", err) 178 } 179 err = json.Unmarshal(body, &j) 180 if err != nil { 181 t.Error("Got an error ", err) 182 } 183 } 184 185 func TestLoginHandler(t *testing.T) { 186 data := url.Values{} 187 data.Add("password", "bar") 188 r, err := http.NewRequest("POST", "/", bytes.NewBufferString(data.Encode())) 189 if err != nil { 190 t.Error("Got an error ", err) 191 } 192 193 r.Header.Add("Content-Type", "application/x-www-form-urlencoded") 194 195 w := httptest.NewRecorder() 196 loginHandler(w, r) 197 198 if w.Code != 302 { 199 t.Error("Expected 302 got ", w.Code) 200 } 201 202 if w.HeaderMap.Get("Location") != "/" { 203 t.Error("Expected / got ", w.HeaderMap.Get("Location")) 204 } 205 206 data.Set("password", config.Config.HTTPPassword) 207 r, err = http.NewRequest("POST", "/", bytes.NewBufferString(data.Encode())) 208 if err != nil { 209 t.Error("Got an error ", err) 210 } 211 212 r.Header.Add("Content-Type", "application/x-www-form-urlencoded") 213 214 w = httptest.NewRecorder() 215 loginHandler(w, r) 216 217 if w.Code != 302 { 218 t.Error("Expected 302 got ", w.Code) 219 } 220 221 if w.HeaderMap.Get("Location") != "/report.htm" { 222 t.Error("Expected /report.htm got ", w.HeaderMap.Get("Location")) 223 } 224 } 225 226 func TestLogoutHandler(t *testing.T) { 227 r, err := http.NewRequest("POST", "/", nil) 228 if err != nil { 229 t.Error("Got an error ", err) 230 } 231 232 w := httptest.NewRecorder() 233 logoutHandler(w, r) 234 235 if w.Code != 302 { 236 t.Error("Expected 302 got ", w.Code) 237 } 238 } 239 240 func TestCheckLogin(t *testing.T) { 241 r, err := http.NewRequest("GET", "/", nil) 242 if err != nil { 243 t.Error("Got an error ", err) 244 } 245 246 cookie := &http.Cookie{ 247 Name: "session", 248 Value: "XYZ", 249 Path: "/", 250 Secure: true, 251 HttpOnly: true, 252 } 253 r.AddCookie(cookie) 254 if checkLogin(r) { 255 t.Error("Expected false got true") 256 } 257 258 r, err = http.NewRequest("GET", "/", nil) 259 if err != nil { 260 t.Error("Got an error ", err) 261 } 262 w := httptest.NewRecorder() 263 setSession(w) 264 r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie")) 265 266 if !checkLogin(r) { 267 t.Error("Expected true got false") 268 } 269 } 270 271 func TestSetSession(t *testing.T) { 272 w := httptest.NewRecorder() 273 setSession(w) 274 if strings.Index(w.HeaderMap.Get("Set-Cookie"), "session=") == -1 { 275 t.Error("Expected session cookie") 276 } 277 } 278 279 func TestClearSession(t *testing.T) { 280 w := httptest.NewRecorder() 281 clearSession(w) 282 if strings.Index(w.HeaderMap.Get("Set-Cookie"), "session=;") == -1 { 283 t.Error("Expected empty session cookie") 284 } 285 }