gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/textgen_test.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package textgen_test runs ollama and generates some text with it.
    16  package textgen_test
    17  
    18  import (
    19  	"context"
    20  	_ "embed"
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"gvisor.dev/gvisor/pkg/test/dockerutil"
    28  	"gvisor.dev/gvisor/pkg/test/testutil"
    29  	"gvisor.dev/gvisor/test/gpu/ollama"
    30  )
    31  
    32  //go:embed gvisor.png
    33  var gVisorPNG []byte
    34  
    35  // extractCode extracts code between two code block markers.
    36  func extractCode(response, codeBlockDelim string) (string, error) {
    37  	if !strings.Contains(response, codeBlockDelim) {
    38  		return "", fmt.Errorf("no marker string %q", codeBlockDelim)
    39  	}
    40  	var codeLines []string
    41  	isCodeBlock := false
    42  	for _, line := range strings.Split(response, "\n") {
    43  		if strings.HasPrefix(line, codeBlockDelim) {
    44  			isCodeBlock = !isCodeBlock
    45  		} else if isCodeBlock {
    46  			codeLines = append(codeLines, line)
    47  		}
    48  	}
    49  	if isCodeBlock {
    50  		return "", errors.New("non-terminated code block")
    51  	}
    52  	if len(codeLines) == 0 {
    53  		return "", errors.New("no or empty code block")
    54  	}
    55  	return strings.Join(codeLines, "\n") + "\n", nil
    56  }
    57  
    58  // runSandboxedPython runs the given Python code in a sandboxed container.
    59  func runSandboxedPython(ctx context.Context, logger testutil.Logger, code string) (string, error) {
    60  	return dockerutil.MakeContainer(ctx, logger).Run(ctx, dockerutil.RunOpts{
    61  		Image:       "basic/python",
    62  		NetworkMode: "none",
    63  		Entrypoint:  []string{"python3"},
    64  		Env:         []string{"PYTHONUTF8=1"},
    65  	}, "-c", code)
    66  }
    67  
    68  // TestLLM tests an LLM running in a sandboxed container.
    69  // It first asks it to translate "Hello World" to Chinese.
    70  // Then it asks it to write a unit test that verifies that
    71  // this text is a correct translation.
    72  func TestLLM(t *testing.T) {
    73  	ctx := context.Background()
    74  	// Run the LLM.
    75  	llmContainer := dockerutil.MakeContainer(ctx, t)
    76  	defer llmContainer.CleanUp(ctx)
    77  	startCtx, startCancel := context.WithTimeout(ctx, 3*time.Minute)
    78  	llm, err := ollama.NewDocker(startCtx, llmContainer, t)
    79  	startCancel()
    80  	if err != nil {
    81  		t.Fatalf("Failed to start ollama: %v", err)
    82  	}
    83  	if !llm.HasGPU {
    84  		t.Fatal("LLM is not using a GPU")
    85  	}
    86  
    87  	// Query it.
    88  	var translation string
    89  	t.Run("translate text", func(t *testing.T) {
    90  		prompt := ollama.Prompt{
    91  			Model: ollama.ZeroTemperatureModel("llama2-chinese:7b-chat"),
    92  			Query: `
    93  				Translate the following text from English to Chinese:
    94  				    "Hello World".
    95  			`,
    96  		}
    97  		promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute)
    98  		response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) {
    99  			defer prompt.Model.RaiseTemperature()
   100  			text := strings.TrimSpace(response.Text())
   101  			for _, unacceptable := range []rune{'"', '\'', '\\', '\n', '\r', '\t'} {
   102  				if strings.ContainsRune(text, unacceptable) {
   103  					return prompt, fmt.Errorf("response contains unacceptable character %q", unacceptable)
   104  				}
   105  			}
   106  			for _, acceptableWord := range []string{
   107  				"你好",
   108  				"世界",
   109  			} {
   110  				if strings.Contains(text, acceptableWord) {
   111  					return prompt, nil
   112  				}
   113  			}
   114  			return prompt, errors.New("text does not contain any of the expected words")
   115  		})
   116  		promptCancel()
   117  		if err != nil {
   118  			t.Fatalf("translation failed: %v", err)
   119  		}
   120  		translation = strings.TrimSpace(response.Text())
   121  		t.Logf("The Chinese translation of %q is: %q", "Hello World", translation)
   122  	})
   123  	if t.Failed() {
   124  		return
   125  	}
   126  	t.Run("generate test case", func(t *testing.T) {
   127  		const (
   128  			markerString   = "FOOBARBAZQUUX"
   129  			hello          = "你好"
   130  			world          = "世界"
   131  			codeBlockDelim = "```"
   132  		)
   133  		promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute)
   134  		prompt := ollama.Prompt{
   135  			Model: ollama.ZeroTemperatureModel("codellama:7b-instruct"),
   136  			Query: fmt.Sprintf(`
   137  				Generate a Python function that takes a string and verifies that it
   138  				is a valid Chinese translation of the English phrase "Hello World".
   139  				The function should first turn its input into lowercase in order to
   140  				match case-insensitively, and remove all spaces.
   141  				Then, the function should verify that the phrase contains at least
   142  				"你好" ("hello") or "世界" ("world").
   143  				If the verification succeeds, the function should return True.
   144  				After this function is defined, you should call this function with
   145  				the input string %q.
   146  				Then, the code should verify that the function call returned True.
   147  				If it did, the code should print "Verification succeeded";
   148  				otherwise, it should print "Verification failed".
   149  				You may use Python code comments, but do not otherwise explain how
   150  				the code works and do not provide usage examples.
   151  				Output a single block of Python code wrapped between %q marks.
   152  			`, markerString, codeBlockDelim),
   153  		}
   154  		response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) {
   155  			defer prompt.Model.RaiseTemperature()
   156  			pythonCode, err := extractCode(response.Text(), codeBlockDelim)
   157  			if err != nil {
   158  				return prompt, fmt.Errorf("code extraction failed: %w", err)
   159  			}
   160  			if !strings.Contains(pythonCode, markerString) {
   161  				return prompt, fmt.Errorf("marker string %q is not in a code block", markerString)
   162  			}
   163  			out, err := runSandboxedPython(ctx, t, pythonCode)
   164  			if err != nil {
   165  				return prompt, fmt.Errorf("execution with marker string failed: %w", err)
   166  			}
   167  			out = strings.TrimSpace(out)
   168  			if out == "" {
   169  				return prompt, fmt.Errorf("execution with marker string %q had no output", markerString)
   170  			}
   171  			if out == "Verification succeeded" {
   172  				return prompt, fmt.Errorf("verification did not fail for marker string %q (we expected it to fail for this string): got output %q", markerString, out)
   173  			}
   174  			if out != "Verification failed" {
   175  				return prompt, fmt.Errorf("verification program returned unexpected output %q for marker string %q", out, markerString)
   176  			}
   177  			for _, word := range []string{hello, world} {
   178  				codeWithRealText := strings.ReplaceAll(pythonCode, markerString, fmt.Sprintf("asdf %s fdsa", word))
   179  				out, err = runSandboxedPython(ctx, t, codeWithRealText)
   180  				if err != nil {
   181  					return prompt, fmt.Errorf("execution with word %q failed: %w", word, err)
   182  				}
   183  				out = strings.TrimSpace(out)
   184  				if out == "" {
   185  					return prompt, fmt.Errorf("execution with word %q had no output", word)
   186  				}
   187  				if out != "Verification succeeded" {
   188  					return prompt, fmt.Errorf("verification with word %q failed: got output %q", word, out)
   189  				}
   190  			}
   191  			return nil, nil
   192  		})
   193  		promptCancel()
   194  		if err != nil {
   195  			t.Fatalf("Code generation prompt failed: %v", err)
   196  		}
   197  		pythonCode, err := extractCode(response.Text(), codeBlockDelim)
   198  		if err != nil {
   199  			t.Fatalf("Code extraction failed: %v", err)
   200  		}
   201  		testCode := strings.ReplaceAll(pythonCode, markerString, translation)
   202  		out, err := runSandboxedPython(ctx, t, testCode)
   203  		if err != nil {
   204  			t.Fatalf("Translation verification with string %q failed: %v\nCode used:\n\n%s\n\n", translation, err, testCode)
   205  		}
   206  		out = strings.TrimSpace(out)
   207  		if out != "Verification succeeded" {
   208  			t.Fatalf("Translation verification with string %q failed: %q\nCode used:\n\n%s\n\n", translation, out, testCode)
   209  		}
   210  		t.Logf("Translation verification succeeded with code:\n\n%s\n\n", pythonCode)
   211  	})
   212  	t.Run("ocr", func(t *testing.T) {
   213  		const textInImage = "gVisor"
   214  		promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute)
   215  		prompt := ollama.Prompt{
   216  			Model: ollama.ZeroTemperatureModel("llava:7b-v1.6"),
   217  			Query: "What is the text written in this image?",
   218  		}
   219  		prompt.AddImage(gVisorPNG)
   220  		response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) {
   221  			defer prompt.Model.RaiseTemperature()
   222  			text := strings.TrimSpace(response.Text())
   223  			if !strings.Contains(strings.ToLower(text), strings.ToLower(textInImage)) {
   224  				return prompt, fmt.Errorf("text does not contain %q: %q", textInImage, text)
   225  			}
   226  			return prompt, nil
   227  		})
   228  		promptCancel()
   229  		if err != nil {
   230  			t.Fatalf("OCR failed: %v", err)
   231  		}
   232  		t.Logf("OCR response for gVisor logo: %q", response.Text())
   233  	})
   234  }