github.com/weaviate/weaviate@v1.24.6/usecases/objects/authorization_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package objects
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"fmt"
    18  	"reflect"
    19  	"testing"
    20  
    21  	"github.com/go-openapi/strfmt"
    22  	"github.com/sirupsen/logrus/hooks/test"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  	"github.com/weaviate/weaviate/entities/additional"
    26  	"github.com/weaviate/weaviate/entities/models"
    27  	"github.com/weaviate/weaviate/usecases/config"
    28  )
    29  
    30  // A component-test like test suite that makes sure that every available UC is
    31  // potentially protected with the Authorization plugin
    32  
    33  func Test_Kinds_Authorization(t *testing.T) {
    34  	type testCase struct {
    35  		methodName       string
    36  		additionalArgs   []interface{}
    37  		expectedVerb     string
    38  		expectedResource string
    39  	}
    40  
    41  	tests := []testCase{
    42  		// single kind
    43  		{
    44  			methodName:       "AddObject",
    45  			additionalArgs:   []interface{}{(*models.Object)(nil)},
    46  			expectedVerb:     "create",
    47  			expectedResource: "objects",
    48  		},
    49  		{
    50  			methodName:       "ValidateObject",
    51  			additionalArgs:   []interface{}{(*models.Object)(nil)},
    52  			expectedVerb:     "validate",
    53  			expectedResource: "objects",
    54  		},
    55  		{
    56  			methodName:       "GetObject",
    57  			additionalArgs:   []interface{}{"", strfmt.UUID("foo"), additional.Properties{}},
    58  			expectedVerb:     "get",
    59  			expectedResource: "objects/foo",
    60  		},
    61  		{
    62  			methodName:       "DeleteObject",
    63  			additionalArgs:   []interface{}{"class", strfmt.UUID("foo")},
    64  			expectedVerb:     "delete",
    65  			expectedResource: "objects/class/foo",
    66  		},
    67  		{ // deprecated by the one above
    68  			methodName:       "DeleteObject",
    69  			additionalArgs:   []interface{}{"", strfmt.UUID("foo")},
    70  			expectedVerb:     "delete",
    71  			expectedResource: "objects/foo",
    72  		},
    73  		{
    74  			methodName:       "UpdateObject",
    75  			additionalArgs:   []interface{}{"class", strfmt.UUID("foo"), (*models.Object)(nil)},
    76  			expectedVerb:     "update",
    77  			expectedResource: "objects/class/foo",
    78  		},
    79  		{ // deprecated by the one above
    80  			methodName:       "UpdateObject",
    81  			additionalArgs:   []interface{}{"", strfmt.UUID("foo"), (*models.Object)(nil)},
    82  			expectedVerb:     "update",
    83  			expectedResource: "objects/foo",
    84  		},
    85  		{
    86  			methodName: "MergeObject",
    87  			additionalArgs: []interface{}{
    88  				&models.Object{Class: "class", ID: "foo"},
    89  				(*additional.ReplicationProperties)(nil),
    90  			},
    91  			expectedVerb:     "update",
    92  			expectedResource: "objects/class/foo",
    93  		},
    94  		{
    95  			methodName:       "GetObjectsClass",
    96  			additionalArgs:   []interface{}{strfmt.UUID("foo")},
    97  			expectedVerb:     "get",
    98  			expectedResource: "objects/foo",
    99  		},
   100  		{
   101  			methodName:       "GetObjectClassFromName",
   102  			additionalArgs:   []interface{}{strfmt.UUID("foo")},
   103  			expectedVerb:     "get",
   104  			expectedResource: "objects/foo",
   105  		},
   106  		{
   107  			methodName:       "HeadObject",
   108  			additionalArgs:   []interface{}{"class", strfmt.UUID("foo")},
   109  			expectedVerb:     "head",
   110  			expectedResource: "objects/class/foo",
   111  		},
   112  		{ // deprecated by the one above
   113  			methodName:       "HeadObject",
   114  			additionalArgs:   []interface{}{"", strfmt.UUID("foo")},
   115  			expectedVerb:     "head",
   116  			expectedResource: "objects/foo",
   117  		},
   118  
   119  		// query objects
   120  		{
   121  			methodName:       "Query",
   122  			additionalArgs:   []interface{}{new(QueryParams)},
   123  			expectedVerb:     "list",
   124  			expectedResource: "objects",
   125  		},
   126  
   127  		{ // list objects is deprecated by query
   128  			methodName:       "GetObjects",
   129  			additionalArgs:   []interface{}{(*int64)(nil), (*int64)(nil), (*string)(nil), (*string)(nil), additional.Properties{}},
   130  			expectedVerb:     "list",
   131  			expectedResource: "objects",
   132  		},
   133  
   134  		// reference on objects
   135  		{
   136  			methodName:       "AddObjectReference",
   137  			additionalArgs:   []interface{}{AddReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}, (*models.SingleRef)(nil)},
   138  			expectedVerb:     "update",
   139  			expectedResource: "objects/class/foo",
   140  		},
   141  		{
   142  			methodName:       "DeleteObjectReference",
   143  			additionalArgs:   []interface{}{strfmt.UUID("foo"), "some prop", (*models.SingleRef)(nil)},
   144  			expectedVerb:     "update",
   145  			expectedResource: "objects/foo",
   146  		},
   147  		{
   148  			methodName:       "UpdateObjectReferences",
   149  			additionalArgs:   []interface{}{&PutReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}},
   150  			expectedVerb:     "update",
   151  			expectedResource: "objects/class/foo",
   152  		},
   153  	}
   154  
   155  	t.Run("verify that a test for every public method exists", func(t *testing.T) {
   156  		testedMethods := make([]string, len(tests))
   157  		for i, test := range tests {
   158  			testedMethods[i] = test.methodName
   159  		}
   160  
   161  		for _, method := range allExportedMethods(&Manager{}) {
   162  			assert.Contains(t, testedMethods, method)
   163  		}
   164  	})
   165  
   166  	t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
   167  		principal := &models.Principal{}
   168  		logger, _ := test.NewNullLogger()
   169  		for _, test := range tests {
   170  			if test.methodName != "MergeObject" {
   171  				continue
   172  			}
   173  			t.Run(test.methodName, func(t *testing.T) {
   174  				schemaManager := &fakeSchemaManager{}
   175  				locks := &fakeLocks{}
   176  				cfg := &config.WeaviateConfig{}
   177  				authorizer := &authDenier{}
   178  				vectorRepo := &fakeVectorRepo{}
   179  				manager := NewManager(locks, schemaManager,
   180  					cfg, logger, authorizer,
   181  					vectorRepo, getFakeModulesProvider(), nil)
   182  
   183  				args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
   184  				out, _ := callFuncByName(manager, test.methodName, args...)
   185  
   186  				require.Len(t, authorizer.calls, 1, "authorizer must be called")
   187  				aerr := out[len(out)-1].Interface().(error)
   188  				if err, ok := aerr.(*Error); !ok || !err.Forbidden() {
   189  					assert.Equal(t, errors.New("just a test fake"), aerr,
   190  						"execution must abort with authorizer error")
   191  				}
   192  
   193  				assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
   194  					authorizer.calls[0], "correct parameters must have been used on authorizer")
   195  			})
   196  		}
   197  	})
   198  }
   199  
   200  func Test_BatchKinds_Authorization(t *testing.T) {
   201  	type testCase struct {
   202  		methodName       string
   203  		additionalArgs   []interface{}
   204  		expectedVerb     string
   205  		expectedResource string
   206  	}
   207  
   208  	tests := []testCase{
   209  		{
   210  			methodName: "AddObjects",
   211  			additionalArgs: []interface{}{
   212  				[]*models.Object{},
   213  				[]*string{},
   214  				&additional.ReplicationProperties{},
   215  			},
   216  			expectedVerb:     "create",
   217  			expectedResource: "batch/objects",
   218  		},
   219  
   220  		{
   221  			methodName: "AddReferences",
   222  			additionalArgs: []interface{}{
   223  				[]*models.BatchReference{},
   224  				&additional.ReplicationProperties{},
   225  			},
   226  			expectedVerb:     "update",
   227  			expectedResource: "batch/*",
   228  		},
   229  
   230  		{
   231  			methodName: "DeleteObjects",
   232  			additionalArgs: []interface{}{
   233  				&models.BatchDeleteMatch{},
   234  				(*bool)(nil),
   235  				(*string)(nil),
   236  				&additional.ReplicationProperties{},
   237  				"",
   238  			},
   239  			expectedVerb:     "delete",
   240  			expectedResource: "batch/objects",
   241  		},
   242  		{
   243  			methodName: "DeleteObjectsFromGRPC",
   244  			additionalArgs: []interface{}{
   245  				BatchDeleteParams{},
   246  				&additional.ReplicationProperties{},
   247  				"",
   248  			},
   249  			expectedVerb:     "delete",
   250  			expectedResource: "batch/objects",
   251  		},
   252  	}
   253  
   254  	t.Run("verify that a test for every public method exists", func(t *testing.T) {
   255  		testedMethods := make([]string, len(tests))
   256  		for i, test := range tests {
   257  			testedMethods[i] = test.methodName
   258  		}
   259  
   260  		for _, method := range allExportedMethods(&BatchManager{}) {
   261  			assert.Contains(t, testedMethods, method)
   262  		}
   263  	})
   264  
   265  	t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
   266  		principal := &models.Principal{}
   267  		logger, _ := test.NewNullLogger()
   268  		for _, test := range tests {
   269  			schemaManager := &fakeSchemaManager{}
   270  			locks := &fakeLocks{}
   271  			cfg := &config.WeaviateConfig{}
   272  			authorizer := &authDenier{}
   273  			vectorRepo := &fakeVectorRepo{}
   274  			modulesProvider := getFakeModulesProvider()
   275  			manager := NewBatchManager(vectorRepo, modulesProvider, locks, schemaManager, cfg, logger, authorizer, nil)
   276  
   277  			args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
   278  			out, _ := callFuncByName(manager, test.methodName, args...)
   279  
   280  			require.Len(t, authorizer.calls, 1, "authorizer must be called")
   281  			assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
   282  				"execution must abort with authorizer error")
   283  			assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
   284  				authorizer.calls[0], "correct parameters must have been used on authorizer")
   285  		}
   286  	})
   287  }
   288  
   289  type authorizeCall struct {
   290  	principal *models.Principal
   291  	verb      string
   292  	resource  string
   293  }
   294  
   295  type authDenier struct {
   296  	calls []authorizeCall
   297  }
   298  
   299  func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
   300  	a.calls = append(a.calls, authorizeCall{principal, verb, resource})
   301  	return errors.New("just a test fake")
   302  }
   303  
   304  // inspired by https://stackoverflow.com/a/33008200
   305  func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
   306  	managerValue := reflect.ValueOf(manager)
   307  	m := managerValue.MethodByName(funcName)
   308  	if !m.IsValid() {
   309  		return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
   310  	}
   311  	in := make([]reflect.Value, len(params))
   312  	for i, param := range params {
   313  		in[i] = reflect.ValueOf(param)
   314  	}
   315  	out = m.Call(in)
   316  	return
   317  }
   318  
   319  func allExportedMethods(subject interface{}) []string {
   320  	var methods []string
   321  	subjectType := reflect.TypeOf(subject)
   322  	for i := 0; i < subjectType.NumMethod(); i++ {
   323  		name := subjectType.Method(i).Name
   324  		if name[0] >= 'A' && name[0] <= 'Z' {
   325  			methods = append(methods, name)
   326  		}
   327  	}
   328  
   329  	return methods
   330  }