github.com/weaviate/weaviate@v1.24.6/usecases/traverser/authorization_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package traverser
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"fmt"
    18  	"reflect"
    19  	"testing"
    20  
    21  	"github.com/sirupsen/logrus/hooks/test"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/entities/aggregation"
    25  	"github.com/weaviate/weaviate/entities/dto"
    26  	"github.com/weaviate/weaviate/entities/models"
    27  	"github.com/weaviate/weaviate/usecases/config"
    28  )
    29  
    30  // A component-test like test suite that makes sure that every available UC is
    31  // potentially protected with the Authorization plugin
    32  
    33  func Test_Traverser_Authorization(t *testing.T) {
    34  	type testCase struct {
    35  		methodName       string
    36  		additionalArgs   []interface{}
    37  		expectedVerb     string
    38  		expectedResource string
    39  	}
    40  
    41  	tests := []testCase{
    42  		{
    43  			methodName:       "GetClass",
    44  			additionalArgs:   []interface{}{dto.GetParams{}},
    45  			expectedVerb:     "get",
    46  			expectedResource: "traversal/*",
    47  		},
    48  
    49  		{
    50  			methodName:       "Aggregate",
    51  			additionalArgs:   []interface{}{&aggregation.Params{}},
    52  			expectedVerb:     "get",
    53  			expectedResource: "traversal/*",
    54  		},
    55  
    56  		{
    57  			methodName:       "Explore",
    58  			additionalArgs:   []interface{}{ExploreParams{}},
    59  			expectedVerb:     "get",
    60  			expectedResource: "traversal/*",
    61  		},
    62  	}
    63  
    64  	t.Run("verify that a test for every public method exists", func(t *testing.T) {
    65  		testedMethods := make([]string, len(tests))
    66  		for i, test := range tests {
    67  			testedMethods[i] = test.methodName
    68  		}
    69  
    70  		for _, method := range allExportedMethods(&Traverser{}) {
    71  			assert.Contains(t, testedMethods, method)
    72  		}
    73  	})
    74  
    75  	t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
    76  		principal := &models.Principal{}
    77  		logger, _ := test.NewNullLogger()
    78  		for _, test := range tests {
    79  			locks := &fakeLocks{}
    80  			authorizer := &authDenier{}
    81  			vectorRepo := &fakeVectorRepo{}
    82  			explorer := &fakeExplorer{}
    83  			schemaGetter := &fakeSchemaGetter{}
    84  
    85  			manager := NewTraverser(&config.WeaviateConfig{}, locks, logger, authorizer,
    86  				vectorRepo, explorer, schemaGetter, nil, nil, -1)
    87  
    88  			args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
    89  			out, _ := callFuncByName(manager, test.methodName, args...)
    90  
    91  			require.Len(t, authorizer.calls, 1, "authorizer must be called")
    92  			assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
    93  				"execution must abort with authorizer error")
    94  			assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
    95  				authorizer.calls[0], "correct parameters must have been used on authorizer")
    96  		}
    97  	})
    98  }
    99  
   100  type authorizeCall struct {
   101  	principal *models.Principal
   102  	verb      string
   103  	resource  string
   104  }
   105  
   106  type authDenier struct {
   107  	calls []authorizeCall
   108  }
   109  
   110  func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
   111  	a.calls = append(a.calls, authorizeCall{principal, verb, resource})
   112  	return errors.New("just a test fake")
   113  }
   114  
   115  // inspired by https://stackoverflow.com/a/33008200
   116  func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
   117  	managerValue := reflect.ValueOf(manager)
   118  	m := managerValue.MethodByName(funcName)
   119  	if !m.IsValid() {
   120  		return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
   121  	}
   122  	in := make([]reflect.Value, len(params))
   123  	for i, param := range params {
   124  		in[i] = reflect.ValueOf(param)
   125  	}
   126  	out = m.Call(in)
   127  	return
   128  }
   129  
   130  func allExportedMethods(subject interface{}) []string {
   131  	var methods []string
   132  	subjectType := reflect.TypeOf(subject)
   133  	for i := 0; i < subjectType.NumMethod(); i++ {
   134  		name := subjectType.Method(i).Name
   135  		if name[0] >= 'A' && name[0] <= 'Z' {
   136  			methods = append(methods, name)
   137  		}
   138  	}
   139  
   140  	return methods
   141  }