gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/ollama/ollama.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ollama provides an Ollama API client.
    16  package ollama
    17  
    18  import (
    19  	"bytes"
    20  	"context"
    21  	"encoding/base64"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"math"
    27  	"sort"
    28  	"strconv"
    29  	"strings"
    30  	"time"
    31  
    32  	"gvisor.dev/gvisor/pkg/test/dockerutil"
    33  	"gvisor.dev/gvisor/pkg/test/testutil"
    34  )
    35  
    36  const (
    37  	// Port is the port used by the ollama server.
    38  	Port = 11434
    39  
    40  	// curtQuery is a query that should result in a very curt response.
    41  	curtQuery = `Please reply with the single word: "Hello". Do not reply with any other word.`
    42  )
    43  
    44  // Ollama is an ollama client.
    45  type Ollama struct {
    46  	// server is used to perform requests against the server.
    47  	server Server
    48  
    49  	// logger is used to log.
    50  	logger testutil.Logger
    51  
    52  	// ModelNames is the list of available model names.
    53  	ModelNames []string
    54  
    55  	// cheapModels is a list of models that are known to be cheap.
    56  	// A caller may set this to make forcefully unloading a model quicker.
    57  	cheapModels []*Model
    58  
    59  	// HasGPU is set depending on whether the LLM has GPU access.
    60  	// ollama supports running both on CPU and GPU, and detects this
    61  	// by spawning nvidia-smi.
    62  	HasGPU bool
    63  }
    64  
    65  // Server performs requests against an ollama server.
    66  type Server interface {
    67  	// InstrumentedRequest performs an instrumented HTTP request against the
    68  	// ollama server, using the `gpu/ollama_client` ollama image.
    69  	// `argvFn` takes in a `protocol://host:port` string and returns a
    70  	// command-line to use for making an instrumented HTTP request against the
    71  	// ollama server.
    72  	// InstrumentedRequest should return the logs from the request container.
    73  	InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error)
    74  
    75  	// Logs retrieves logs from the server.
    76  	Logs(ctx context.Context) (string, error)
    77  }
    78  
    79  // New starts a new Ollama server in the given container,
    80  // then waits for it to serve and returns the client.
    81  func New(ctx context.Context, server Server, logger testutil.Logger) (*Ollama, error) {
    82  	started := time.Now()
    83  	llm := &Ollama{
    84  		logger: logger,
    85  		server: server,
    86  	}
    87  
    88  	// Wait until serving.
    89  	if err := llm.WaitUntilServing(ctx); err != nil {
    90  		return nil, fmt.Errorf("ollama did not come up for serving: %w", err)
    91  	}
    92  	logger.Logf("Ollama serving API requests after %v", time.Since(started))
    93  
    94  	// Get list of model names.
    95  	modelNames, err := llm.listModelNames(ctx)
    96  	if err != nil {
    97  		return nil, fmt.Errorf("could not list model names: %w", err)
    98  	}
    99  	if len(modelNames) == 0 {
   100  		return nil, errors.New("no models available")
   101  	}
   102  	llm.ModelNames = modelNames
   103  	logger.Logf("Available ollama model names: %v (loaded %v since container start)", modelNames, time.Since(started))
   104  
   105  	// Load the first model.
   106  	// This is necessary to force ollama to load a model, without which
   107  	// we cannot detect if it is using the GPU or not.
   108  	// This may fail during the process of loading the first model, so we keep
   109  	// iterating for a while.
   110  	_, err = llm.WarmModel(ctx, &Model{Name: llm.ModelNames[0]}, 1*time.Millisecond, false)
   111  	if err != nil {
   112  		return nil, fmt.Errorf("could not load first model %q: %w", llm.ModelNames[0], err)
   113  	}
   114  	logger.Logf("Loaded first ollama model %q (%v since container start)", llm.ModelNames[0], time.Since(started))
   115  
   116  	// Now go over the logs and check if the GPU was used.
   117  	logs, err := llm.server.Logs(ctx)
   118  	if err != nil {
   119  		return nil, fmt.Errorf("could not get logs: %w", err)
   120  	}
   121  	switch {
   122  	case strings.Contains(logs, "no GPU detected"):
   123  		llm.HasGPU = false
   124  	case strings.Contains(logs, "Nvidia GPU detected"):
   125  		llm.HasGPU = true
   126  	default:
   127  		return nil, fmt.Errorf("cannot determine whether ollama is using GPU from logs:\n%s", logs)
   128  	}
   129  	logger.Logf("Ollama successfully initialized in a total of %v", time.Since(started))
   130  	return llm, nil
   131  }
   132  
   133  // SetCheapModels can be used to inform this Ollama client as to the list of
   134  // models it can use that are known to be cheap.
   135  // This is useful when forcefully unloading models by swapping them with
   136  // another one, to ensure that the one it is being swapped with is small.
   137  // Therefore, there should be at least two models specified here.
   138  func (llm *Ollama) SetCheapModels(cheapModels []*Model) {
   139  	llm.cheapModels = cheapModels
   140  }
   141  
   142  // dockerServer implements `Server`. It interfaces with an ollama server
   143  // running in a local Docker container.
   144  type dockerServer struct {
   145  	container *dockerutil.Container
   146  	logger    testutil.Logger
   147  }
   148  
   149  // NewDocker returns a new Ollama client talking to an Ollama server that runs
   150  // in a local Docker container.
   151  func NewDocker(ctx context.Context, cont *dockerutil.Container, logger testutil.Logger) (*Ollama, error) {
   152  	opts := dockerutil.GPURunOpts()
   153  	opts.Image = "gpu/ollama"
   154  	started := time.Now()
   155  	if err := cont.Spawn(ctx, opts); err != nil {
   156  		return nil, fmt.Errorf("could not start ollama: %v", err)
   157  	}
   158  	logger.Logf("Ollama container started after %v", time.Since(started))
   159  	ds := &dockerServer{
   160  		container: cont,
   161  		logger:    logger,
   162  	}
   163  	return New(ctx, ds, logger)
   164  }
   165  
   166  // InstrumentedRequest implements `Server.InstrumentedRequest`.
   167  func (ds *dockerServer) InstrumentedRequest(ctx context.Context, argvFn func(hostPort string) []string) ([]byte, error) {
   168  	const ollamaHost = "llm"
   169  	cmd := argvFn(fmt.Sprintf("http://%s:%d", ollamaHost, Port))
   170  	out, err := dockerutil.MakeContainer(ctx, ds.logger).Run(ctx, dockerutil.RunOpts{
   171  		Image: "gpu/ollama/client",
   172  		Links: []string{ds.container.MakeLink(ollamaHost)},
   173  	}, cmd...)
   174  	if err != nil {
   175  		if out != "" {
   176  			return []byte(out), fmt.Errorf("command %q failed (%w): %v", strings.Join(cmd, " "), err, out)
   177  		}
   178  		return nil, fmt.Errorf("could not run command %q: %w", strings.Join(cmd, " "), err)
   179  	}
   180  	return []byte(out), nil
   181  }
   182  
   183  // Logs implements `Server.Logs`.
   184  func (ds *dockerServer) Logs(ctx context.Context) (string, error) {
   185  	return ds.container.Logs(ctx)
   186  }
   187  
   188  // ResponseMetrics are HTTP request metrics from an ollama API query.
   189  // These is the same JSON struct as defined in
   190  // `images/gpu/ollama/client/client.go`.
   191  type ResponseMetrics struct {
   192  	// ProgramStarted is the time when the program started.
   193  	ProgramStarted time.Time `json:"program_started"`
   194  	// RequestSent is the time when the HTTP request was sent.
   195  	RequestSent time.Time `json:"request_sent"`
   196  	// ResponseReceived is the time when the HTTP response headers were received.
   197  	ResponseReceived time.Time `json:"response_received"`
   198  	// FirstByteRead is the time when the first HTTP response body byte was read.
   199  	FirstByteRead time.Time `json:"first_byte_read"`
   200  	// LastByteRead is the time when the last HTTP response body byte was read.
   201  	LastByteRead time.Time `json:"last_byte_read"`
   202  }
   203  
   204  // apiResponse represents a JSON response from the ollama API.
   205  type apiResponse[T any] struct {
   206  	// Objects is the list of JSON objects in the response.
   207  	Objects []*T
   208  	// Metrics contains HTTP response metrics.
   209  	Metrics ResponseMetrics
   210  }
   211  
   212  // Obj returns the first object in the response, if there is a singular
   213  // object in the response.
   214  func (ar *apiResponse[T]) Obj() (*T, error) {
   215  	if len(ar.Objects) == 0 {
   216  		return nil, fmt.Errorf("no objects in response")
   217  	}
   218  	if len(ar.Objects) > 1 {
   219  		return nil, fmt.Errorf("multiple objects in response")
   220  	}
   221  	return ar.Objects[0], nil
   222  }
   223  
   224  // makeAPIResponse decodes a raw response from an instrumented HTTP request
   225  // into an `apiResponse` with deserialized JSON objects.
   226  func makeAPIResponse[T any](rawResponse []byte) (*apiResponse[T], error) {
   227  	var respBytes bytes.Buffer
   228  	var resp apiResponse[T]
   229  	for _, line := range strings.Split(string(rawResponse), "\n") {
   230  		line = strings.TrimSpace(line)
   231  		if line == "" {
   232  			continue
   233  		}
   234  		colonIndex := strings.Index(line, ":")
   235  		if colonIndex == -1 {
   236  			return nil, fmt.Errorf("malformed line: %q", line)
   237  		}
   238  		data := strings.TrimSpace(line[colonIndex+1:])
   239  		switch line[:colonIndex] {
   240  		case "FATAL":
   241  			return nil, fmt.Errorf("request failed: %s", data)
   242  		case "REQHEADER", "RESPHEADER":
   243  			// Do nothing with these.
   244  		case "BODY":
   245  			unquoted, err := strconv.Unquote(data)
   246  			if err != nil {
   247  				return nil, fmt.Errorf("malformed body line: %q", data)
   248  			}
   249  			respBytes.WriteString(unquoted)
   250  		case "STATS":
   251  			if err := json.Unmarshal([]byte(data), &resp.Metrics); err != nil {
   252  				return nil, fmt.Errorf("malformed stats line: %q", data)
   253  			}
   254  		default:
   255  			return nil, fmt.Errorf("malformed line: %q", line)
   256  		}
   257  	}
   258  	decoder := json.NewDecoder(&respBytes)
   259  	for {
   260  		var obj T
   261  		err := decoder.Decode(&obj)
   262  		if err == io.EOF {
   263  			break
   264  		}
   265  		if err != nil {
   266  			return nil, fmt.Errorf("malformed JSON response: %w", err)
   267  		}
   268  		resp.Objects = append(resp.Objects, &obj)
   269  	}
   270  	if len(resp.Objects) == 0 {
   271  		return nil, fmt.Errorf("response is empty")
   272  	}
   273  	leftoverBytes, err := io.ReadAll(decoder.Buffered())
   274  	if err != nil && err != io.EOF {
   275  		return nil, fmt.Errorf("could not read leftover bytes: %w", err)
   276  	}
   277  	if leftover := strings.TrimSpace(string(leftoverBytes)); leftover != "" {
   278  		return nil, fmt.Errorf("unprocessed bytes in response: %q", leftover)
   279  	}
   280  	return &resp, nil
   281  }
   282  
   283  // instrumentedRequest makes an HTTP request to the ollama API.
   284  // It returns the raw bytestream from the instrumented request logs.
   285  func (llm *Ollama) instrumentedRequest(ctx context.Context, method, endpoint string, data []byte) ([]byte, error) {
   286  	if endpoint != "" && !strings.HasPrefix(endpoint, "/") {
   287  		return nil, fmt.Errorf("endpoint must be empty or start with '/', got %q", endpoint)
   288  	}
   289  	argvFn := func(hostPort string) []string {
   290  		argv := []string{
   291  			"httpclient",
   292  			fmt.Sprintf("--method=%s", method),
   293  			fmt.Sprintf("--url=%s%s", hostPort, endpoint),
   294  		}
   295  		if data != nil {
   296  			argv = append(argv, fmt.Sprintf("--post_base64=%s", base64.StdEncoding.EncodeToString(data)))
   297  		}
   298  		if ctxDeadline, hasDeadline := ctx.Deadline(); hasDeadline {
   299  			argv = append(argv, fmt.Sprintf("--timeout=%v", time.Until(ctxDeadline)))
   300  		}
   301  		return argv
   302  	}
   303  	rawResponse, err := llm.server.InstrumentedRequest(ctx, argvFn)
   304  	if err != nil {
   305  		return nil, fmt.Errorf("%s: %w", endpoint, err)
   306  	}
   307  	return rawResponse, nil
   308  }
   309  
   310  // jsonGet performs a JSON HTTP GET request.
   311  func jsonGet[Out any](ctx context.Context, llm *Ollama, endpoint string) (*apiResponse[Out], error) {
   312  	out, err := llm.instrumentedRequest(ctx, "GET", endpoint, nil)
   313  	if err != nil {
   314  		return nil, fmt.Errorf("GET %q failed: %w", endpoint, err)
   315  	}
   316  	return makeAPIResponse[Out](out)
   317  }
   318  
   319  // jsonPost performs a JSON HTTP POST request.
   320  func jsonPost[In, Out any](ctx context.Context, llm *Ollama, endpoint string, input In) (*apiResponse[Out], error) {
   321  	query, err := json.Marshal(input)
   322  	if err != nil {
   323  		return nil, fmt.Errorf("could not marshal input %v: %w", input, err)
   324  	}
   325  	out, err := llm.instrumentedRequest(ctx, "POST", endpoint, query)
   326  	if err != nil {
   327  		return nil, fmt.Errorf("POST %q %v failed: %w", endpoint, string(query), err)
   328  	}
   329  	return makeAPIResponse[Out](out)
   330  }
   331  
   332  // listModelNames lists the available model names.
   333  func (llm *Ollama) listModelNames(ctx context.Context) ([]string, error) {
   334  	type model struct {
   335  		Name       string `json:"name"`
   336  		ModifiedAt string `json:"modified_at"`
   337  		Size       int    `json:"size"`
   338  	}
   339  	type modelsList struct {
   340  		Models []model `json:"models"`
   341  	}
   342  	modelsResp, err := jsonGet[modelsList](ctx, llm, "/api/tags")
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  	models, err := modelsResp.Obj()
   347  	if err != nil {
   348  		return nil, fmt.Errorf("malformed model tags response: %w", err)
   349  	}
   350  	modelNames := make([]string, len(models.Models))
   351  	for i, m := range models.Models {
   352  		modelNames[i] = m.Name
   353  	}
   354  	return modelNames, nil
   355  }
   356  
   357  // WaitUntilServing waits until ollama is serving, or the context expires.
   358  func (llm *Ollama) WaitUntilServing(ctx context.Context) error {
   359  	for ctx.Err() == nil {
   360  		out, err := llm.instrumentedRequest(ctx, "GET", "/", nil)
   361  		if err != nil {
   362  			continue
   363  		}
   364  		if strings.Contains(string(out), "Ollama is running") {
   365  			return nil
   366  		}
   367  	}
   368  	return fmt.Errorf("ollama did not respond: %w", ctx.Err())
   369  }
   370  
   371  // Model encodes a model and options for it.
   372  type Model struct {
   373  	// Name is the name of the ollama model, e.g. "codellama:7b".
   374  	Name string
   375  
   376  	// Options maps parameter names to JSON-compatible values.
   377  	Options map[string]any
   378  }
   379  
   380  // String returns the model's name.
   381  func (m *Model) String() string {
   382  	return m.Name
   383  }
   384  
   385  // modelTemperatureOption is the temperature option that most models have
   386  // which controls how free they are from deviating from their most-likely
   387  // token chain.
   388  const modelTemperatureOption = "temperature"
   389  
   390  // RaiseTemperature increases the "temperature" option of the model,
   391  // if any.
   392  func (m *Model) RaiseTemperature() {
   393  	temp, ok := m.Options[modelTemperatureOption]
   394  	if !ok {
   395  		temp = float64(0.0)
   396  	}
   397  	if m.Options == nil {
   398  		m.Options = map[string]any{}
   399  	}
   400  	m.Options[modelTemperatureOption] = min(1.0, temp.(float64)*2+.025)
   401  }
   402  
   403  // Copy returns a copy of the model.
   404  func (m *Model) Copy() *Model {
   405  	modelCopy := *m
   406  	modelCopy.Options = make(map[string]any, len(m.Options))
   407  	for k, v := range m.Options {
   408  		modelCopy.Options[k] = v
   409  	}
   410  	return &modelCopy
   411  }
   412  
   413  // ZeroTemperatureModel returns a Model with the given name and an initial
   414  // temperature setting of zero. This setting allows for consistent settings.
   415  func ZeroTemperatureModel(name string) *Model {
   416  	return &Model{
   417  		Name: name,
   418  		Options: map[string]any{
   419  			modelTemperatureOption: 0.0,
   420  		},
   421  	}
   422  }
   423  
   424  // Prompt is an ollama prompt.
   425  type Prompt struct {
   426  	// Model is the model to query.
   427  	Model *Model
   428  
   429  	// If set, keep the model alive in memory for the given duration after this
   430  	// prompt is answered. A zero duration will use the ollama default (a few
   431  	// minutes). Note that model unloading is asynchronous, so the model will
   432  	// not be fully unloaded after only `KeepModelAlive` beyond prompt response.
   433  	KeepModelAlive time.Duration
   434  
   435  	// Query is the prompt string.
   436  	// Common leading whitespace will be removed.
   437  	Query string
   438  
   439  	// images is a set of attached images.
   440  	// Use AddImage to add an image.
   441  	images [][]byte
   442  
   443  	// Context is the conversational context to follow up on, if any.
   444  	// This is returned from `Response`.
   445  	Context ConversationContext
   446  }
   447  
   448  // AddImage attaches an image to the prompt.
   449  // Returns itself for chainability.
   450  func (p *Prompt) AddImage(data []byte) *Prompt {
   451  	p.images = append(p.images, data)
   452  	return p
   453  }
   454  
   455  // CleanQuery removes common whitespace from query lines, and all
   456  // leading/ending whitespace-only lines.
   457  // It is useful to be able to specify query string as indented strings
   458  // without breaking visual continuity in Go code.
   459  // For example (where dots are spaces):
   460  //
   461  // """\n
   462  // ..The Quick Brown Fox\n
   463  // ..Jumps Over\n
   464  // ....The Lazy Dog\n
   465  // ."""
   466  //
   467  // becomes:
   468  //
   469  // ""The Quick Brown Fox\n
   470  // Jumps Over\n
   471  // ..The Lazy Dog"""
   472  func (p *Prompt) CleanQuery() string {
   473  	lines := strings.Split(p.Query, "\n")
   474  
   475  	// Trim lines at the beginning and end that are only whitespace.
   476  	trimmedLines := make([]string, 0, len(lines))
   477  	startedNonWhitespace := false
   478  	var block []string
   479  	for _, line := range lines {
   480  		trimmedLine := strings.TrimSpace(line)
   481  		if !startedNonWhitespace && trimmedLine != "" {
   482  			startedNonWhitespace = true
   483  		}
   484  		if startedNonWhitespace {
   485  			block = append(block, line)
   486  		}
   487  		if trimmedLine != "" {
   488  			trimmedLines = append(trimmedLines, block...)
   489  			block = block[:0]
   490  		}
   491  	}
   492  
   493  	// Find longest common whitespace prefix.
   494  	if len(trimmedLines) == 0 {
   495  		return ""
   496  	}
   497  	trimmedFirstLine := strings.TrimSpace(trimmedLines[0])
   498  	common := []rune(trimmedLines[0][:strings.Index(trimmedLines[0], trimmedFirstLine)])
   499  	for ; len(common) > 0; common = common[:len(common)-1] {
   500  		allMatch := true
   501  		for _, line := range trimmedLines[1:] {
   502  			if strings.TrimSpace(line) == "" {
   503  				continue // Ignore whitespace-only or empty lines.
   504  			}
   505  			if !strings.HasPrefix(line, string(common)) {
   506  				allMatch = false
   507  				break
   508  			}
   509  		}
   510  		if allMatch {
   511  			break
   512  		}
   513  	}
   514  
   515  	// Remove it.
   516  	if len(common) > 0 {
   517  		for i, line := range trimmedLines {
   518  			trimmedLines[i] = strings.TrimPrefix(line, string(common))
   519  		}
   520  	}
   521  
   522  	return strings.Join(trimmedLines, "\n")
   523  }
   524  
   525  // String returns a human-friendly string representing this prompt.
   526  func (p *Prompt) String() string {
   527  	return fmt.Sprintf("[%v] %s", p.Model, p.CleanQuery())
   528  }
   529  
   530  // WithHotterModel returns a copy of this prompt with the same model having
   531  // a higher temperature.
   532  func (p *Prompt) WithHotterModel() *Prompt {
   533  	promptCopy := *p
   534  	promptCopy.Model = p.Model.Copy()
   535  	promptCopy.Model.RaiseTemperature()
   536  	return &promptCopy
   537  }
   538  
   539  // PromptJSON encodes the JSON data for a query.
   540  type PromptJSON struct {
   541  	Model     string              `json:"model"`
   542  	Prompt    string              `json:"prompt,omitempty"`
   543  	Images    []string            `json:"images"`
   544  	Stream    bool                `json:"stream"`
   545  	Context   ConversationContext `json:"context"`
   546  	Options   map[string]any      `json:"options"`
   547  	KeepAlive string              `json:"keep_alive,omitempty"`
   548  }
   549  
   550  // json encodes this prompt to the JSON format expected by Ollama.
   551  func (p *Prompt) json() PromptJSON {
   552  	keepAlive := ""
   553  	if p.KeepModelAlive != 0 {
   554  		keepAlive = p.KeepModelAlive.String()
   555  	}
   556  	images := make([]string, len(p.images))
   557  	for i, image := range p.images {
   558  		images[i] = base64.StdEncoding.EncodeToString(image)
   559  	}
   560  	return PromptJSON{
   561  		Model:     p.Model.Name,
   562  		Prompt:    p.CleanQuery(),
   563  		Images:    images,
   564  		Stream:    true,
   565  		Context:   p.Context,
   566  		Options:   p.Model.Options,
   567  		KeepAlive: keepAlive,
   568  	}
   569  }
   570  
   571  // ResponseJSON is the JSON-format response from ollama about a prompt.
   572  // Note that in `streamed` mode, the `Response` field contains a single token.
   573  // To recover the whole response, all `Response` fields must be concatenated
   574  // until the last `ResponseJSON`, identified as such by the `Done` field.
   575  type ResponseJSON struct {
   576  	Model           string              `json:"model"`
   577  	CreatedAt       time.Time           `json:"created_at"`
   578  	Response        string              `json:"response"`
   579  	Done            bool                `json:"done"`
   580  	TotalNanos      int                 `json:"total_duration"`
   581  	LoadNanos       int                 `json:"load_duration"`
   582  	EvalCount       int                 `json:"eval_count"`
   583  	EvalNanos       int                 `json:"eval_duration"`
   584  	PromptEvalCount int                 `json:"prompt_eval_count"`
   585  	PromptEvalNanos int                 `json:"prompt_eval_duration"`
   586  	Context         ConversationContext `json:"context"`
   587  }
   588  
   589  // Response represents a response to a query from Ollama.
   590  type Response struct {
   591  	data    []*ResponseJSON
   592  	metrics ResponseMetrics
   593  }
   594  
   595  // Done returns whether the response was completely generated.
   596  func (r *Response) Done() bool {
   597  	if len(r.data) == 0 {
   598  		return false
   599  	}
   600  	return r.data[len(r.data)-1].Done
   601  }
   602  
   603  // NumTokens returns the number of tokens in the response.
   604  func (r *Response) NumTokens() int {
   605  	return len(r.data)
   606  }
   607  
   608  // String returns the response text, if it is done.
   609  func (r *Response) String() string {
   610  	if len(r.data) == 0 {
   611  		return "<EMPTY>"
   612  	}
   613  	var fullResponse strings.Builder
   614  	gotDone := false
   615  	for i, token := range r.data {
   616  		fullResponse.WriteString(token.Response)
   617  		if token.Done {
   618  			if i != len(r.data)-1 {
   619  				fullResponse.WriteString("<CORRUPT>")
   620  			}
   621  			gotDone = true
   622  			break
   623  		}
   624  	}
   625  	if !gotDone {
   626  		return "<NOT DONE>"
   627  	}
   628  	return fullResponse.String()
   629  }
   630  
   631  // Text returns the body of the response, if it is done.
   632  func (r *Response) Text() string {
   633  	if !r.Done() {
   634  		return ""
   635  	}
   636  	return r.String()
   637  }
   638  
   639  // TimeToFirstToken returns the time it took between the request starting
   640  // and the first token being received by the client.
   641  func (r *Response) TimeToFirstToken() time.Duration {
   642  	if !r.Done() {
   643  		return -1
   644  	}
   645  	return r.metrics.FirstByteRead.Sub(r.metrics.RequestSent)
   646  }
   647  
   648  // TimeToLastToken returns the time it took between the request starting
   649  // and the last token being received by the client.
   650  func (r *Response) TimeToLastToken() time.Duration {
   651  	if !r.Done() {
   652  		return -1
   653  	}
   654  	return r.metrics.LastByteRead.Sub(r.metrics.RequestSent)
   655  }
   656  
   657  // tokenIntervals returns the time between each token generation.
   658  func (r *Response) tokenIntervals() []time.Duration {
   659  	if !r.Done() || len(r.data) < 2 {
   660  		return nil
   661  	}
   662  	intervals := make([]time.Duration, len(r.data)-1)
   663  	for i := 0; i < len(r.data)-1; i++ {
   664  		intervals[i] = r.data[i+1].CreatedAt.Sub(r.data[i].CreatedAt)
   665  	}
   666  	return intervals
   667  }
   668  
   669  // OutputTokensPerSecond computes the average number of output tokens
   670  // generated per second.
   671  func (r *Response) OutputTokensPerSecond() float64 {
   672  	if !r.Done() || r.EvalDuration() == 0 {
   673  		return -1
   674  	}
   675  	return float64(r.data[len(r.data)-1].EvalCount) / float64(r.EvalDuration().Seconds())
   676  }
   677  
   678  // TimePerOutputTokenAverage computes the average time to generate an output
   679  // token.
   680  func (r *Response) TimePerOutputTokenAverage() time.Duration {
   681  	if !r.Done() {
   682  		return -1
   683  	}
   684  	intervals := r.tokenIntervals()
   685  	var sum time.Duration
   686  	for _, interval := range intervals {
   687  		sum += interval
   688  	}
   689  	return sum / time.Duration(len(intervals))
   690  }
   691  
   692  // TimePerOutputTokenQuantile computes a quantile of the time it takes to
   693  // generate an output token.
   694  func (r *Response) TimePerOutputTokenQuantile(quantile float64) time.Duration {
   695  	if quantile < 0.0 || quantile > 1.0 {
   696  		panic("quantile must be between 0.0 and 1.0 inclusively")
   697  	}
   698  	if !r.Done() || r.EvalDuration() == 0 {
   699  		return -1
   700  	}
   701  	intervals := r.tokenIntervals()
   702  	sort.Slice(intervals, func(i, j int) bool { return intervals[i] < intervals[j] })
   703  	return intervals[int(quantile*float64(len(intervals)-1))]
   704  }
   705  
   706  // TokenGenerationStdDev returns the standard deviation of the time between
   707  // token generations.
   708  func (r *Response) TokenGenerationStdDev() time.Duration {
   709  	intervals := r.tokenIntervals()
   710  	if len(intervals) == 0 {
   711  		return -1
   712  	}
   713  	if len(intervals) == 1 {
   714  		return 0
   715  	}
   716  
   717  	var sum time.Duration
   718  	for _, interval := range intervals {
   719  		sum += interval
   720  	}
   721  	mean := sum / time.Duration(len(intervals))
   722  	variance := 0.0
   723  	for _, interval := range intervals {
   724  		intervalMinusMean := float64((interval - mean).Nanoseconds())
   725  		variance += intervalMinusMean * intervalMinusMean
   726  	}
   727  	variance = variance / float64(len(intervals)-1)
   728  	return time.Duration(math.Sqrt(variance)) * time.Nanosecond
   729  }
   730  
   731  // TotalDuration returns the total response generation time.
   732  func (r *Response) TotalDuration() time.Duration {
   733  	if !r.Done() {
   734  		return time.Duration(0)
   735  	}
   736  	return time.Duration(r.data[len(r.data)-1].TotalNanos) * time.Nanosecond
   737  }
   738  
   739  // LoadDuration returns the load response generation time as reported
   740  // by the ollama server.
   741  func (r *Response) LoadDuration() time.Duration {
   742  	if !r.Done() {
   743  		return time.Duration(0)
   744  	}
   745  	return time.Duration(r.data[len(r.data)-1].LoadNanos) * time.Nanosecond
   746  }
   747  
   748  // EvalDuration returns the response evaluation time.
   749  func (r *Response) EvalDuration() time.Duration {
   750  	if !r.Done() {
   751  		return time.Duration(0)
   752  	}
   753  	return time.Duration(r.data[len(r.data)-1].EvalNanos) * time.Nanosecond
   754  }
   755  
   756  // PromptEvalDuration returns the prompt evaluation time.
   757  func (r *Response) PromptEvalDuration() time.Duration {
   758  	if !r.Done() {
   759  		return time.Duration(0)
   760  	}
   761  	return time.Duration(r.data[len(r.data)-1].PromptEvalNanos) * time.Nanosecond
   762  }
   763  
   764  // ConversationContext represents a conversational context.
   765  // It is returned by a response and may be passed to a follow-up prompt.
   766  type ConversationContext []int
   767  
   768  // withServerLogsErr adds server logs to `err` if possible.
   769  func (llm *Ollama) withServerLogsErr(ctx context.Context, err error) error {
   770  	if err == nil {
   771  		return nil
   772  	}
   773  	if ctx.Err() != nil {
   774  		return fmt.Errorf("%w (+ context err: %v)", err, ctx.Err())
   775  	}
   776  	serverLogs, logsErr := llm.server.Logs(ctx)
   777  	if logsErr != nil {
   778  		return fmt.Errorf("%w (could not get server logs: %v)", err, logsErr)
   779  	}
   780  	if serverLogs != "" {
   781  		return fmt.Errorf("%w; ollama server logs:\n%v\n(end of ollama server logs)", err, serverLogs)
   782  	}
   783  	return fmt.Errorf("%w (server logs are empty)", err)
   784  }
   785  
   786  // getReplacementModel picks an available model other than `model`.
   787  // It tries to find a one that is marked cheap if possible.
   788  func (llm *Ollama) getReplacementModel(model *Model) (*Model, error) {
   789  	for _, cheapModel := range llm.cheapModels {
   790  		if cheapModel.Name != model.Name {
   791  			return cheapModel, nil
   792  		}
   793  	}
   794  	for _, otherModelName := range llm.ModelNames {
   795  		if otherModelName != model.Name {
   796  			return ZeroTemperatureModel(otherModelName), nil
   797  		}
   798  	}
   799  	return nil, fmt.Errorf("cannot find a replacement model to load instead of %q (available: %v; cheap: %v)", model.Name, llm.ModelNames, llm.cheapModels)
   800  }
   801  
   802  // ModelLoadStats holds metrics about the model loading process.
   803  type ModelLoadStats struct {
   804  	// ClientReportedDuration is the duration to load the model as perceived
   805  	// by the client, measured by HTTP client metrics.
   806  	ClientReportedDuration time.Duration
   807  }
   808  
   809  // WarmModel pre-warms a model in memory and keeps it warm for `keepWarmFor`.
   810  // If `unloadFirst` is true, another model will be loaded before loading the
   811  // requested model. This ensures that the model was loaded from a cold state.
   812  func (llm *Ollama) WarmModel(ctx context.Context, model *Model, keepWarmFor time.Duration, unloadFirst bool) (*ModelLoadStats, error) {
   813  	if keepWarmFor <= 0 {
   814  		return nil, fmt.Errorf("keepWarmFor must be strictly positive, got %v", keepWarmFor)
   815  	}
   816  	if unloadFirst {
   817  		replacementModel, err := llm.getReplacementModel(model)
   818  		if err != nil {
   819  			return nil, fmt.Errorf("cannot find a replacement model to load instead of %q to forcefully unload it: %w", model.Name, err)
   820  		}
   821  		unloadCtx, unloadCancel := context.WithTimeout(ctx, 3*time.Minute)
   822  		_, err = llm.Prompt(unloadCtx, &Prompt{Model: replacementModel, KeepModelAlive: 1 * time.Millisecond})
   823  		unloadCancel()
   824  		if err != nil {
   825  			return nil, llm.withServerLogsErr(ctx, fmt.Errorf("unload prompt for replacement model %s failed: %w", replacementModel.Name, err))
   826  		}
   827  		select { // Wait for the model to get unloaded. Unfortunately there isn't a great way to do this but to sleep.
   828  		case <-time.After(20 * time.Second):
   829  		case <-ctx.Done():
   830  		}
   831  	}
   832  	resp, err := llm.Prompt(ctx, &Prompt{Model: model, KeepModelAlive: keepWarmFor})
   833  	if err != nil {
   834  		return nil, llm.withServerLogsErr(ctx, fmt.Errorf("warmup prompt for model %s failed: %w", model.Name, err))
   835  	}
   836  	return &ModelLoadStats{
   837  		ClientReportedDuration: resp.metrics.LastByteRead.Sub(resp.metrics.RequestSent),
   838  	}, nil
   839  }
   840  
   841  // Prompt returns the result of prompting the given `model` with `prompt`.
   842  func (llm *Ollama) Prompt(ctx context.Context, prompt *Prompt) (*Response, error) {
   843  	resp, err := jsonPost[PromptJSON, ResponseJSON](ctx, llm, "/api/generate", prompt.json())
   844  	if err != nil {
   845  		return nil, llm.withServerLogsErr(ctx, fmt.Errorf("prompt (%s %q) request failed: %w", prompt.Model.Name, prompt.CleanQuery(), err))
   846  	}
   847  	return &Response{data: resp.Objects, metrics: resp.Metrics}, nil
   848  }
   849  
   850  // PromptUntil repeatedly issues a prompt until `iterate` returns a nil error.
   851  // `iterate` may optionally return an updated `Prompt` which will be used to
   852  // follow up. This is useful to work around the flakiness of LLMs in tests.
   853  func (llm *Ollama) PromptUntil(ctx context.Context, prompt *Prompt, iterate func(*Prompt, *Response) (*Prompt, error)) (*Response, error) {
   854  	var lastResponse *Response
   855  	var lastError error
   856  	attempts := 0
   857  	for ctx.Err() == nil {
   858  		response, err := llm.Prompt(ctx, prompt)
   859  		if err != nil {
   860  			return nil, fmt.Errorf("prompt request failed: %w", err)
   861  		}
   862  		attempts++
   863  		newPrompt, err := iterate(prompt, response)
   864  		if err == nil {
   865  			return response, nil
   866  		}
   867  		if newPrompt != nil {
   868  			prompt = newPrompt
   869  		}
   870  		lastResponse = response
   871  		lastError = err
   872  	}
   873  	return nil, fmt.Errorf("response %q (attempt #%d with prompt %v) did not match predicate: %v", lastResponse, attempts, prompt, lastError)
   874  }