github.com/weaviate/weaviate@v1.24.6/usecases/classification/authorization_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  // import (
    15  // 	"context"
    16  // 	"errors"
    17  // 	"fmt"
    18  // 	"reflect"
    19  // 	"testing"
    20  
    21  // 	"github.com/go-openapi/strfmt"
    22  // 	"github.com/weaviate/weaviate/entities/models"
    23  // 	"github.com/stretchr/testify/assert"
    24  // 	"github.com/stretchr/testify/require"
    25  // )
    26  
    27  // // A component-test like test suite that makes sure that every available UC is
    28  // // potentially protected with the Authorization plugin
    29  
    30  // func Test_Classifier_Authorization(t *testing.T) {
    31  
    32  // 	type testCase struct {
    33  // 		methodName       string
    34  // 		additionalArgs   []interface{}
    35  // 		expectedVerb     string
    36  // 		expectedResource string
    37  // 	}
    38  
    39  // 	tests := []testCase{
    40  // 		testCase{
    41  // 			methodName:       "Get",
    42  // 			additionalArgs:   []interface{}{strfmt.UUID("")},
    43  // 			expectedVerb:     "get",
    44  // 			expectedResource: "classifications/*",
    45  // 		},
    46  // 		testCase{
    47  // 			methodName:       "Schedule",
    48  // 			additionalArgs:   []interface{}{models.Classification{}},
    49  // 			expectedVerb:     "create",
    50  // 			expectedResource: "classifications/*",
    51  // 		},
    52  // 	}
    53  
    54  // 	t.Run("verify that a test for every public method exists", func(t *testing.T) {
    55  // 		testedMethods := make([]string, len(tests), len(tests))
    56  // 		for i, test := range tests {
    57  // 			testedMethods[i] = test.methodName
    58  // 		}
    59  
    60  // 		for _, method := range allExportedMethods(&Classifier{}) {
    61  // 			assert.Contains(t, testedMethods, method)
    62  // 		}
    63  // 	})
    64  
    65  // 	t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) {
    66  // 		principal := &models.Principal{}
    67  // 		// logger, _ := test.NewNullLogger()
    68  // 		for _, test := range tests {
    69  // 			authorizer := &authDenier{}
    70  // 			repo := &fakeClassificationRepo{}
    71  // 			vectorRepo := &fakeVectorRepoKNN{}
    72  // 			schemaGetter := &fakeSchemaGetter{}
    73  
    74  // 			classifier := New(schemaGetter, repo, vectorRepo, authorizer)
    75  
    76  // 			args := append([]interface{}{context.Background(), principal}, test.additionalArgs...)
    77  // 			out, _ := callFuncByName(classifier, test.methodName, args...)
    78  
    79  // 			require.Len(t, authorizer.calls, 1, "authorizer must be called")
    80  // 			assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
    81  // 				"execution must abort with authorizer error")
    82  // 			assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
    83  // 				authorizer.calls[0], "correct parameters must have been used on authorizer")
    84  // 		}
    85  // 	})
    86  // }
    87  
    88  // type authorizeCall struct {
    89  // 	principal *models.Principal
    90  // 	verb      string
    91  // 	resource  string
    92  // }
    93  
    94  // type authDenier struct {
    95  // 	calls []authorizeCall
    96  // }
    97  
    98  // func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
    99  // 	a.calls = append(a.calls, authorizeCall{principal, verb, resource})
   100  // 	return errors.New("just a test fake")
   101  // }
   102  
   103  // // inspired by https://stackoverflow.com/a/33008200
   104  // func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
   105  // 	managerValue := reflect.ValueOf(manager)
   106  // 	m := managerValue.MethodByName(funcName)
   107  // 	if !m.IsValid() {
   108  // 		return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
   109  // 	}
   110  // 	in := make([]reflect.Value, len(params))
   111  // 	for i, param := range params {
   112  // 		in[i] = reflect.ValueOf(param)
   113  // 	}
   114  // 	out = m.Call(in)
   115  // 	return
   116  // }
   117  
   118  // func allExportedMethods(subject interface{}) []string {
   119  // 	var methods []string
   120  // 	subjectType := reflect.TypeOf(subject)
   121  // 	for i := 0; i < subjectType.NumMethod(); i++ {
   122  // 		name := subjectType.Method(i).Name
   123  // 		if name[0] >= 'A' && name[0] <= 'Z' {
   124  // 			methods = append(methods, name)
   125  // 		}
   126  // 	}
   127  
   128  // 	return methods
   129  // }