k8s.io/test-infra@v0.0.0-20240520184403-27c6b4c223d8/experiment/ml/analyze/client.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"log"
    23  	"sync/atomic"
    24  	"time"
    25  
    26  	automl "cloud.google.com/go/automl/apiv1"
    27  	"cloud.google.com/go/automl/apiv1/automlpb"
    28  	"google.golang.org/api/option"
    29  )
    30  
    31  func defaultPredictionClient(ctx context.Context) (*predictionClient, error) {
    32  	return newPredictionClient(ctx, *projectID, *location, *model, *quotaProjectID)
    33  }
    34  
    35  func newPredictionClient(ctx context.Context, projectID, location, model, quotaProject string) (*predictionClient, error) {
    36  	var opts []option.ClientOption
    37  	if quotaProject != "" {
    38  		opts = append(opts, option.WithQuotaProject(quotaProject))
    39  	}
    40  	client, err := automl.NewPredictionClient(ctx, opts...)
    41  	if err != nil {
    42  		return nil, fmt.Errorf("create client: %w", err)
    43  	}
    44  
    45  	return &predictionClient{
    46  		client:    client,
    47  		modelName: modelName(projectID, location, model),
    48  		ch:        throttle(ctx, *qps**burst, time.Second/time.Duration(*qps), *warmup),
    49  	}, nil
    50  }
    51  
    52  type predictionClient struct {
    53  	client    *automl.PredictionClient
    54  	modelName string
    55  	requests  int64
    56  	ch        <-chan time.Time
    57  }
    58  
    59  func modelName(projectID, location, model string) string {
    60  	return fmt.Sprintf("projects/%s/locations/%s/models/%s", projectID, location, model)
    61  }
    62  
    63  func throttle(ctx context.Context, capacity int, wait time.Duration, warmup bool) <-chan time.Time {
    64  	ch := make(chan time.Time, capacity)
    65  	go func() {
    66  		tick := time.NewTicker(wait)
    67  		now := time.Now()
    68  		if warmup {
    69  			for len(ch) < cap(ch) {
    70  				select {
    71  				case <-ctx.Done():
    72  					return
    73  				case ch <- now:
    74  				}
    75  			}
    76  		}
    77  		for {
    78  			select {
    79  			case <-ctx.Done():
    80  				return
    81  			case now = <-tick.C:
    82  			}
    83  			select {
    84  			case <-ctx.Done():
    85  				return
    86  			case ch <- now:
    87  			}
    88  		}
    89  	}()
    90  	return ch
    91  }
    92  
    93  func (pc *predictionClient) predict(ctx context.Context, sentence string) (map[string]float32, error) {
    94  	select {
    95  	case <-ctx.Done():
    96  		return nil, ctx.Err()
    97  	case <-pc.ch:
    98  	}
    99  	n := atomic.AddInt64(&pc.requests, 1)
   100  	req := automlpb.PredictRequest{
   101  		Name: pc.modelName,
   102  		Payload: &automlpb.ExamplePayload{
   103  			Payload: &automlpb.ExamplePayload_TextSnippet{
   104  				TextSnippet: &automlpb.TextSnippet{
   105  					Content:  sentence,
   106  					MimeType: "text/plain",
   107  				},
   108  			},
   109  		},
   110  	}
   111  
   112  	resp, err := pc.client.Predict(ctx, &req)
   113  	if n%100 == 0 || err != nil {
   114  		log.Println("Prediction request", n, "remaining quota", len(pc.ch), err)
   115  	}
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	payloads := resp.GetPayload()
   120  	out := make(map[string]float32, len(payloads))
   121  	for _, payload := range payloads {
   122  		out[payload.GetDisplayName()] = payload.GetClassification().GetScore()
   123  	}
   124  	return out, nil
   125  }