github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/examples/sample-application/secure/auth/auth_test.go (about)

     1  // Copyright 2022 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package auth
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/google/go-safeweb/examples/sample-application/storage"
    21  	"github.com/google/go-safeweb/safehttp"
    22  	"github.com/google/go-safeweb/safehttp/safehttptest"
    23  )
    24  
    25  const TEST_USER = "test"
    26  
    27  func TestInterceptorBefore(t *testing.T) {
    28  	tests := []struct {
    29  		name    string
    30  		cfg     safehttp.InterceptorConfig
    31  		hasAuth bool
    32  		want    safehttp.StatusCode
    33  	}{
    34  		{
    35  			name:    "base case, no error",
    36  			hasAuth: true,
    37  			cfg:     nil,
    38  			want:    safehttp.StatusOK,
    39  		},
    40  		{
    41  			name:    "force skip using config",
    42  			hasAuth: true,
    43  			cfg:     Skip{},
    44  			want:    safehttp.StatusOK,
    45  		},
    46  		{
    47  			name:    "missing auth, error",
    48  			hasAuth: false,
    49  			cfg:     nil,
    50  			want:    safehttp.StatusUnauthorized,
    51  		},
    52  	}
    53  	for _, tt := range tests {
    54  		t.Run(tt.name, func(t *testing.T) {
    55  			withUserDB, token := addTestUser(storage.NewDB())
    56  			ip := newTestInterceptor(withUserDB)
    57  
    58  			req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
    59  			rw, r := safehttptest.NewFakeResponseWriter()
    60  
    61  			if tt.hasAuth {
    62  				addTestUserCookie(req, token)
    63  			}
    64  
    65  			// Note: "Before" return value is not significant
    66  			ip.Before(rw, req, tt.cfg)
    67  
    68  			if got := r.Code; got != int(tt.want) {
    69  				t.Errorf("status code got: %d, want %d", got, tt.want)
    70  			}
    71  		})
    72  	}
    73  }
    74  
    75  func TestInterceptorCommit(t *testing.T) {
    76  	tests := []struct {
    77  		name      string
    78  		action    sessionAction
    79  		hasCookie bool
    80  	}{
    81  		{
    82  			name:      "clear session, no error",
    83  			action:    clearSess,
    84  			hasCookie: false,
    85  		}, {
    86  			name:      "set session, no error",
    87  			action:    setSess,
    88  			hasCookie: true,
    89  		},
    90  		{
    91  			name:      "unexpected action, skip",
    92  			action:    sessionAction("unexpected"),
    93  			hasCookie: false,
    94  		},
    95  	}
    96  	for _, tt := range tests {
    97  		t.Run(tt.name, func(t *testing.T) {
    98  			withUserDB, _ := addTestUser(storage.NewDB())
    99  			ip := newTestInterceptor(withUserDB)
   100  
   101  			req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
   102  			rw, r := safehttptest.NewFakeResponseWriter()
   103  
   104  			safehttp.FlightValues(req.Context()).Put(userKey, "user")
   105  			safehttp.FlightValues(req.Context()).Put(changeSessKey, tt.action)
   106  
   107  			ip.Commit(rw, req, r.Result, nil)
   108  
   109  			var token string
   110  			for _, c := range rw.Cookies {
   111  				if c.Name() == sessionCookie {
   112  					token = c.Value()
   113  				}
   114  			}
   115  
   116  			if tt.hasCookie == (token == "") {
   117  				t.Errorf("token = %q, want %v", token, tt.hasCookie)
   118  			}
   119  		})
   120  	}
   121  }
   122  
   123  func TestInterceptorMatch(t *testing.T) {
   124  	tests := []struct {
   125  		name string
   126  		cfg  safehttp.InterceptorConfig
   127  		want bool
   128  	}{
   129  		{
   130  			name: "basic case, no error",
   131  			cfg:  Skip{},
   132  			want: true,
   133  		},
   134  		{
   135  			name: "no Skip{}, error",
   136  			cfg:  nil,
   137  			want: false,
   138  		},
   139  	}
   140  	for _, tt := range tests {
   141  		t.Run(tt.name, func(t *testing.T) {
   142  			ip := newTestInterceptor(nil)
   143  			if got := ip.Match(tt.cfg); got != tt.want {
   144  				t.Errorf("Interceptor.Match() = %v, want %v", got, tt.want)
   145  			}
   146  		})
   147  	}
   148  }
   149  
   150  func TestInterceptorUserFromCookie(t *testing.T) {
   151  	withUserDB, validToken := addTestUser(storage.NewDB())
   152  	ip := newTestInterceptor(withUserDB)
   153  
   154  	tests := []struct {
   155  		name  string
   156  		token string
   157  		want  string
   158  	}{
   159  		{
   160  			name:  "basic case, no error",
   161  			token: validToken,
   162  			want:  TEST_USER,
   163  		},
   164  		{
   165  			name:  "empty cookie, error",
   166  			token: "",
   167  			want:  "",
   168  		},
   169  		{
   170  			name:  "invalid token, error",
   171  			token: "not_a_valid_token",
   172  			want:  "",
   173  		},
   174  	}
   175  	for _, tt := range tests {
   176  		t.Run(tt.name, func(t *testing.T) {
   177  			req := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
   178  			addTestUserCookie(req, tt.token)
   179  
   180  			if got := ip.userFromCookie(req); got != tt.want {
   181  				t.Errorf("Interceptor.userFromCookie() = %v, want %v", got, tt.want)
   182  			}
   183  		})
   184  	}
   185  }
   186  
   187  func TestSessionManagement(t *testing.T) {
   188  	want := "wanted"
   189  	r := safehttptest.NewRequest(safehttp.MethodGet, "/", nil)
   190  
   191  	CreateSession(r, want)
   192  	if got := User(r); got != want {
   193  		t.Errorf("user id got: %q, want %q", got, want)
   194  	}
   195  
   196  	ClearSession(r)
   197  	// Note: `ctxSessionAction` already tested inside ctx_test.go
   198  	if got := ctxSessionAction(r.Context()); got != clearSess {
   199  		t.Errorf("no clearSess action found in context after ClearSession")
   200  	}
   201  }
   202  
   203  func addTestUserCookie(r *safehttp.IncomingRequest, v string) {
   204  	r.Header.Add("Cookie", safehttp.NewCookie(sessionCookie, v).String())
   205  }
   206  
   207  func newTestInterceptor(db *storage.DB) Interceptor {
   208  	if db == nil {
   209  		db = storage.NewDB()
   210  	}
   211  	return Interceptor{
   212  		DB: db,
   213  	}
   214  }
   215  
   216  func addTestUser(db *storage.DB) (*storage.DB, string) {
   217  	token := (*db).GetToken(TEST_USER)
   218  	return db, token
   219  }