github.com/GGP1/kure@v0.8.4/commands/2fa/add/add_test.go (about)

     1  package add
     2  
     3  import (
     4  	"bytes"
     5  	"net/url"
     6  	"testing"
     7  
     8  	cmdutil "github.com/GGP1/kure/commands"
     9  	"github.com/GGP1/kure/db/totp"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  func TestAdd(t *testing.T) {
    15  	db := cmdutil.SetContext(t)
    16  
    17  	cases := []struct {
    18  		desc   string
    19  		name   string
    20  		input  string
    21  		digits string
    22  		url    string
    23  	}{
    24  		{
    25  			desc:   "Key with 6 digits",
    26  			name:   "6_digits",
    27  			input:  "IFGEWRKSIFJUMR2R",
    28  			digits: "6",
    29  		},
    30  		{
    31  			desc:   "Key with 7 digits",
    32  			name:   "7_digits",
    33  			input:  "IFGEWRKSIFJUMR2R",
    34  			digits: "7",
    35  		},
    36  		{
    37  			desc:   "Key with 8 digits",
    38  			name:   "8_digits",
    39  			input:  "IFGEWRKSIFJUMR2R",
    40  			digits: "8",
    41  		},
    42  		{
    43  			desc:  "URL",
    44  			name:  "url",
    45  			input: "otpauth://totp/Test?secret=IFGEWRKSIFJUMR2R",
    46  			url:   "true",
    47  		},
    48  	}
    49  
    50  	for _, tc := range cases {
    51  		t.Run(tc.desc, func(t *testing.T) {
    52  			buf := bytes.NewBufferString(tc.input)
    53  			cmd := NewCmd(db, buf)
    54  			cmd.SetArgs([]string{tc.name})
    55  
    56  			f := cmd.Flags()
    57  			f.Set("digits", tc.digits)
    58  			f.Set("url", tc.url)
    59  
    60  			err := cmd.Execute()
    61  			assert.NoError(t, err, "Failed adding TOTP")
    62  		})
    63  	}
    64  }
    65  
    66  func TestAddErrors(t *testing.T) {
    67  	db := cmdutil.SetContext(t)
    68  
    69  	name := "test"
    70  	err := createTOTP(db, name, "", 0)
    71  	assert.NoError(t, err)
    72  
    73  	cases := []struct {
    74  		desc   string
    75  		input  string
    76  		name   string
    77  		digits string
    78  		url    string
    79  	}{
    80  		{
    81  			desc: "Invalid name",
    82  			name: "",
    83  		},
    84  		{
    85  			desc: "Key already exists",
    86  			name: name,
    87  		},
    88  		{
    89  			desc:   "Invalid digits",
    90  			name:   "fail",
    91  			digits: "10",
    92  		},
    93  		{
    94  			desc:   "Invalid key",
    95  			name:   "fail",
    96  			digits: "7",
    97  			input:  "invalid/%",
    98  		},
    99  		{
   100  			desc:  "URL already exists",
   101  			input: "otpauth://totp/Test?secret=IFGEWRKSIFJUMR2R",
   102  			url:   "true",
   103  		},
   104  		{
   105  			desc:  "Invalid url secret",
   106  			input: "otpauth://totp/Testing?secret=not-base32",
   107  			url:   "true",
   108  		},
   109  		{
   110  			desc:  "Invalid url format",
   111  			input: "otpauth://hotp/Tests?secret=IFGEWRKSIFJUMR2R",
   112  			url:   "true",
   113  		},
   114  	}
   115  
   116  	for _, tc := range cases {
   117  		t.Run(tc.desc, func(t *testing.T) {
   118  			buf := bytes.NewBufferString(tc.input)
   119  			cmd := NewCmd(db, buf)
   120  			cmd.SetArgs([]string{tc.name})
   121  
   122  			f := cmd.Flags()
   123  			f.Set("digits", tc.digits)
   124  			f.Set("url", tc.url)
   125  
   126  			err := cmd.Execute()
   127  			assert.Error(t, err)
   128  		})
   129  	}
   130  }
   131  
   132  func TestCreateTOTP(t *testing.T) {
   133  	db := cmdutil.SetContext(t)
   134  
   135  	t.Run("Success", func(t *testing.T) {
   136  		name := "test"
   137  		err := createTOTP(db, name, "secret", 6)
   138  		assert.NoError(t, err, "Failed creating TOTP")
   139  
   140  		_, err = totp.Get(db, name)
   141  		assert.NoErrorf(t, err, "%q TOTP not found", name)
   142  	})
   143  
   144  	t.Run("Fail", func(t *testing.T) {
   145  		err := createTOTP(db, "", "", 0)
   146  		assert.Error(t, err)
   147  	})
   148  }
   149  
   150  func TestGetName(t *testing.T) {
   151  	expected := "test"
   152  	path := "/Test:mail@enterprise.com"
   153  	got := getName(path)
   154  
   155  	assert.Equal(t, expected, got)
   156  }
   157  
   158  func TestStringDigits(t *testing.T) {
   159  	cases := []struct {
   160  		desc     string
   161  		digits   string
   162  		expected int32
   163  	}{
   164  		{
   165  			desc:     "Empty",
   166  			digits:   "",
   167  			expected: 6,
   168  		},
   169  		{
   170  			desc:     "Six digits",
   171  			digits:   "6",
   172  			expected: 6,
   173  		},
   174  		{
   175  			desc:     "Seven digits",
   176  			digits:   "7",
   177  			expected: 7,
   178  		},
   179  		{
   180  			desc:     "Eight digits",
   181  			digits:   "8",
   182  			expected: 8,
   183  		},
   184  	}
   185  
   186  	for _, tc := range cases {
   187  		t.Run(tc.desc, func(t *testing.T) {
   188  			got := stringDigits(tc.digits)
   189  			assert.Equal(t, tc.expected, got)
   190  		})
   191  	}
   192  }
   193  
   194  func TestValidateURL(t *testing.T) {
   195  	t.Run("Valid", func(t *testing.T) {
   196  		u := &url.URL{
   197  			Scheme: "otpauth",
   198  			Host:   "totp",
   199  		}
   200  		query := u.Query()
   201  		query.Add("algorithm", "SHA1")
   202  		query.Add("period", "30")
   203  		err := validateURL(u, query)
   204  		assert.NoError(t, err)
   205  	})
   206  
   207  	t.Run("Invalid", func(t *testing.T) {
   208  		cases := []struct {
   209  			desc      string
   210  			scheme    string
   211  			host      string
   212  			period    string
   213  			algorithm string
   214  		}{
   215  			{
   216  				desc:   "Invalid scheme",
   217  				scheme: "otpauth-migration",
   218  			},
   219  			{
   220  				desc:   "Invalid host",
   221  				scheme: "otpauth",
   222  				host:   "hotp",
   223  			},
   224  			{
   225  				desc:      "Invalid algorithm",
   226  				scheme:    "otpauth",
   227  				host:      "totp",
   228  				algorithm: "SHA256",
   229  			},
   230  			{
   231  				desc:   "Invalid period",
   232  				scheme: "otpauth",
   233  				host:   "totp",
   234  				period: "60",
   235  			},
   236  		}
   237  
   238  		for _, tc := range cases {
   239  			t.Run(tc.desc, func(t *testing.T) {
   240  				u := &url.URL{
   241  					Scheme: tc.scheme,
   242  					Host:   tc.host,
   243  				}
   244  				query := u.Query()
   245  				query.Add("algorithm", tc.algorithm)
   246  				query.Add("period", tc.period)
   247  				err := validateURL(u, query)
   248  				assert.Error(t, err)
   249  			})
   250  		}
   251  	})
   252  }
   253  
   254  func TestPostRun(t *testing.T) {
   255  	NewCmd(nil, nil).PostRun(nil, nil)
   256  }