github.com/weaviate/weaviate@v1.24.6/test/helper/modules/modules_helper.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package moduleshelper 13 14 import ( 15 "context" 16 "encoding/json" 17 "fmt" 18 "net/http" 19 "os" 20 "path/filepath" 21 "testing" 22 "time" 23 24 "cloud.google.com/go/storage" 25 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/test/helper" 29 graphqlhelper "github.com/weaviate/weaviate/test/helper/graphql" 30 "google.golang.org/api/googleapi" 31 "google.golang.org/api/option" 32 ) 33 34 func EnsureClassExists(t *testing.T, className string, tenant string) { 35 query := fmt.Sprintf("{Aggregate{%s", className) 36 if tenant != "" { 37 query += fmt.Sprintf("(tenant:%q)", tenant) 38 } 39 query += " { meta { count}}}}" 40 resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) 41 42 class := resp.Get("Aggregate", className).Result.([]interface{}) 43 require.Len(t, class, 1) 44 } 45 46 func EnsureCompressedVectorsRestored(t *testing.T, className string) { 47 query := fmt.Sprintf("{Get{%s(limit:1){_additional{vector}}}}", className) 48 resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) 49 50 class := resp.Get("Get", className).Result.([]interface{}) 51 require.Len(t, class, 1) 52 vecResp := class[0].(map[string]interface{})["_additional"].(map[string]interface{})["vector"].([]interface{}) 53 54 searchVec := graphqlhelper.Vec2String(graphqlhelper.ParseVec(t, vecResp)) 55 56 limit := 10 57 query = fmt.Sprintf( 58 "{Get{%s(nearVector:{vector:%s} limit:%d){_additional{vector}}}}", 59 className, searchVec, limit) 60 resp = graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) 61 class = resp.Get("Get", className).Result.([]interface{}) 62 require.Len(t, class, limit) 63 } 64 65 func GetClassCount(t *testing.T, className string, tenant string) int64 { 66 query := fmt.Sprintf("{Aggregate{%s", className) 67 if tenant != "" { 68 query += fmt.Sprintf("(tenant:%q)", tenant) 69 } 70 query += " { meta { count}}}}" 71 resp := graphqlhelper.AssertGraphQL(t, helper.RootAuth, query) 72 73 class := resp.Get("Aggregate", className).Result.([]interface{}) 74 require.Len(t, class, 1) 75 76 meta := class[0].(map[string]interface{})["meta"].(map[string]interface{}) 77 78 countPayload := meta["count"].(json.Number) 79 80 count, err := countPayload.Int64() 81 require.Nil(t, err) 82 83 return count 84 } 85 86 func CreateTestFiles(t *testing.T, dirPath string) []string { 87 count := 5 88 filePaths := make([]string, count) 89 var fileName string 90 91 for i := 0; i < count; i += 1 { 92 fileName = fmt.Sprintf("file_%d.db", i) 93 filePaths[i] = filepath.Join(dirPath, fileName) 94 file, err := os.Create(filePaths[i]) 95 if err != nil { 96 t.Fatalf("failed to create test file '%s': %s", fileName, err) 97 } 98 fmt.Fprintf(file, "This is content of db file named %s", fileName) 99 file.Close() 100 } 101 return filePaths 102 } 103 104 func CreateGCSBucket(ctx context.Context, t *testing.T, projectID, bucketName string) { 105 assert.EventuallyWithT(t, func(collect *assert.CollectT) { 106 client, err := storage.NewClient(ctx, option.WithoutAuthentication()) 107 require.Nil(t, err) 108 err = client.Bucket(bucketName).Create(ctx, projectID, nil) 109 gcsErr, ok := err.(*googleapi.Error) 110 if ok { 111 // the bucket persists from the previous test. 112 // if the bucket already exists, we can proceed 113 if gcsErr.Code == http.StatusConflict { 114 return 115 } 116 } 117 require.Nil(t, err) 118 }, 5*time.Second, 500*time.Millisecond) 119 } 120 121 func CreateAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { 122 assert.EventuallyWithT(t, func(collect *assert.CollectT) { 123 connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" 124 client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) 125 require.Nil(t, err) 126 127 _, err = client.CreateContainer(ctx, containerName, nil) 128 require.Nil(t, err) 129 }, 5*time.Second, 500*time.Millisecond) 130 } 131 132 func DeleteAzureContainer(ctx context.Context, t *testing.T, endpoint, containerName string) { 133 assert.EventuallyWithT(t, func(collect *assert.CollectT) { 134 connectionString := "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://%s/devstoreaccount1;" 135 client, err := azblob.NewClientFromConnectionString(fmt.Sprintf(connectionString, endpoint), nil) 136 require.Nil(t, err) 137 138 _, err = client.DeleteContainer(ctx, containerName, nil) 139 require.Nil(t, err) 140 }, 5*time.Second, 500*time.Millisecond) 141 }