gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/textgen_test.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package textgen_test runs ollama and generates some text with it. 16 package textgen_test 17 18 import ( 19 "context" 20 _ "embed" 21 "errors" 22 "fmt" 23 "strings" 24 "testing" 25 "time" 26 27 "gvisor.dev/gvisor/pkg/test/dockerutil" 28 "gvisor.dev/gvisor/pkg/test/testutil" 29 "gvisor.dev/gvisor/test/gpu/ollama" 30 ) 31 32 //go:embed gvisor.png 33 var gVisorPNG []byte 34 35 // extractCode extracts code between two code block markers. 36 func extractCode(response, codeBlockDelim string) (string, error) { 37 if !strings.Contains(response, codeBlockDelim) { 38 return "", fmt.Errorf("no marker string %q", codeBlockDelim) 39 } 40 var codeLines []string 41 isCodeBlock := false 42 for _, line := range strings.Split(response, "\n") { 43 if strings.HasPrefix(line, codeBlockDelim) { 44 isCodeBlock = !isCodeBlock 45 } else if isCodeBlock { 46 codeLines = append(codeLines, line) 47 } 48 } 49 if isCodeBlock { 50 return "", errors.New("non-terminated code block") 51 } 52 if len(codeLines) == 0 { 53 return "", errors.New("no or empty code block") 54 } 55 return strings.Join(codeLines, "\n") + "\n", nil 56 } 57 58 // runSandboxedPython runs the given Python code in a sandboxed container. 59 func runSandboxedPython(ctx context.Context, logger testutil.Logger, code string) (string, error) { 60 return dockerutil.MakeContainer(ctx, logger).Run(ctx, dockerutil.RunOpts{ 61 Image: "basic/python", 62 NetworkMode: "none", 63 Entrypoint: []string{"python3"}, 64 Env: []string{"PYTHONUTF8=1"}, 65 }, "-c", code) 66 } 67 68 // TestLLM tests an LLM running in a sandboxed container. 69 // It first asks it to translate "Hello World" to Chinese. 70 // Then it asks it to write a unit test that verifies that 71 // this text is a correct translation. 72 func TestLLM(t *testing.T) { 73 ctx := context.Background() 74 // Run the LLM. 75 llmContainer := dockerutil.MakeContainer(ctx, t) 76 defer llmContainer.CleanUp(ctx) 77 startCtx, startCancel := context.WithTimeout(ctx, 3*time.Minute) 78 llm, err := ollama.NewDocker(startCtx, llmContainer, t) 79 startCancel() 80 if err != nil { 81 t.Fatalf("Failed to start ollama: %v", err) 82 } 83 if !llm.HasGPU { 84 t.Fatal("LLM is not using a GPU") 85 } 86 87 // Query it. 88 var translation string 89 t.Run("translate text", func(t *testing.T) { 90 prompt := ollama.Prompt{ 91 Model: ollama.ZeroTemperatureModel("llama2-chinese:7b-chat"), 92 Query: ` 93 Translate the following text from English to Chinese: 94 "Hello World". 95 `, 96 } 97 promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute) 98 response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) { 99 defer prompt.Model.RaiseTemperature() 100 text := strings.TrimSpace(response.Text()) 101 for _, unacceptable := range []rune{'"', '\'', '\\', '\n', '\r', '\t'} { 102 if strings.ContainsRune(text, unacceptable) { 103 return prompt, fmt.Errorf("response contains unacceptable character %q", unacceptable) 104 } 105 } 106 for _, acceptableWord := range []string{ 107 "你好", 108 "世界", 109 } { 110 if strings.Contains(text, acceptableWord) { 111 return prompt, nil 112 } 113 } 114 return prompt, errors.New("text does not contain any of the expected words") 115 }) 116 promptCancel() 117 if err != nil { 118 t.Fatalf("translation failed: %v", err) 119 } 120 translation = strings.TrimSpace(response.Text()) 121 t.Logf("The Chinese translation of %q is: %q", "Hello World", translation) 122 }) 123 if t.Failed() { 124 return 125 } 126 t.Run("generate test case", func(t *testing.T) { 127 const ( 128 markerString = "FOOBARBAZQUUX" 129 hello = "你好" 130 world = "世界" 131 codeBlockDelim = "```" 132 ) 133 promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute) 134 prompt := ollama.Prompt{ 135 Model: ollama.ZeroTemperatureModel("codellama:7b-instruct"), 136 Query: fmt.Sprintf(` 137 Generate a Python function that takes a string and verifies that it 138 is a valid Chinese translation of the English phrase "Hello World". 139 The function should first turn its input into lowercase in order to 140 match case-insensitively, and remove all spaces. 141 Then, the function should verify that the phrase contains at least 142 "你好" ("hello") or "世界" ("world"). 143 If the verification succeeds, the function should return True. 144 After this function is defined, you should call this function with 145 the input string %q. 146 Then, the code should verify that the function call returned True. 147 If it did, the code should print "Verification succeeded"; 148 otherwise, it should print "Verification failed". 149 You may use Python code comments, but do not otherwise explain how 150 the code works and do not provide usage examples. 151 Output a single block of Python code wrapped between %q marks. 152 `, markerString, codeBlockDelim), 153 } 154 response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) { 155 defer prompt.Model.RaiseTemperature() 156 pythonCode, err := extractCode(response.Text(), codeBlockDelim) 157 if err != nil { 158 return prompt, fmt.Errorf("code extraction failed: %w", err) 159 } 160 if !strings.Contains(pythonCode, markerString) { 161 return prompt, fmt.Errorf("marker string %q is not in a code block", markerString) 162 } 163 out, err := runSandboxedPython(ctx, t, pythonCode) 164 if err != nil { 165 return prompt, fmt.Errorf("execution with marker string failed: %w", err) 166 } 167 out = strings.TrimSpace(out) 168 if out == "" { 169 return prompt, fmt.Errorf("execution with marker string %q had no output", markerString) 170 } 171 if out == "Verification succeeded" { 172 return prompt, fmt.Errorf("verification did not fail for marker string %q (we expected it to fail for this string): got output %q", markerString, out) 173 } 174 if out != "Verification failed" { 175 return prompt, fmt.Errorf("verification program returned unexpected output %q for marker string %q", out, markerString) 176 } 177 for _, word := range []string{hello, world} { 178 codeWithRealText := strings.ReplaceAll(pythonCode, markerString, fmt.Sprintf("asdf %s fdsa", word)) 179 out, err = runSandboxedPython(ctx, t, codeWithRealText) 180 if err != nil { 181 return prompt, fmt.Errorf("execution with word %q failed: %w", word, err) 182 } 183 out = strings.TrimSpace(out) 184 if out == "" { 185 return prompt, fmt.Errorf("execution with word %q had no output", word) 186 } 187 if out != "Verification succeeded" { 188 return prompt, fmt.Errorf("verification with word %q failed: got output %q", word, out) 189 } 190 } 191 return nil, nil 192 }) 193 promptCancel() 194 if err != nil { 195 t.Fatalf("Code generation prompt failed: %v", err) 196 } 197 pythonCode, err := extractCode(response.Text(), codeBlockDelim) 198 if err != nil { 199 t.Fatalf("Code extraction failed: %v", err) 200 } 201 testCode := strings.ReplaceAll(pythonCode, markerString, translation) 202 out, err := runSandboxedPython(ctx, t, testCode) 203 if err != nil { 204 t.Fatalf("Translation verification with string %q failed: %v\nCode used:\n\n%s\n\n", translation, err, testCode) 205 } 206 out = strings.TrimSpace(out) 207 if out != "Verification succeeded" { 208 t.Fatalf("Translation verification with string %q failed: %q\nCode used:\n\n%s\n\n", translation, out, testCode) 209 } 210 t.Logf("Translation verification succeeded with code:\n\n%s\n\n", pythonCode) 211 }) 212 t.Run("ocr", func(t *testing.T) { 213 const textInImage = "gVisor" 214 promptCtx, promptCancel := context.WithTimeout(ctx, 3*time.Minute) 215 prompt := ollama.Prompt{ 216 Model: ollama.ZeroTemperatureModel("llava:7b-v1.6"), 217 Query: "What is the text written in this image?", 218 } 219 prompt.AddImage(gVisorPNG) 220 response, err := llm.PromptUntil(promptCtx, &prompt, func(prompt *ollama.Prompt, response *ollama.Response) (*ollama.Prompt, error) { 221 defer prompt.Model.RaiseTemperature() 222 text := strings.TrimSpace(response.Text()) 223 if !strings.Contains(strings.ToLower(text), strings.ToLower(textInImage)) { 224 return prompt, fmt.Errorf("text does not contain %q: %q", textInImage, text) 225 } 226 return prompt, nil 227 }) 228 promptCancel() 229 if err != nil { 230 t.Fatalf("OCR failed: %v", err) 231 } 232 t.Logf("OCR response for gVisor logo: %q", response.Text()) 233 }) 234 }