gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/ollama/client/client.go (about)

     1  // Copyright 2024 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // A simple `curl`-like HTTP client that prints metrics after the request.
    16  // All of its output is structured to be unambiguous even if stdout/stderr
    17  // is combined, as is the case for Kubernetes logs.
    18  // Useful for communicating with ollama.
    19  package main
    20  
    21  import (
    22  	"bytes"
    23  	"encoding/base64"
    24  	"encoding/json"
    25  	"flag"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"os"
    30  	"sort"
    31  	"time"
    32  )
    33  
    34  // Flags.
    35  var (
    36  	url            = flag.String("url", "", "HTTP request URL.")
    37  	method         = flag.String("method", "GET", "HTTP request method (GET or POST).")
    38  	postDataBase64 = flag.String("post_base64", "", "HTTP request POST data in base64 format; ignored for GET requests.")
    39  	timeout        = flag.Duration("timeout", 0, "HTTP request timeout; 0 for no timeout.")
    40  )
    41  
    42  // bufSize is the size of buffers used for HTTP requests and responses.
    43  const bufSize = 1024 * 1024 // 1MiB
    44  
    45  // fatalf crashes the program with a given error message.
    46  func fatalf(format string, values ...any) {
    47  	fmt.Fprintf(os.Stderr, "FATAL: "+format+"\n", values...)
    48  	os.Exit(1)
    49  }
    50  
    51  // Metrics contains the request metrics to export to JSON.
    52  // This is parsed by the ollama library at `test/gpu/ollama/ollama.go`.
    53  type Metrics struct {
    54  	// ProgramStarted is the time when the program started.
    55  	ProgramStarted time.Time `json:"program_started"`
    56  	// RequestSent is the time when the HTTP request was sent.
    57  	RequestSent time.Time `json:"request_sent"`
    58  	// ResponseReceived is the time when the HTTP response headers were received.
    59  	ResponseReceived time.Time `json:"response_received"`
    60  	// FirstByteRead is the time when the first HTTP response body byte was read.
    61  	FirstByteRead time.Time `json:"first_byte_read"`
    62  	// LastByteRead is the time when the last HTTP response body byte was read.
    63  	LastByteRead time.Time `json:"last_byte_read"`
    64  }
    65  
    66  func main() {
    67  	var metrics Metrics
    68  	metrics.ProgramStarted = time.Now()
    69  	flag.Parse()
    70  	if *url == "" {
    71  		fatalf("--url is required")
    72  	}
    73  	client := http.Client{
    74  		Transport: &http.Transport{
    75  			MaxIdleConns:    1,
    76  			IdleConnTimeout: *timeout,
    77  			ReadBufferSize:  bufSize,
    78  			WriteBufferSize: bufSize,
    79  		},
    80  		Timeout: *timeout,
    81  	}
    82  	var request *http.Request
    83  	var err error
    84  	switch *method {
    85  	case "GET":
    86  		request, err = http.NewRequest("GET", *url, nil)
    87  	case "POST":
    88  		postData, postDataErr := base64.StdEncoding.DecodeString(*postDataBase64)
    89  		if postDataErr != nil {
    90  			fatalf("cannot decode POST data: %v", postDataErr)
    91  		}
    92  		request, err = http.NewRequest("POST", *url, bytes.NewBuffer(postData))
    93  	default:
    94  		err = fmt.Errorf("unknown method %q", *method)
    95  	}
    96  	if err != nil {
    97  		fatalf("cannot create request: %v", err)
    98  	}
    99  	readBuf := make([]byte, bufSize)
   100  	orderedReqHeaders := make([]string, 0, len(request.Header))
   101  	for k := range request.Header {
   102  		orderedReqHeaders = append(orderedReqHeaders, k)
   103  	}
   104  	sort.Strings(orderedReqHeaders)
   105  	for _, k := range orderedReqHeaders {
   106  		for _, v := range request.Header[k] {
   107  			fmt.Fprintf(os.Stderr, "REQHEADER: %s: %s\n", k, v)
   108  		}
   109  	}
   110  	metrics.RequestSent = time.Now()
   111  	resp, err := client.Do(request)
   112  	metrics.ResponseReceived = time.Now()
   113  	if err != nil {
   114  		fatalf("cannot make request: %v", err)
   115  	}
   116  	gotFirstByte := false
   117  	for {
   118  		n, err := resp.Body.Read(readBuf)
   119  		if n > 0 {
   120  			if !gotFirstByte {
   121  				metrics.FirstByteRead = time.Now()
   122  				gotFirstByte = true
   123  			}
   124  			fmt.Printf("BODY: %q\n", string(readBuf[:n]))
   125  		}
   126  		if err == io.EOF {
   127  			metrics.LastByteRead = time.Now()
   128  			break
   129  		}
   130  		if err != nil {
   131  			fatalf("cannot read response body: %v", err)
   132  		}
   133  	}
   134  	if err := resp.Body.Close(); err != nil {
   135  		fatalf("cannot close response body: %v", err)
   136  	}
   137  	orderedRespHeaders := make([]string, 0, len(resp.Header))
   138  	for k := range resp.Header {
   139  		orderedRespHeaders = append(orderedRespHeaders, k)
   140  	}
   141  	sort.Strings(orderedRespHeaders)
   142  	for _, k := range orderedRespHeaders {
   143  		for _, v := range resp.Header[k] {
   144  			fmt.Fprintf(os.Stderr, "RESPHEADER: %s: %s\n", k, v)
   145  		}
   146  	}
   147  	metricsBytes, err := json.Marshal(&metrics)
   148  	if err != nil {
   149  		fatalf("cannot marshal metrics: %v", err)
   150  	}
   151  	fmt.Fprintf(os.Stderr, "STATS: %s\n", string(metricsBytes))
   152  }