github.com/qiwihui/DBShield@v0.0.0-20171107092910-fb8553bed8ef/dbshield/utils_test.go (about)

     1  // +build !windows
     2  
     3  package dbshield
     4  
     5  import (
     6  	"net"
     7  	"os"
     8  	"testing"
     9  
    10  	"github.com/nim4/mock"
    11  	"github.com/qiwihui/DBShield/dbshield/config"
    12  	"github.com/qiwihui/DBShield/dbshield/logger"
    13  )
    14  
    15  func TestDbNameToStruct(t *testing.T) {
    16  	_, err := dbNameToStruct("db2")
    17  	if err != nil {
    18  		t.Error("Expected struct, got ", err)
    19  		return
    20  	}
    21  	_, err = dbNameToStruct("mysql")
    22  	if err != nil {
    23  		t.Error("Expected struct, got ", err)
    24  		return
    25  	}
    26  	_, err = dbNameToStruct("oracle")
    27  	if err != nil {
    28  		t.Error("Expected struct, got ", err)
    29  		return
    30  	}
    31  	_, err = dbNameToStruct("postgres")
    32  	if err != nil {
    33  		t.Error("Expected struct, got ", err)
    34  		return
    35  	}
    36  	//Invalid case is tested in postConfig test
    37  }
    38  
    39  func TestInitLogging(t *testing.T) {
    40  	config.Config.LogPath = "stdout"
    41  	initLogging()
    42  }
    43  
    44  func TestHandleClient(t *testing.T) {
    45  	var s mock.ConnMock
    46  	err := handleClient(&s, nil)
    47  	if err == nil {
    48  		t.Error("Expected error got nil")
    49  	}
    50  	ls, _ := net.Listen("tcp4", "localhost:0")
    51  	go func() {
    52  		for {
    53  			conn, _ := ls.Accept()
    54  			conn.Close()
    55  		}
    56  	}()
    57  
    58  	ra, err := net.ResolveTCPAddr("tcp4", ls.Addr().String())
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	err = handleClient(&s, ra)
    63  	if err == nil {
    64  		t.Error("Expected error got nil")
    65  	}
    66  }
    67  
    68  func TestCloseHandlers(t *testing.T) {
    69  	logger.Output = os.Stderr
    70  	defer func() {
    71  		if r := recover(); r != nil {
    72  			t.Error("Panic!")
    73  		}
    74  	}()
    75  	closeHandlers()
    76  }
    77  
    78  func TestGenerateDBMS(t *testing.T) {
    79  	config.Config.DB = 0
    80  	v, _ := generateDBMS()
    81  	if v == nil {
    82  		t.Error("Got nil")
    83  	}
    84  
    85  	config.Config.DB++
    86  	v, _ = generateDBMS()
    87  	if v == nil {
    88  		t.Error("Got nil")
    89  	}
    90  
    91  	config.Config.DB++
    92  	v, _ = generateDBMS()
    93  	if v == nil {
    94  		t.Error("Got nil")
    95  	}
    96  
    97  	config.Config.DB++
    98  	v, _ = generateDBMS()
    99  	if v == nil {
   100  		t.Error("Got nil")
   101  	}
   102  
   103  	config.Config.DB = 100
   104  	v, _ = generateDBMS()
   105  	if v != nil {
   106  		t.Error("Expected nil")
   107  	}
   108  }