k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/experiment/ml/analyze/client.go (about) 1 /* 2 Copyright 2022 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package main 18 19 import ( 20 "context" 21 "fmt" 22 "log" 23 "sync/atomic" 24 "time" 25 26 automl "cloud.google.com/go/automl/apiv1" 27 "cloud.google.com/go/automl/apiv1/automlpb" 28 "google.golang.org/api/option" 29 ) 30 31 func defaultPredictionClient(ctx context.Context) (*predictionClient, error) { 32 return newPredictionClient(ctx, *projectID, *location, *model, *quotaProjectID) 33 } 34 35 func newPredictionClient(ctx context.Context, projectID, location, model, quotaProject string) (*predictionClient, error) { 36 var opts []option.ClientOption 37 if quotaProject != "" { 38 opts = append(opts, option.WithQuotaProject(quotaProject)) 39 } 40 client, err := automl.NewPredictionClient(ctx, opts...) 41 if err != nil { 42 return nil, fmt.Errorf("create client: %w", err) 43 } 44 45 return &predictionClient{ 46 client: client, 47 modelName: modelName(projectID, location, model), 48 ch: throttle(ctx, *qps**burst, time.Second/time.Duration(*qps), *warmup), 49 }, nil 50 } 51 52 type predictionClient struct { 53 client *automl.PredictionClient 54 modelName string 55 requests int64 56 ch <-chan time.Time 57 } 58 59 func modelName(projectID, location, model string) string { 60 return fmt.Sprintf("projects/%s/locations/%s/models/%s", projectID, location, model) 61 } 62 63 func throttle(ctx context.Context, capacity int, wait time.Duration, warmup bool) <-chan time.Time { 64 ch := make(chan time.Time, capacity) 65 go func() { 66 tick := time.NewTicker(wait) 67 now := time.Now() 68 if warmup { 69 for len(ch) < cap(ch) { 70 select { 71 case <-ctx.Done(): 72 return 73 case ch <- now: 74 } 75 } 76 } 77 for { 78 select { 79 case <-ctx.Done(): 80 return 81 case now = <-tick.C: 82 } 83 select { 84 case <-ctx.Done(): 85 return 86 case ch <- now: 87 } 88 } 89 }() 90 return ch 91 } 92 93 func (pc *predictionClient) predict(ctx context.Context, sentence string) (map[string]float32, error) { 94 select { 95 case <-ctx.Done(): 96 return nil, ctx.Err() 97 case <-pc.ch: 98 } 99 n := atomic.AddInt64(&pc.requests, 1) 100 req := automlpb.PredictRequest{ 101 Name: pc.modelName, 102 Payload: &automlpb.ExamplePayload{ 103 Payload: &automlpb.ExamplePayload_TextSnippet{ 104 TextSnippet: &automlpb.TextSnippet{ 105 Content: sentence, 106 MimeType: "text/plain", 107 }, 108 }, 109 }, 110 } 111 112 resp, err := pc.client.Predict(ctx, &req) 113 if n%100 == 0 || err != nil { 114 log.Println("Prediction request", n, "remaining quota", len(pc.ch), err) 115 } 116 if err != nil { 117 return nil, err 118 } 119 payloads := resp.GetPayload() 120 out := make(map[string]float32, len(payloads)) 121 for _, payload := range payloads { 122 out[payload.GetDisplayName()] = payload.GetClassification().GetScore() 123 } 124 return out, nil 125 }