gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/ollama/ollama.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package ollama provides an Ollama API client. 16 package ollama 17 18 import ( 19 "bytes" 20 "context" 21 "encoding/base64" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 "math" 27 "sort" 28 "strconv" 29 "strings" 30 "time" 31 32 "gvisor.dev/gvisor/pkg/test/dockerutil" 33 "gvisor.dev/gvisor/pkg/test/testutil" 34 ) 35 36 const ( 37 // Port is the port used by the ollama server. 38 Port = 11434 39 40 // curtQuery is a query that should result in a very curt response. 41 curtQuery = `Please reply with the single word: "Hello". Do not reply with any other word.` 42 ) 43 44 // Ollama is an ollama client. 45 type Ollama struct { 46 // server is used to perform requests against the server. 47 server Server 48 49 // logger is used to log. 50 logger testutil.Logger 51 52 // ModelNames is the list of available model names. 53 ModelNames []string 54 55 // cheapModels is a list of models that are known to be cheap. 56 // A caller may set this to make forcefully unloading a model quicker. 57 cheapModels []*Model 58 59 // HasGPU is set depending on whether the LLM has GPU access. 60 // ollama supports running both on CPU and GPU, and detects this 61 // by spawning nvidia-smi. 62 HasGPU bool 63 } 64 65 // Server performs requests against an ollama server. 66 type Server interface { 67 // InstrumentedRequest performs an instrumented HTTP request against the 68 // ollama server, using the `gpu/ollama_client` ollama image. 69 // `argvFn` takes in a `protocol://host:port` string and returns a 70 // command-line to use for making an instrumented HTTP request against the 71 // ollama server. 72 // InstrumentedRequest should return the logs from the request container. 73 InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error) 74 75 // Logs retrieves logs from the server. 76 Logs(ctx context.Context) (string, error) 77 } 78 79 // New starts a new Ollama server in the given container, 80 // then waits for it to serve and returns the client. 81 func New(ctx context.Context, server Server, logger testutil.Logger) (*Ollama, error) { 82 started := time.Now() 83 llm := &Ollama{ 84 logger: logger, 85 server: server, 86 } 87 88 // Wait until serving. 89 if err := llm.WaitUntilServing(ctx); err != nil { 90 return nil, fmt.Errorf("ollama did not come up for serving: %w", err) 91 } 92 logger.Logf("Ollama serving API requests after %v", time.Since(started)) 93 94 // Get list of model names. 95 modelNames, err := llm.listModelNames(ctx) 96 if err != nil { 97 return nil, fmt.Errorf("could not list model names: %w", err) 98 } 99 if len(modelNames) == 0 { 100 return nil, errors.New("no models available") 101 } 102 llm.ModelNames = modelNames 103 logger.Logf("Available ollama model names: %v (loaded %v since container start)", modelNames, time.Since(started)) 104 105 // Load the first model. 106 // This is necessary to force ollama to load a model, without which 107 // we cannot detect if it is using the GPU or not. 108 // This may fail during the process of loading the first model, so we keep 109 // iterating for a while. 110 _, err = llm.WarmModel(ctx, &Model{Name: llm.ModelNames[0]}, 1*time.Millisecond, false) 111 if err != nil { 112 return nil, fmt.Errorf("could not load first model %q: %w", llm.ModelNames[0], err) 113 } 114 logger.Logf("Loaded first ollama model %q (%v since container start)", llm.ModelNames[0], time.Since(started)) 115 116 // Now go over the logs and check if the GPU was used. 117 logs, err := llm.server.Logs(ctx) 118 if err != nil { 119 return nil, fmt.Errorf("could not get logs: %w", err) 120 } 121 switch { 122 case strings.Contains(logs, "no GPU detected"): 123 llm.HasGPU = false 124 case strings.Contains(logs, "Nvidia GPU detected"): 125 llm.HasGPU = true 126 default: 127 return nil, fmt.Errorf("cannot determine whether ollama is using GPU from logs:\n%s", logs) 128 } 129 logger.Logf("Ollama successfully initialized in a total of %v", time.Since(started)) 130 return llm, nil 131 } 132 133 // SetCheapModels can be used to inform this Ollama client as to the list of 134 // models it can use that are known to be cheap. 135 // This is useful when forcefully unloading models by swapping them with 136 // another one, to ensure that the one it is being swapped with is small. 137 // Therefore, there should be at least two models specified here. 138 func (llm *Ollama) SetCheapModels(cheapModels []*Model) { 139 llm.cheapModels = cheapModels 140 } 141 142 // dockerServer implements `Server`. It interfaces with an ollama server 143 // running in a local Docker container. 144 type dockerServer struct { 145 container *dockerutil.Container 146 logger testutil.Logger 147 } 148 149 // NewDocker returns a new Ollama client talking to an Ollama server that runs 150 // in a local Docker container. 151 func NewDocker(ctx context.Context, cont *dockerutil.Container, logger testutil.Logger) (*Ollama, error) { 152 opts := dockerutil.GPURunOpts() 153 opts.Image = "gpu/ollama" 154 started := time.Now() 155 if err := cont.Spawn(ctx, opts); err != nil { 156 return nil, fmt.Errorf("could not start ollama: %v", err) 157 } 158 logger.Logf("Ollama container started after %v", time.Since(started)) 159 ds := &dockerServer{ 160 container: cont, 161 logger: logger, 162 } 163 return New(ctx, ds, logger) 164 } 165 166 // InstrumentedRequest implements `Server.InstrumentedRequest`. 167 func (ds *dockerServer) InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error) { 168 const ollamaHost = "llm" 169 cmd := argvFn(fmt.Sprintf("http://%s:%d", ollamaHost, Port)) 170 out, err := dockerutil.MakeContainer(ctx, ds.logger).Run(ctx, dockerutil.RunOpts{ 171 Image: "gpu/ollama/client", 172 Links: []string{ds.container.MakeLink(ollamaHost)}, 173 }, cmd...) 174 if err != nil { 175 if out != "" { 176 return []byte(out), fmt.Errorf("command %q failed (%w): %v", strings.Join(cmd, " "), err, out) 177 } 178 return nil, fmt.Errorf("could not run command %q: %w", strings.Join(cmd, " "), err) 179 } 180 return []byte(out), nil 181 } 182 183 // Logs implements `Server.Logs`. 184 func (ds *dockerServer) Logs(ctx context.Context) (string, error) { 185 return ds.container.Logs(ctx) 186 } 187 188 // ResponseMetrics are HTTP request metrics from an ollama API query. 189 // These is the same JSON struct as defined in 190 // `images/gpu/ollama/client/client.go`. 191 type ResponseMetrics struct { 192 // ProgramStarted is the time when the program started. 193 ProgramStarted time.Time `json:"program_started"` 194 // RequestSent is the time when the HTTP request was sent. 195 RequestSent time.Time `json:"request_sent"` 196 // ResponseReceived is the time when the HTTP response headers were received. 197 ResponseReceived time.Time `json:"response_received"` 198 // FirstByteRead is the time when the first HTTP response body byte was read. 199 FirstByteRead time.Time `json:"first_byte_read"` 200 // LastByteRead is the time when the last HTTP response body byte was read. 201 LastByteRead time.Time `json:"last_byte_read"` 202 } 203 204 // apiResponse represents a JSON response from the ollama API. 205 type apiResponse[T any] struct { 206 // Objects is the list of JSON objects in the response. 207 Objects []*T 208 // Metrics contains HTTP response metrics. 209 Metrics ResponseMetrics 210 } 211 212 // Obj returns the first object in the response, if there is a singular 213 // object in the response. 214 func (ar *apiResponse[T]) Obj() (*T, error) { 215 if len(ar.Objects) == 0 { 216 return nil, fmt.Errorf("no objects in response") 217 } 218 if len(ar.Objects) > 1 { 219 return nil, fmt.Errorf("multiple objects in response") 220 } 221 return ar.Objects[0], nil 222 } 223 224 // makeAPIResponse decodes a raw response from an instrumented HTTP request 225 // into an `apiResponse` with deserialized JSON objects. 226 func makeAPIResponse[T any](rawResponse []byte) (*apiResponse[T], error) { 227 var respBytes bytes.Buffer 228 var resp apiResponse[T] 229 for _, line := range strings.Split(string(rawResponse), "\n") { 230 line = strings.TrimSpace(line) 231 if line == "" { 232 continue 233 } 234 colonIndex := strings.Index(line, ":") 235 if colonIndex == -1 { 236 return nil, fmt.Errorf("malformed line: %q", line) 237 } 238 data := strings.TrimSpace(line[colonIndex+1:]) 239 switch line[:colonIndex] { 240 case "FATAL": 241 return nil, fmt.Errorf("request failed: %s", data) 242 case "REQHEADER", "RESPHEADER": 243 // Do nothing with these. 244 case "BODY": 245 unquoted, err := strconv.Unquote(data) 246 if err != nil { 247 return nil, fmt.Errorf("malformed body line: %q", data) 248 } 249 respBytes.WriteString(unquoted) 250 case "STATS": 251 if err := json.Unmarshal([]byte(data), &resp.Metrics); err != nil { 252 return nil, fmt.Errorf("malformed stats line: %q", data) 253 } 254 default: 255 return nil, fmt.Errorf("malformed line: %q", line) 256 } 257 } 258 decoder := json.NewDecoder(&respBytes) 259 for { 260 var obj T 261 err := decoder.Decode(&obj) 262 if err == io.EOF { 263 break 264 } 265 if err != nil { 266 return nil, fmt.Errorf("malformed JSON response: %w", err) 267 } 268 resp.Objects = append(resp.Objects, &obj) 269 } 270 if len(resp.Objects) == 0 { 271 return nil, fmt.Errorf("response is empty") 272 } 273 leftoverBytes, err := io.ReadAll(decoder.Buffered()) 274 if err != nil && err != io.EOF { 275 return nil, fmt.Errorf("could not read leftover bytes: %w", err) 276 } 277 if leftover := strings.TrimSpace(string(leftoverBytes)); leftover != "" { 278 return nil, fmt.Errorf("unprocessed bytes in response: %q", leftover) 279 } 280 return &resp, nil 281 } 282 283 // instrumentedRequest makes an HTTP request to the ollama API. 284 // It returns the raw bytestream from the instrumented request logs. 285 func (llm *Ollama) instrumentedRequest(ctx context.Context, method, endpoint string, data []byte) ([]byte, error) { 286 if endpoint != "" && !strings.HasPrefix(endpoint, "/") { 287 return nil, fmt.Errorf("endpoint must be empty or start with '/', got %q", endpoint) 288 } 289 argvFn := func(hostPort string) []string { 290 argv := []string{ 291 "httpclient", 292 fmt.Sprintf("--method=%s", method), 293 fmt.Sprintf("--url=%s%s", hostPort, endpoint), 294 } 295 if data != nil { 296 argv = append(argv, fmt.Sprintf("--post_base64=%s", base64.StdEncoding.EncodeToString(data))) 297 } 298 if ctxDeadline, hasDeadline := ctx.Deadline(); hasDeadline { 299 argv = append(argv, fmt.Sprintf("--timeout=%v", time.Until(ctxDeadline))) 300 } 301 return argv 302 } 303 rawResponse, err := llm.server.InstrumentedRequest(ctx, argvFn) 304 if err != nil { 305 return nil, fmt.Errorf("%s: %w", endpoint, err) 306 } 307 return rawResponse, nil 308 } 309 310 // jsonGet performs a JSON HTTP GET request. 311 func jsonGet[Out any](ctx context.Context, llm *Ollama, endpoint string) (*apiResponse[Out], error) { 312 out, err := llm.instrumentedRequest(ctx, "GET", endpoint, nil) 313 if err != nil { 314 return nil, fmt.Errorf("GET %q failed: %w", endpoint, err) 315 } 316 return makeAPIResponse[Out](out) 317 } 318 319 // jsonPost performs a JSON HTTP POST request. 320 func jsonPost[In, Out any](ctx context.Context, llm *Ollama, endpoint string, input In) (*apiResponse[Out], error) { 321 query, err := json.Marshal(input) 322 if err != nil { 323 return nil, fmt.Errorf("could not marshal input %v: %w", input, err) 324 } 325 out, err := llm.instrumentedRequest(ctx, "POST", endpoint, query) 326 if err != nil { 327 return nil, fmt.Errorf("POST %q %v failed: %w", endpoint, string(query), err) 328 } 329 return makeAPIResponse[Out](out) 330 } 331 332 // listModelNames lists the available model names. 333 func (llm *Ollama) listModelNames(ctx context.Context) ([]string, error) { 334 type model struct { 335 Name string `json:"name"` 336 ModifiedAt string `json:"modified_at"` 337 Size int `json:"size"` 338 } 339 type modelsList struct { 340 Models []model `json:"models"` 341 } 342 modelsResp, err := jsonGet[modelsList](ctx, llm, "/api/tags") 343 if err != nil { 344 return nil, err 345 } 346 models, err := modelsResp.Obj() 347 if err != nil { 348 return nil, fmt.Errorf("malformed model tags response: %w", err) 349 } 350 modelNames := make([]string, len(models.Models)) 351 for i, m := range models.Models { 352 modelNames[i] = m.Name 353 } 354 return modelNames, nil 355 } 356 357 // WaitUntilServing waits until ollama is serving, or the context expires. 358 func (llm *Ollama) WaitUntilServing(ctx context.Context) error { 359 for ctx.Err() == nil { 360 out, err := llm.instrumentedRequest(ctx, "GET", "/", nil) 361 if err != nil { 362 continue 363 } 364 if strings.Contains(string(out), "Ollama is running") { 365 return nil 366 } 367 } 368 return fmt.Errorf("ollama did not respond: %w", ctx.Err()) 369 } 370 371 // Model encodes a model and options for it. 372 type Model struct { 373 // Name is the name of the ollama model, e.g. "codellama:7b". 374 Name string 375 376 // Options maps parameter names to JSON-compatible values. 377 Options map[string]any 378 } 379 380 // String returns the model's name. 381 func (m *Model) String() string { 382 return m.Name 383 } 384 385 // modelTemperatureOption is the temperature option that most models have 386 // which controls how free they are from deviating from their most-likely 387 // token chain. 388 const modelTemperatureOption = "temperature" 389 390 // RaiseTemperature increases the "temperature" option of the model, 391 // if any. 392 func (m *Model) RaiseTemperature() { 393 temp, ok := m.Options[modelTemperatureOption] 394 if !ok { 395 temp = float64(0.0) 396 } 397 if m.Options == nil { 398 m.Options = map[string]any{} 399 } 400 m.Options[modelTemperatureOption] = min(1.0, temp.(float64)*2+.025) 401 } 402 403 // Copy returns a copy of the model. 404 func (m *Model) Copy() *Model { 405 modelCopy := *m 406 modelCopy.Options = make(map[string]any, len(m.Options)) 407 for k, v := range m.Options { 408 modelCopy.Options[k] = v 409 } 410 return &modelCopy 411 } 412 413 // ZeroTemperatureModel returns a Model with the given name and an initial 414 // temperature setting of zero. This setting allows for consistent settings. 415 func ZeroTemperatureModel(name string) *Model { 416 return &Model{ 417 Name: name, 418 Options: map[string]any{ 419 modelTemperatureOption: 0.0, 420 }, 421 } 422 } 423 424 // Prompt is an ollama prompt. 425 type Prompt struct { 426 // Model is the model to query. 427 Model *Model 428 429 // If set, keep the model alive in memory for the given duration after this 430 // prompt is answered. A zero duration will use the ollama default (a few 431 // minutes). Note that model unloading is asynchronous, so the model will 432 // not be fully unloaded after only `KeepModelAlive` beyond prompt response. 433 KeepModelAlive time.Duration 434 435 // Query is the prompt string. 436 // Common leading whitespace will be removed. 437 Query string 438 439 // images is a set of attached images. 440 // Use AddImage to add an image. 441 images [][]byte 442 443 // Context is the conversational context to follow up on, if any. 444 // This is returned from `Response`. 445 Context ConversationContext 446 } 447 448 // AddImage attaches an image to the prompt. 449 // Returns itself for chainability. 450 func (p *Prompt) AddImage(data []byte) *Prompt { 451 p.images = append(p.images, data) 452 return p 453 } 454 455 // CleanQuery removes common whitespace from query lines, and all 456 // leading/ending whitespace-only lines. 457 // It is useful to be able to specify query string as indented strings 458 // without breaking visual continuity in Go code. 459 // For example (where dots are spaces): 460 // 461 // """\n 462 // ..The Quick Brown Fox\n 463 // ..Jumps Over\n 464 // ....The Lazy Dog\n 465 // .""" 466 // 467 // becomes: 468 // 469 // ""The Quick Brown Fox\n 470 // Jumps Over\n 471 // ..The Lazy Dog""" 472 func (p *Prompt) CleanQuery() string { 473 lines := strings.Split(p.Query, "\n") 474 475 // Trim lines at the beginning and end that are only whitespace. 476 trimmedLines := make([]string, 0, len(lines)) 477 startedNonWhitespace := false 478 var block []string 479 for _, line := range lines { 480 trimmedLine := strings.TrimSpace(line) 481 if !startedNonWhitespace && trimmedLine != "" { 482 startedNonWhitespace = true 483 } 484 if startedNonWhitespace { 485 block = append(block, line) 486 } 487 if trimmedLine != "" { 488 trimmedLines = append(trimmedLines, block...) 489 block = block[:0] 490 } 491 } 492 493 // Find longest common whitespace prefix. 494 if len(trimmedLines) == 0 { 495 return "" 496 } 497 trimmedFirstLine := strings.TrimSpace(trimmedLines[0]) 498 common := []rune(trimmedLines[0][:strings.Index(trimmedLines[0], trimmedFirstLine)]) 499 for ; len(common) > 0; common = common[:len(common)-1] { 500 allMatch := true 501 for _, line := range trimmedLines[1:] { 502 if strings.TrimSpace(line) == "" { 503 continue // Ignore whitespace-only or empty lines. 504 } 505 if !strings.HasPrefix(line, string(common)) { 506 allMatch = false 507 break 508 } 509 } 510 if allMatch { 511 break 512 } 513 } 514 515 // Remove it. 516 if len(common) > 0 { 517 for i, line := range trimmedLines { 518 trimmedLines[i] = strings.TrimPrefix(line, string(common)) 519 } 520 } 521 522 return strings.Join(trimmedLines, "\n") 523 } 524 525 // String returns a human-friendly string representing this prompt. 526 func (p *Prompt) String() string { 527 return fmt.Sprintf("[%v] %s", p.Model, p.CleanQuery()) 528 } 529 530 // WithHotterModel returns a copy of this prompt with the same model having 531 // a higher temperature. 532 func (p *Prompt) WithHotterModel() *Prompt { 533 promptCopy := *p 534 promptCopy.Model = p.Model.Copy() 535 promptCopy.Model.RaiseTemperature() 536 return &promptCopy 537 } 538 539 // PromptJSON encodes the JSON data for a query. 540 type PromptJSON struct { 541 Model string `json:"model"` 542 Prompt string `json:"prompt,omitempty"` 543 Images []string `json:"images"` 544 Stream bool `json:"stream"` 545 Context ConversationContext `json:"context"` 546 Options map[string]any `json:"options"` 547 KeepAlive string `json:"keep_alive,omitempty"` 548 } 549 550 // json encodes this prompt to the JSON format expected by Ollama. 551 func (p *Prompt) json() PromptJSON { 552 keepAlive := "" 553 if p.KeepModelAlive != 0 { 554 keepAlive = p.KeepModelAlive.String() 555 } 556 images := make([]string, len(p.images)) 557 for i, image := range p.images { 558 images[i] = base64.StdEncoding.EncodeToString(image) 559 } 560 return PromptJSON{ 561 Model: p.Model.Name, 562 Prompt: p.CleanQuery(), 563 Images: images, 564 Stream: true, 565 Context: p.Context, 566 Options: p.Model.Options, 567 KeepAlive: keepAlive, 568 } 569 } 570 571 // ResponseJSON is the JSON-format response from ollama about a prompt. 572 // Note that in `streamed` mode, the `Response` field contains a single token. 573 // To recover the whole response, all `Response` fields must be concatenated 574 // until the last `ResponseJSON`, identified as such by the `Done` field. 575 type ResponseJSON struct { 576 Model string `json:"model"` 577 CreatedAt time.Time `json:"created_at"` 578 Response string `json:"response"` 579 Done bool `json:"done"` 580 TotalNanos int `json:"total_duration"` 581 LoadNanos int `json:"load_duration"` 582 EvalCount int `json:"eval_count"` 583 EvalNanos int `json:"eval_duration"` 584 PromptEvalCount int `json:"prompt_eval_count"` 585 PromptEvalNanos int `json:"prompt_eval_duration"` 586 Context ConversationContext `json:"context"` 587 } 588 589 // Response represents a response to a query from Ollama. 590 type Response struct { 591 data []*ResponseJSON 592 metrics ResponseMetrics 593 } 594 595 // Done returns whether the response was completely generated. 596 func (r *Response) Done() bool { 597 if len(r.data) == 0 { 598 return false 599 } 600 return r.data[len(r.data)-1].Done 601 } 602 603 // NumTokens returns the number of tokens in the response. 604 func (r *Response) NumTokens() int { 605 return len(r.data) 606 } 607 608 // String returns the response text, if it is done. 609 func (r *Response) String() string { 610 if len(r.data) == 0 { 611 return "<EMPTY>" 612 } 613 var fullResponse strings.Builder 614 gotDone := false 615 for i, token := range r.data { 616 fullResponse.WriteString(token.Response) 617 if token.Done { 618 if i != len(r.data)-1 { 619 fullResponse.WriteString("<CORRUPT>") 620 } 621 gotDone = true 622 break 623 } 624 } 625 if !gotDone { 626 return "<NOT DONE>" 627 } 628 return fullResponse.String() 629 } 630 631 // Text returns the body of the response, if it is done. 632 func (r *Response) Text() string { 633 if !r.Done() { 634 return "" 635 } 636 return r.String() 637 } 638 639 // TimeToFirstToken returns the time it took between the request starting 640 // and the first token being received by the client. 641 func (r *Response) TimeToFirstToken() time.Duration { 642 if !r.Done() { 643 return -1 644 } 645 return r.metrics.FirstByteRead.Sub(r.metrics.RequestSent) 646 } 647 648 // TimeToLastToken returns the time it took between the request starting 649 // and the last token being received by the client. 650 func (r *Response) TimeToLastToken() time.Duration { 651 if !r.Done() { 652 return -1 653 } 654 return r.metrics.LastByteRead.Sub(r.metrics.RequestSent) 655 } 656 657 // tokenIntervals returns the time between each token generation. 658 func (r *Response) tokenIntervals() []time.Duration { 659 if !r.Done() || len(r.data) < 2 { 660 return nil 661 } 662 intervals := make([]time.Duration, len(r.data)-1) 663 for i := 0; i < len(r.data)-1; i++ { 664 intervals[i] = r.data[i+1].CreatedAt.Sub(r.data[i].CreatedAt) 665 } 666 return intervals 667 } 668 669 // OutputTokensPerSecond computes the average number of output tokens 670 // generated per second. 671 func (r *Response) OutputTokensPerSecond() float64 { 672 if !r.Done() || r.EvalDuration() == 0 { 673 return -1 674 } 675 return float64(r.data[len(r.data)-1].EvalCount) / float64(r.EvalDuration().Seconds()) 676 } 677 678 // TimePerOutputTokenAverage computes the average time to generate an output 679 // token. 680 func (r *Response) TimePerOutputTokenAverage() time.Duration { 681 if !r.Done() { 682 return -1 683 } 684 intervals := r.tokenIntervals() 685 var sum time.Duration 686 for _, interval := range intervals { 687 sum += interval 688 } 689 return sum / time.Duration(len(intervals)) 690 } 691 692 // TimePerOutputTokenQuantile computes a quantile of the time it takes to 693 // generate an output token. 694 func (r *Response) TimePerOutputTokenQuantile(quantile float64) time.Duration { 695 if quantile < 0.0 || quantile > 1.0 { 696 panic("quantile must be between 0.0 and 1.0 inclusively") 697 } 698 if !r.Done() || r.EvalDuration() == 0 { 699 return -1 700 } 701 intervals := r.tokenIntervals() 702 sort.Slice(intervals, func(i, j int) bool { return intervals[i] < intervals[j] }) 703 return intervals[int(quantile*float64(len(intervals)-1))] 704 } 705 706 // TokenGenerationStdDev returns the standard deviation of the time between 707 // token generations. 708 func (r *Response) TokenGenerationStdDev() time.Duration { 709 intervals := r.tokenIntervals() 710 if len(intervals) == 0 { 711 return -1 712 } 713 if len(intervals) == 1 { 714 return 0 715 } 716 717 var sum time.Duration 718 for _, interval := range intervals { 719 sum += interval 720 } 721 mean := sum / time.Duration(len(intervals)) 722 variance := 0.0 723 for _, interval := range intervals { 724 intervalMinusMean := float64((interval - mean).Nanoseconds()) 725 variance += intervalMinusMean * intervalMinusMean 726 } 727 variance = variance / float64(len(intervals)-1) 728 return time.Duration(math.Sqrt(variance)) * time.Nanosecond 729 } 730 731 // TotalDuration returns the total response generation time. 732 func (r *Response) TotalDuration() time.Duration { 733 if !r.Done() { 734 return time.Duration(0) 735 } 736 return time.Duration(r.data[len(r.data)-1].TotalNanos) * time.Nanosecond 737 } 738 739 // LoadDuration returns the load response generation time as reported 740 // by the ollama server. 741 func (r *Response) LoadDuration() time.Duration { 742 if !r.Done() { 743 return time.Duration(0) 744 } 745 return time.Duration(r.data[len(r.data)-1].LoadNanos) * time.Nanosecond 746 } 747 748 // EvalDuration returns the response evaluation time. 749 func (r *Response) EvalDuration() time.Duration { 750 if !r.Done() { 751 return time.Duration(0) 752 } 753 return time.Duration(r.data[len(r.data)-1].EvalNanos) * time.Nanosecond 754 } 755 756 // PromptEvalDuration returns the prompt evaluation time. 757 func (r *Response) PromptEvalDuration() time.Duration { 758 if !r.Done() { 759 return time.Duration(0) 760 } 761 return time.Duration(r.data[len(r.data)-1].PromptEvalNanos) * time.Nanosecond 762 } 763 764 // ConversationContext represents a conversational context. 765 // It is returned by a response and may be passed to a follow-up prompt. 766 type ConversationContext []int 767 768 // withServerLogsErr adds server logs to `err` if possible. 769 func (llm *Ollama) withServerLogsErr(ctx context.Context, err error) error { 770 if err == nil { 771 return nil 772 } 773 if ctx.Err() != nil { 774 return fmt.Errorf("%w (+ context err: %v)", err, ctx.Err()) 775 } 776 serverLogs, logsErr := llm.server.Logs(ctx) 777 if logsErr != nil { 778 return fmt.Errorf("%w (could not get server logs: %v)", err, logsErr) 779 } 780 if serverLogs != "" { 781 return fmt.Errorf("%w; ollama server logs:\n%v\n(end of ollama server logs)", err, serverLogs) 782 } 783 return fmt.Errorf("%w (server logs are empty)", err) 784 } 785 786 // getReplacementModel picks an available model other than `model`. 787 // It tries to find a one that is marked cheap if possible. 788 func (llm *Ollama) getReplacementModel(model *Model) (*Model, error) { 789 for _, cheapModel := range llm.cheapModels { 790 if cheapModel.Name != model.Name { 791 return cheapModel, nil 792 } 793 } 794 for _, otherModelName := range llm.ModelNames { 795 if otherModelName != model.Name { 796 return ZeroTemperatureModel(otherModelName), nil 797 } 798 } 799 return nil, fmt.Errorf("cannot find a replacement model to load instead of %q (available: %v; cheap: %v)", model.Name, llm.ModelNames, llm.cheapModels) 800 } 801 802 // ModelLoadStats holds metrics about the model loading process. 803 type ModelLoadStats struct { 804 // ClientReportedDuration is the duration to load the model as perceived 805 // by the client, measured by HTTP client metrics. 806 ClientReportedDuration time.Duration 807 } 808 809 // WarmModel pre-warms a model in memory and keeps it warm for `keepWarmFor`. 810 // If `unloadFirst` is true, another model will be loaded before loading the 811 // requested model. This ensures that the model was loaded from a cold state. 812 func (llm *Ollama) WarmModel(ctx context.Context, model *Model, keepWarmFor time.Duration, unloadFirst bool) (*ModelLoadStats, error) { 813 if keepWarmFor <= 0 { 814 return nil, fmt.Errorf("keepWarmFor must be strictly positive, got %v", keepWarmFor) 815 } 816 if unloadFirst { 817 replacementModel, err := llm.getReplacementModel(model) 818 if err != nil { 819 return nil, fmt.Errorf("cannot find a replacement model to load instead of %q to forcefully unload it: %w", model.Name, err) 820 } 821 unloadCtx, unloadCancel := context.WithTimeout(ctx, 3*time.Minute) 822 _, err = llm.Prompt(unloadCtx, &Prompt{Model: replacementModel, KeepModelAlive: 1 * time.Millisecond}) 823 unloadCancel() 824 if err != nil { 825 return nil, llm.withServerLogsErr(ctx, fmt.Errorf("unload prompt for replacement model %s failed: %w", replacementModel.Name, err)) 826 } 827 select { // Wait for the model to get unloaded. Unfortunately there isn't a great way to do this but to sleep. 828 case <-time.After(20 * time.Second): 829 case <-ctx.Done(): 830 } 831 } 832 resp, err := llm.Prompt(ctx, &Prompt{Model: model, KeepModelAlive: keepWarmFor}) 833 if err != nil { 834 return nil, llm.withServerLogsErr(ctx, fmt.Errorf("warmup prompt for model %s failed: %w", model.Name, err)) 835 } 836 return &ModelLoadStats{ 837 ClientReportedDuration: resp.metrics.LastByteRead.Sub(resp.metrics.RequestSent), 838 }, nil 839 } 840 841 // Prompt returns the result of prompting the given `model` with `prompt`. 842 func (llm *Ollama) Prompt(ctx context.Context, prompt *Prompt) (*Response, error) { 843 resp, err := jsonPost[PromptJSON, ResponseJSON](ctx, llm, "/api/generate", prompt.json()) 844 if err != nil { 845 return nil, llm.withServerLogsErr(ctx, fmt.Errorf("prompt (%s %q) request failed: %w", prompt.Model.Name, prompt.CleanQuery(), err)) 846 } 847 return &Response{data: resp.Objects, metrics: resp.Metrics}, nil 848 } 849 850 // PromptUntil repeatedly issues a prompt until `iterate` returns a nil error. 851 // `iterate` may optionally return an updated `Prompt` which will be used to 852 // follow up. This is useful to work around the flakiness of LLMs in tests. 853 func (llm *Ollama) PromptUntil(ctx context.Context, prompt *Prompt, iterate func(*Prompt, *Response) (*Prompt, error)) (*Response, error) { 854 var lastResponse *Response 855 var lastError error 856 attempts := 0 857 for ctx.Err() == nil { 858 response, err := llm.Prompt(ctx, prompt) 859 if err != nil { 860 return nil, fmt.Errorf("prompt request failed: %w", err) 861 } 862 attempts++ 863 newPrompt, err := iterate(prompt, response) 864 if err == nil { 865 return response, nil 866 } 867 if newPrompt != nil { 868 prompt = newPrompt 869 } 870 lastResponse = response 871 lastError = err 872 } 873 return nil, fmt.Errorf("response %q (attempt #%d with prompt %v) did not match predicate: %v", lastResponse, attempts, prompt, lastError) 874 }