github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/httpserver/server_test.go (about)

     1  package httpserver
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/boltdb/bolt"
    16  	"github.com/nim4/DBShield/dbshield/config"
    17  	"github.com/nim4/DBShield/dbshield/sql"
    18  	"github.com/nim4/DBShield/dbshield/training"
    19  )
    20  
    21  func TestMain(m *testing.M) {
    22  	config.Config.HTTPAddr = ":-1"
    23  	config.Config.HTTPPassword = "foo"
    24  
    25  	tmpfile, err := ioutil.TempFile("", "testdb")
    26  	if err != nil {
    27  		panic(err)
    28  	}
    29  	defer tmpfile.Close()
    30  	path := tmpfile.Name()
    31  	training.DBCon, err = bolt.Open(path, 0600, nil)
    32  	if err != nil {
    33  		panic(err)
    34  	}
    35  	training.DBCon.Update(func(tx *bolt.Tx) error {
    36  		tx.CreateBucket([]byte("pattern"))
    37  		tx.CreateBucket([]byte("abnormal"))
    38  		tx.CreateBucket([]byte("state"))
    39  		return nil
    40  	})
    41  	m.Run()
    42  }
    43  
    44  func TestServe(t *testing.T) {
    45  	err := Serve()
    46  	if err == nil {
    47  		t.Error("Expected error")
    48  	}
    49  	config.Config.HTTPSSL = true
    50  	err = Serve()
    51  	if err == nil {
    52  		t.Error("Expected error")
    53  	}
    54  }
    55  
    56  func TestMainHandler(t *testing.T) {
    57  	r, err := http.NewRequest("GET", "/", nil)
    58  	if err != nil {
    59  		t.Error("Got an error ", err)
    60  	}
    61  
    62  	w := httptest.NewRecorder()
    63  	mainHandler(w, r)
    64  
    65  	if w.Code != 200 {
    66  		t.Error("Expected 200 got ", w.Code)
    67  	}
    68  
    69  	r, err = http.NewRequest("GET", "/", nil)
    70  	if err != nil {
    71  		t.Error("Got an error ", err)
    72  	}
    73  	w = httptest.NewRecorder()
    74  	setSession(w)
    75  	r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie"))
    76  	mainHandler(w, r)
    77  	body, err := ioutil.ReadAll(w.Body)
    78  	if err != nil {
    79  		t.Error("Got an error ", err)
    80  	}
    81  	if strings.Index(string(body), "\"Logout\"") == -1 {
    82  		t.Error("Expected report page")
    83  	}
    84  }
    85  
    86  func TestAPIHandler(t *testing.T) {
    87  	defer recover()
    88  	r, err := http.NewRequest("GET", "/", nil)
    89  	if err != nil {
    90  		t.Error("Got an error ", err)
    91  	}
    92  	w := httptest.NewRecorder()
    93  	apiHandler(w, r)
    94  	body, err := ioutil.ReadAll(w.Body)
    95  	if err != nil {
    96  		t.Error("Got an error ", err)
    97  	}
    98  	if len(body) != 0 {
    99  		t.Error("Expected 0 length got", len(body))
   100  	}
   101  	setSession(w)
   102  	r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie"))
   103  	apiHandler(w, r)
   104  	body, _ = ioutil.ReadAll(w.Body)
   105  	var j struct {
   106  		Total    int
   107  		Abnormal int
   108  	}
   109  	err = json.Unmarshal(body, &j)
   110  	if err != nil {
   111  		t.Error("Got an error ", err)
   112  	}
   113  	if j.Total != 0 || j.Abnormal != 0 {
   114  		t.Error("Expected 0, 0 got", j)
   115  	}
   116  
   117  	j.Total = 1
   118  	j.Abnormal = 1
   119  
   120  	c1 := sql.QueryContext{
   121  		Query:    []byte("select * from test;"),
   122  		Database: []byte("test"),
   123  		User:     []byte("test"),
   124  		Client:   []byte("127.0.0.1"),
   125  		Time:     time.Now(),
   126  	}
   127  	c2 := sql.QueryContext{
   128  		Query:    []byte("select * from user;"),
   129  		Database: []byte("test"),
   130  		User:     []byte("test"),
   131  		Client:   []byte("127.0.0.1"),
   132  		Time:     time.Now(),
   133  	}
   134  	err = training.AddToTrainingSet(c1)
   135  	if err != nil {
   136  		t.Error("Got an error ", err)
   137  	}
   138  	training.CheckQuery(c2)
   139  	apiHandler(w, r)
   140  	body, _ = ioutil.ReadAll(w.Body)
   141  	err = json.Unmarshal(body, &j)
   142  	fmt.Println(j)
   143  	if err != nil {
   144  		t.Error("Got an error ", err)
   145  	}
   146  	if j.Total != 2 || j.Abnormal != 1 {
   147  		t.Error("Expected 1, 1 got", j)
   148  	}
   149  
   150  	tmpCon := training.DBCon
   151  	defer func() {
   152  		training.DBCon = tmpCon
   153  	}()
   154  	tmpfile, err := ioutil.TempFile("", "testdb")
   155  	if err != nil {
   156  		panic(err)
   157  	}
   158  	defer tmpfile.Close()
   159  	path := tmpfile.Name()
   160  	training.DBCon, err = bolt.Open(path, 0600, nil)
   161  	if err != nil {
   162  		t.Error("Got an error ", err)
   163  	}
   164  	apiHandler(w, r)
   165  	body, err = ioutil.ReadAll(w.Body)
   166  	if err != nil {
   167  		t.Error("Got an error ", err)
   168  	}
   169  	err = json.Unmarshal(body, &j)
   170  	if err != nil {
   171  		t.Error("Got an error ", err)
   172  	}
   173  
   174  	apiHandler(w, r)
   175  	body, err = ioutil.ReadAll(w.Body)
   176  	if err != nil {
   177  		t.Error("Got an error ", err)
   178  	}
   179  	err = json.Unmarshal(body, &j)
   180  	if err != nil {
   181  		t.Error("Got an error ", err)
   182  	}
   183  }
   184  
   185  func TestLoginHandler(t *testing.T) {
   186  	data := url.Values{}
   187  	data.Add("password", "bar")
   188  	r, err := http.NewRequest("POST", "/", bytes.NewBufferString(data.Encode()))
   189  	if err != nil {
   190  		t.Error("Got an error ", err)
   191  	}
   192  
   193  	r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   194  
   195  	w := httptest.NewRecorder()
   196  	loginHandler(w, r)
   197  
   198  	if w.Code != 302 {
   199  		t.Error("Expected 302 got ", w.Code)
   200  	}
   201  
   202  	if w.HeaderMap.Get("Location") != "/" {
   203  		t.Error("Expected / got ", w.HeaderMap.Get("Location"))
   204  	}
   205  
   206  	data.Set("password", config.Config.HTTPPassword)
   207  	r, err = http.NewRequest("POST", "/", bytes.NewBufferString(data.Encode()))
   208  	if err != nil {
   209  		t.Error("Got an error ", err)
   210  	}
   211  
   212  	r.Header.Add("Content-Type", "application/x-www-form-urlencoded")
   213  
   214  	w = httptest.NewRecorder()
   215  	loginHandler(w, r)
   216  
   217  	if w.Code != 302 {
   218  		t.Error("Expected 302 got ", w.Code)
   219  	}
   220  
   221  	if w.HeaderMap.Get("Location") != "/report.htm" {
   222  		t.Error("Expected /report.htm got ", w.HeaderMap.Get("Location"))
   223  	}
   224  }
   225  
   226  func TestLogoutHandler(t *testing.T) {
   227  	r, err := http.NewRequest("POST", "/", nil)
   228  	if err != nil {
   229  		t.Error("Got an error ", err)
   230  	}
   231  
   232  	w := httptest.NewRecorder()
   233  	logoutHandler(w, r)
   234  
   235  	if w.Code != 302 {
   236  		t.Error("Expected 302 got ", w.Code)
   237  	}
   238  }
   239  
   240  func TestCheckLogin(t *testing.T) {
   241  	r, err := http.NewRequest("GET", "/", nil)
   242  	if err != nil {
   243  		t.Error("Got an error ", err)
   244  	}
   245  
   246  	cookie := &http.Cookie{
   247  		Name:     "session",
   248  		Value:    "XYZ",
   249  		Path:     "/",
   250  		Secure:   true,
   251  		HttpOnly: true,
   252  	}
   253  	r.AddCookie(cookie)
   254  	if checkLogin(r) {
   255  		t.Error("Expected false got true")
   256  	}
   257  
   258  	r, err = http.NewRequest("GET", "/", nil)
   259  	if err != nil {
   260  		t.Error("Got an error ", err)
   261  	}
   262  	w := httptest.NewRecorder()
   263  	setSession(w)
   264  	r.Header.Set("Cookie", w.HeaderMap.Get("Set-Cookie"))
   265  
   266  	if !checkLogin(r) {
   267  		t.Error("Expected true got false")
   268  	}
   269  }
   270  
   271  func TestSetSession(t *testing.T) {
   272  	w := httptest.NewRecorder()
   273  	setSession(w)
   274  	if strings.Index(w.HeaderMap.Get("Set-Cookie"), "session=") == -1 {
   275  		t.Error("Expected session cookie")
   276  	}
   277  }
   278  
   279  func TestClearSession(t *testing.T) {
   280  	w := httptest.NewRecorder()
   281  	clearSession(w)
   282  	if strings.Index(w.HeaderMap.Get("Set-Cookie"), "session=;") == -1 {
   283  		t.Error("Expected empty session cookie")
   284  	}
   285  }