gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/ollama/client/client.go (about) 1 // Copyright 2024 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // A simple `curl`-like HTTP client that prints metrics after the request. 16 // All of its output is structured to be unambiguous even if stdout/stderr 17 // is combined, as is the case for Kubernetes logs. 18 // Useful for communicating with ollama. 19 package main 20 21 import ( 22 "bytes" 23 "encoding/base64" 24 "encoding/json" 25 "flag" 26 "fmt" 27 "io" 28 "net/http" 29 "os" 30 "sort" 31 "time" 32 ) 33 34 // Flags. 35 var ( 36 url = flag.String("url", "", "HTTP request URL.") 37 method = flag.String("method", "GET", "HTTP request method (GET or POST).") 38 postDataBase64 = flag.String("post_base64", "", "HTTP request POST data in base64 format; ignored for GET requests.") 39 timeout = flag.Duration("timeout", 0, "HTTP request timeout; 0 for no timeout.") 40 ) 41 42 // bufSize is the size of buffers used for HTTP requests and responses. 43 const bufSize = 1024 * 1024 // 1MiB 44 45 // fatalf crashes the program with a given error message. 46 func fatalf(format string, values ...any) { 47 fmt.Fprintf(os.Stderr, "FATAL: "+format+"\n", values...) 48 os.Exit(1) 49 } 50 51 // Metrics contains the request metrics to export to JSON. 52 // This is parsed by the ollama library at `test/gpu/ollama/ollama.go`. 53 type Metrics struct { 54 // ProgramStarted is the time when the program started. 55 ProgramStarted time.Time `json:"program_started"` 56 // RequestSent is the time when the HTTP request was sent. 57 RequestSent time.Time `json:"request_sent"` 58 // ResponseReceived is the time when the HTTP response headers were received. 59 ResponseReceived time.Time `json:"response_received"` 60 // FirstByteRead is the time when the first HTTP response body byte was read. 61 FirstByteRead time.Time `json:"first_byte_read"` 62 // LastByteRead is the time when the last HTTP response body byte was read. 63 LastByteRead time.Time `json:"last_byte_read"` 64 } 65 66 func main() { 67 var metrics Metrics 68 metrics.ProgramStarted = time.Now() 69 flag.Parse() 70 if *url == "" { 71 fatalf("--url is required") 72 } 73 client := http.Client{ 74 Transport: &http.Transport{ 75 MaxIdleConns: 1, 76 IdleConnTimeout: *timeout, 77 ReadBufferSize: bufSize, 78 WriteBufferSize: bufSize, 79 }, 80 Timeout: *timeout, 81 } 82 var request *http.Request 83 var err error 84 switch *method { 85 case "GET": 86 request, err = http.NewRequest("GET", *url, nil) 87 case "POST": 88 postData, postDataErr := base64.StdEncoding.DecodeString(*postDataBase64) 89 if postDataErr != nil { 90 fatalf("cannot decode POST data: %v", postDataErr) 91 } 92 request, err = http.NewRequest("POST", *url, bytes.NewBuffer(postData)) 93 default: 94 err = fmt.Errorf("unknown method %q", *method) 95 } 96 if err != nil { 97 fatalf("cannot create request: %v", err) 98 } 99 readBuf := make([]byte, bufSize) 100 orderedReqHeaders := make([]string, 0, len(request.Header)) 101 for k := range request.Header { 102 orderedReqHeaders = append(orderedReqHeaders, k) 103 } 104 sort.Strings(orderedReqHeaders) 105 for _, k := range orderedReqHeaders { 106 for _, v := range request.Header[k] { 107 fmt.Fprintf(os.Stderr, "REQHEADER: %s: %s\n", k, v) 108 } 109 } 110 metrics.RequestSent = time.Now() 111 resp, err := client.Do(request) 112 metrics.ResponseReceived = time.Now() 113 if err != nil { 114 fatalf("cannot make request: %v", err) 115 } 116 gotFirstByte := false 117 for { 118 n, err := resp.Body.Read(readBuf) 119 if n > 0 { 120 if !gotFirstByte { 121 metrics.FirstByteRead = time.Now() 122 gotFirstByte = true 123 } 124 fmt.Printf("BODY: %q\n", string(readBuf[:n])) 125 } 126 if err == io.EOF { 127 metrics.LastByteRead = time.Now() 128 break 129 } 130 if err != nil { 131 fatalf("cannot read response body: %v", err) 132 } 133 } 134 if err := resp.Body.Close(); err != nil { 135 fatalf("cannot close response body: %v", err) 136 } 137 orderedRespHeaders := make([]string, 0, len(resp.Header)) 138 for k := range resp.Header { 139 orderedRespHeaders = append(orderedRespHeaders, k) 140 } 141 sort.Strings(orderedRespHeaders) 142 for _, k := range orderedRespHeaders { 143 for _, v := range resp.Header[k] { 144 fmt.Fprintf(os.Stderr, "RESPHEADER: %s: %s\n", k, v) 145 } 146 } 147 metricsBytes, err := json.Marshal(&metrics) 148 if err != nil { 149 fatalf("cannot marshal metrics: %v", err) 150 } 151 fmt.Fprintf(os.Stderr, "STATS: %s\n", string(metricsBytes)) 152 }