github.com/annwntech/go-micro/v2@v2.9.5/util/wrapper/wrapper_test.go (about)

     1  package wrapper
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"reflect"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/annwntech/go-micro/v2/auth"
    11  	"github.com/annwntech/go-micro/v2/client"
    12  	"github.com/annwntech/go-micro/v2/errors"
    13  	"github.com/annwntech/go-micro/v2/metadata"
    14  	"github.com/annwntech/go-micro/v2/server"
    15  )
    16  
    17  func TestWrapper(t *testing.T) {
    18  	testData := []struct {
    19  		existing  metadata.Metadata
    20  		headers   metadata.Metadata
    21  		overwrite bool
    22  	}{
    23  		{
    24  			existing: metadata.Metadata{},
    25  			headers: metadata.Metadata{
    26  				"Foo": "bar",
    27  			},
    28  			overwrite: true,
    29  		},
    30  		{
    31  			existing: metadata.Metadata{
    32  				"Foo": "bar",
    33  			},
    34  			headers: metadata.Metadata{
    35  				"Foo": "baz",
    36  			},
    37  			overwrite: false,
    38  		},
    39  	}
    40  
    41  	for _, d := range testData {
    42  		c := &fromServiceWrapper{
    43  			headers: d.headers,
    44  		}
    45  
    46  		ctx := metadata.NewContext(context.Background(), d.existing)
    47  		ctx = c.setHeaders(ctx)
    48  		md, _ := metadata.FromContext(ctx)
    49  
    50  		for k, v := range d.headers {
    51  			if d.overwrite && md[k] != v {
    52  				t.Fatalf("Expected %s=%s got %s=%s", k, v, k, md[k])
    53  			}
    54  			if !d.overwrite && md[k] != d.existing[k] {
    55  				t.Fatalf("Expected %s=%s got %s=%s", k, d.existing[k], k, md[k])
    56  			}
    57  		}
    58  	}
    59  }
    60  
    61  type testAuth struct {
    62  	verifyCount    int
    63  	inspectCount   int
    64  	namespace      string
    65  	inspectAccount *auth.Account
    66  	verifyError    error
    67  
    68  	auth.Auth
    69  }
    70  
    71  func (a *testAuth) Verify(acc *auth.Account, res *auth.Resource, opts ...auth.VerifyOption) error {
    72  	a.verifyCount = a.verifyCount + 1
    73  	return a.verifyError
    74  }
    75  
    76  func (a *testAuth) Inspect(token string) (*auth.Account, error) {
    77  	a.inspectCount = a.inspectCount + 1
    78  	return a.inspectAccount, nil
    79  }
    80  
    81  func (a *testAuth) Options() auth.Options {
    82  	return auth.Options{Issuer: a.namespace}
    83  }
    84  
    85  type testRequest struct {
    86  	service  string
    87  	endpoint string
    88  
    89  	server.Request
    90  }
    91  
    92  func (r testRequest) Service() string {
    93  	return r.service
    94  }
    95  
    96  func (r testRequest) Endpoint() string {
    97  	return r.endpoint
    98  }
    99  
   100  func TestAuthHandler(t *testing.T) {
   101  	h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   102  		return nil
   103  	}
   104  
   105  	debugReq := testRequest{service: "go.micro.service.foo", endpoint: "Debug.Foo"}
   106  	serviceReq := testRequest{service: "go.micro.service.foo", endpoint: "Foo.Bar"}
   107  
   108  	// Debug endpoints should be excluded from auth so auth.Verify should never get called
   109  	t.Run("DebugEndpoint", func(t *testing.T) {
   110  		a := testAuth{}
   111  		handler := AuthHandler(func() auth.Auth {
   112  			return &a
   113  		})
   114  
   115  		err := handler(h)(context.TODO(), debugReq, nil)
   116  		if err != nil {
   117  			t.Errorf("Expected nil error but got %v", err)
   118  		}
   119  		if a.verifyCount != 0 {
   120  			t.Errorf("Did not expect verify to be called")
   121  		}
   122  	})
   123  
   124  	// If the Authorization header is blank, no error should be returned and verify not called
   125  	t.Run("BlankAuthorizationHeader", func(t *testing.T) {
   126  		a := testAuth{}
   127  		handler := AuthHandler(func() auth.Auth {
   128  			return &a
   129  		})
   130  
   131  		err := handler(h)(context.TODO(), serviceReq, nil)
   132  		if err != nil {
   133  			t.Errorf("Expected nil error but got %v", err)
   134  		}
   135  		if a.inspectCount != 0 {
   136  			t.Errorf("Did not expect inspect to be called")
   137  		}
   138  	})
   139  
   140  	// If the Authorization header is invalid, an error should be returned and verify not called
   141  	t.Run("InvalidAuthorizationHeader", func(t *testing.T) {
   142  		a := testAuth{}
   143  		handler := AuthHandler(func() auth.Auth {
   144  			return &a
   145  		})
   146  
   147  		ctx := metadata.Set(context.TODO(), "Authorization", "Invalid")
   148  		err := handler(h)(ctx, serviceReq, nil)
   149  		if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized {
   150  			t.Errorf("Expected unauthorized error but got %v", err)
   151  		}
   152  		if a.inspectCount != 0 {
   153  			t.Errorf("Did not expect inspect to be called")
   154  		}
   155  	})
   156  
   157  	// If the Authorization header is valid, no error should be returned and verify should called
   158  	t.Run("ValidAuthorizationHeader", func(t *testing.T) {
   159  		a := testAuth{}
   160  		handler := AuthHandler(func() auth.Auth {
   161  			return &a
   162  		})
   163  
   164  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   165  		err := handler(h)(ctx, serviceReq, nil)
   166  		if err != nil {
   167  			t.Errorf("Expected nil error but got %v", err)
   168  		}
   169  		if a.inspectCount != 1 {
   170  			t.Errorf("Expected inspect to be called")
   171  		}
   172  	})
   173  
   174  	// If the namespace header was not set on the request, the wrapper should set it to the auths
   175  	// own namespace
   176  	t.Run("BlankNamespaceHeader", func(t *testing.T) {
   177  		a := testAuth{namespace: "mynamespace"}
   178  		handler := AuthHandler(func() auth.Auth {
   179  			return &a
   180  		})
   181  
   182  		inCtx := context.TODO()
   183  		h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   184  			inCtx = ctx
   185  			return nil
   186  		}
   187  
   188  		err := handler(h)(inCtx, serviceReq, nil)
   189  		if err != nil {
   190  			t.Errorf("Expected nil error but got %v", err)
   191  		}
   192  		if ns, _ := metadata.Get(inCtx, "Micro-Namespace"); ns != a.namespace {
   193  			t.Errorf("Expected namespace to be set to %v but was %v", a.namespace, ns)
   194  		}
   195  	})
   196  	t.Run("ValidNamespaceHeader", func(t *testing.T) {
   197  		a := testAuth{namespace: "mynamespace"}
   198  		handler := AuthHandler(func() auth.Auth {
   199  			return &a
   200  		})
   201  
   202  		inNs := "reqnamespace"
   203  		inCtx := metadata.Set(context.TODO(), "Micro-Namespace", inNs)
   204  		h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   205  			inCtx = ctx
   206  			return nil
   207  		}
   208  
   209  		err := handler(h)(inCtx, serviceReq, nil)
   210  		if err != nil {
   211  			t.Errorf("Expected nil error but got %v", err)
   212  		}
   213  		if ns, _ := metadata.Get(inCtx, "Micro-Namespace"); ns != inNs {
   214  			t.Errorf("Expected namespace to remain as %v but was set to %v", inNs, ns)
   215  		}
   216  	})
   217  
   218  	// If the callers account was set but the issuer didn't match that of the request, the request
   219  	// should be forbidden
   220  	t.Run("InvalidAccountIssuer", func(t *testing.T) {
   221  		a := testAuth{
   222  			namespace:      "validnamespace",
   223  			inspectAccount: &auth.Account{Issuer: "invalidnamespace"},
   224  		}
   225  
   226  		handler := AuthHandler(func() auth.Auth {
   227  			return &a
   228  		})
   229  
   230  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   231  		err := handler(h)(ctx, serviceReq, nil)
   232  		if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusForbidden {
   233  			t.Errorf("Expected forbidden error but got %v", err)
   234  		}
   235  	})
   236  	t.Run("ValidAccountIssuer", func(t *testing.T) {
   237  		a := testAuth{
   238  			namespace:      "validnamespace",
   239  			inspectAccount: &auth.Account{Issuer: "validnamespace"},
   240  		}
   241  
   242  		handler := AuthHandler(func() auth.Auth {
   243  			return &a
   244  		})
   245  
   246  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   247  		err := handler(h)(ctx, serviceReq, nil)
   248  		if err != nil {
   249  			t.Errorf("Expected nil error but got %v", err)
   250  		}
   251  	})
   252  
   253  	// If the caller had a nil account and verify returns an error, the request should be unauthorised
   254  	t.Run("NilAccountUnauthorized", func(t *testing.T) {
   255  		a := testAuth{verifyError: auth.ErrForbidden}
   256  
   257  		handler := AuthHandler(func() auth.Auth {
   258  			return &a
   259  		})
   260  
   261  		err := handler(h)(context.TODO(), serviceReq, nil)
   262  		if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized {
   263  			t.Errorf("Expected unauthorizard error but got %v", err)
   264  		}
   265  	})
   266  	t.Run("AccountForbidden", func(t *testing.T) {
   267  		a := testAuth{verifyError: auth.ErrForbidden, inspectAccount: &auth.Account{}}
   268  
   269  		handler := AuthHandler(func() auth.Auth {
   270  			return &a
   271  		})
   272  
   273  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   274  		err := handler(h)(ctx, serviceReq, nil)
   275  		if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusForbidden {
   276  			t.Errorf("Expected forbidden error but got %v", err)
   277  		}
   278  	})
   279  	t.Run("AccountValid", func(t *testing.T) {
   280  		a := testAuth{inspectAccount: &auth.Account{}}
   281  
   282  		handler := AuthHandler(func() auth.Auth {
   283  			return &a
   284  		})
   285  
   286  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   287  		err := handler(h)(ctx, serviceReq, nil)
   288  		if err != nil {
   289  			t.Errorf("Expected nil error but got %v", err)
   290  		}
   291  	})
   292  
   293  	// If an account is returned from inspecting the token, it should be set in the context
   294  	t.Run("ContextWithAccount", func(t *testing.T) {
   295  		accID := "myaccountid"
   296  		a := testAuth{inspectAccount: &auth.Account{ID: accID}}
   297  
   298  		handler := AuthHandler(func() auth.Auth {
   299  			return &a
   300  		})
   301  
   302  		inCtx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   303  		h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   304  			inCtx = ctx
   305  			return nil
   306  		}
   307  
   308  		err := handler(h)(inCtx, serviceReq, nil)
   309  		if err != nil {
   310  			t.Errorf("Expected nil error but got %v", err)
   311  		}
   312  		if acc, ok := auth.AccountFromContext(inCtx); !ok {
   313  			t.Errorf("Expected an account to be set in the context")
   314  		} else if acc.ID != accID {
   315  			t.Errorf("Expected the account in the context to have the ID %v but it actually had %v", accID, acc.ID)
   316  		}
   317  	})
   318  
   319  	// If verify returns an error the handler should not be called
   320  	t.Run("HandlerNotCalled", func(t *testing.T) {
   321  		a := testAuth{verifyError: auth.ErrForbidden}
   322  
   323  		handler := AuthHandler(func() auth.Auth {
   324  			return &a
   325  		})
   326  
   327  		var handlerCalled bool
   328  		h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   329  			handlerCalled = true
   330  			return nil
   331  		}
   332  
   333  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   334  		err := handler(h)(ctx, serviceReq, nil)
   335  		if verr, ok := err.(*errors.Error); !ok || verr.Code != http.StatusUnauthorized {
   336  			t.Errorf("Expected unauthorizard error but got %v", err)
   337  		}
   338  		if handlerCalled {
   339  			t.Errorf("Expected the handler to not be called")
   340  		}
   341  	})
   342  
   343  	// If verify does not return an error the handler should be called
   344  	t.Run("HandlerNotCalled", func(t *testing.T) {
   345  		a := testAuth{}
   346  
   347  		handler := AuthHandler(func() auth.Auth {
   348  			return &a
   349  		})
   350  
   351  		var handlerCalled bool
   352  		h := func(ctx context.Context, req server.Request, rsp interface{}) error {
   353  			handlerCalled = true
   354  			return nil
   355  		}
   356  
   357  		ctx := metadata.Set(context.TODO(), "Authorization", auth.BearerScheme+"Token")
   358  		err := handler(h)(ctx, serviceReq, nil)
   359  		if err != nil {
   360  			t.Errorf("Expected nil error but got %v", err)
   361  		}
   362  		if !handlerCalled {
   363  			t.Errorf("Expected the handler be called")
   364  		}
   365  	})
   366  }
   367  
   368  type testClient struct {
   369  	callCount int
   370  	callRsp   interface{}
   371  	client.Client
   372  }
   373  
   374  func (c *testClient) Call(ctx context.Context, req client.Request, rsp interface{}, opts ...client.CallOption) error {
   375  	c.callCount++
   376  
   377  	if c.callRsp != nil {
   378  		val := reflect.ValueOf(rsp).Elem()
   379  		val.Set(reflect.ValueOf(c.callRsp).Elem())
   380  	}
   381  
   382  	return nil
   383  }
   384  
   385  type testRsp struct {
   386  	value string
   387  }
   388  
   389  func TestCacheWrapper(t *testing.T) {
   390  	req := client.NewRequest("go.micro.service.foo", "Foo.Bar", nil)
   391  
   392  	t.Run("NilCache", func(t *testing.T) {
   393  		cli := new(testClient)
   394  
   395  		w := CacheClient(func() *client.Cache {
   396  			return nil
   397  		}, cli)
   398  
   399  		// perfroming two requests should increment the call count by two indicating the cache wasn't
   400  		// used even though the WithCache option was passed.
   401  		w.Call(context.TODO(), req, nil, client.WithCache(time.Minute))
   402  		w.Call(context.TODO(), req, nil, client.WithCache(time.Minute))
   403  
   404  		if cli.callCount != 2 {
   405  			t.Errorf("Expected the client to have been called twice")
   406  		}
   407  	})
   408  
   409  	t.Run("OptionNotSet", func(t *testing.T) {
   410  		cli := new(testClient)
   411  		cache := client.NewCache()
   412  
   413  		w := CacheClient(func() *client.Cache {
   414  			return cache
   415  		}, cli)
   416  
   417  		// perfroming two requests should increment the call count by two since we didn't pass the WithCache
   418  		// option to Call.
   419  		w.Call(context.TODO(), req, nil)
   420  		w.Call(context.TODO(), req, nil)
   421  
   422  		if cli.callCount != 2 {
   423  			t.Errorf("Expected the client to have been called twice")
   424  		}
   425  	})
   426  
   427  	t.Run("OptionSet", func(t *testing.T) {
   428  		val := "foo"
   429  		cli := &testClient{callRsp: &testRsp{value: val}}
   430  		cache := client.NewCache()
   431  
   432  		w := CacheClient(func() *client.Cache {
   433  			return cache
   434  		}, cli)
   435  
   436  		// perfroming two requests should increment the call count by once since the second request should
   437  		// have used the cache. The correct value should be set on both responses and no errors should
   438  		// be returned.
   439  		rsp1 := &testRsp{}
   440  		rsp2 := &testRsp{}
   441  		err1 := w.Call(context.TODO(), req, rsp1, client.WithCache(time.Minute))
   442  		err2 := w.Call(context.TODO(), req, rsp2, client.WithCache(time.Minute))
   443  
   444  		if err1 != nil {
   445  			t.Errorf("Expected nil error, got %v", err1)
   446  		}
   447  		if err2 != nil {
   448  			t.Errorf("Expected nil error, got %v", err2)
   449  		}
   450  
   451  		if rsp1.value != val {
   452  			t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp1.value)
   453  		}
   454  		if rsp2.value != val {
   455  			t.Errorf("Expected %v to be assigned to the value, got %v", val, rsp2.value)
   456  		}
   457  
   458  		if cli.callCount != 1 {
   459  			t.Errorf("Expected the client to be called 1 time, was actually called %v time(s)", cli.callCount)
   460  		}
   461  	})
   462  }