github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/tools/syz-gemini-seed/gemini-seed.go (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // syz-gemini-seed generates program seeds based on existing programs in the corpus using Gemini API.
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"flag"
    11  	"fmt"
    12  	"runtime"
    13  
    14  	"github.com/google/generative-ai-go/genai"
    15  	"github.com/google/syzkaller/pkg/db"
    16  	"github.com/google/syzkaller/pkg/tool"
    17  	"github.com/google/syzkaller/prog"
    18  	_ "github.com/google/syzkaller/sys"
    19  	"google.golang.org/api/option"
    20  )
    21  
    22  func main() {
    23  	var (
    24  		flagOS     = flag.String("os", runtime.GOOS, "target OS")
    25  		flagArch   = flag.String("arch", runtime.GOARCH, "target arch")
    26  		flagCorpus = flag.String("corpus", "", "wxisting corpus.db file to use as examples")
    27  		flagCount  = flag.Int("count", 1, "number of programs to generate")
    28  		flagAPIKey = flag.String("key", "", "gemini API key to use")
    29  	)
    30  	tool.Init()
    31  
    32  	target, err := prog.GetTarget(*flagOS, *flagArch)
    33  	if err != nil {
    34  		tool.Failf("failed to find target: %v", err)
    35  	}
    36  
    37  	db, err := db.Open(*flagCorpus, false)
    38  	if err != nil {
    39  		tool.Failf("failed to open database: %v", err)
    40  	}
    41  
    42  	ctx := context.Background()
    43  	client, err := genai.NewClient(ctx, option.WithAPIKey(*flagAPIKey))
    44  	if err != nil {
    45  		tool.Fail(err)
    46  	}
    47  	defer client.Close()
    48  
    49  	for i := 0; i < *flagCount; i++ {
    50  		model := client.GenerativeModel("gemini-1.5-pro")
    51  		model.SetTemperature(0.9)
    52  		// This does not work (fails with "Only one candidate can be specified").
    53  		// model.SetCandidateCount(3)
    54  		// TODO: tune TopP/TopK.
    55  		// model.SetTopP(0.5)
    56  		// model.SetTopK(20)
    57  		// TODO: do we need any system instructions?
    58  		// model.SystemInstruction = &genai.Content{
    59  		//	Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
    60  		// }
    61  
    62  		// In some cases it thinks it generates unsafe content, so disable safety.
    63  		// TODO: this fails with some cryptic error.
    64  		if false {
    65  			for cat := genai.HarmCategoryDerogatory; cat <= genai.HarmCategoryDangerousContent; cat++ {
    66  				model.SafetySettings = append(model.SafetySettings, &genai.SafetySetting{
    67  					Category:  cat,
    68  					Threshold: genai.HarmBlockNone,
    69  				})
    70  			}
    71  		}
    72  
    73  		prompt := new(bytes.Buffer)
    74  		prompt.WriteString("Below are examples of test programs in a special notation.\n\n")
    75  		// TODO: select a subset of related programs (using the same syscall).
    76  		n := 0
    77  		for _, rec := range db.Records {
    78  			prompt.WriteString("\n\nHere is an example:\n\n")
    79  			prompt.Write(rec.Val)
    80  			n++
    81  			if len(prompt.Bytes()) > 50<<10 || n >= 20 {
    82  				break
    83  			}
    84  		}
    85  		prompt.WriteString("\n\nPlease generate a similar but different test program with 5 lines.\n")
    86  		prompt.WriteString("Output just the program.\n")
    87  		resp, err := model.GenerateContent(ctx, genai.Text(prompt.String()))
    88  		if err != nil {
    89  			tool.Fail(err)
    90  		}
    91  
    92  		for _, cand := range resp.Candidates {
    93  			reply := new(bytes.Buffer)
    94  			if cand.Content != nil {
    95  				for _, part := range cand.Content.Parts {
    96  					if text, ok := part.(genai.Text); ok {
    97  						reply.WriteString(string(text))
    98  					}
    99  				}
   100  			}
   101  			fmt.Printf("REPLY:\n%s\n\n", reply)
   102  			p, err := target.Deserialize(reply.Bytes(), prog.NonStrict)
   103  			if err != nil {
   104  				fmt.Printf("failed to parse: %v\n\n", err)
   105  			} else {
   106  				fmt.Printf("PARSED:\n%s\n\n", p.Serialize())
   107  			}
   108  		}
   109  	}
   110  }