github.com/rsc/tmp@v0.0.0-20240517235954-6deaab19748b/palm/main.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Palm is an interactive client for [Google's PaLM API]. 6 // 7 // Usage: 8 // 9 // palm [-l] [-k keyfile] [prompt...] 10 // 11 // Palm concatenates its arguments, sends the result as a prompt 12 // to the PaLM model, and prints the response. 13 // 14 // With no arguments, palm reads standard input until EOF 15 // and uses that as the prompt. 16 // 17 // The -l flag runs palm in an interactive line-based mode: 18 // it reads a single line of input and prints the PaLM response, 19 // and repeats. The -l flag cannot be used with arguments. 20 // 21 // The -k flag specifies the name of a file containing the PaLM API key 22 // (default $HOME/.palmkey). 23 // 24 // [Google's PaLM API]: https://developers.generativeai.google/ 25 package main 26 27 import ( 28 "bufio" 29 "bytes" 30 "encoding/json" 31 "flag" 32 "fmt" 33 "io" 34 "log" 35 "net/http" 36 "os" 37 "path/filepath" 38 "strings" 39 ) 40 41 var ( 42 home, _ = os.UserHomeDir() 43 key string 44 lineMode = flag.Bool("l", false, "line at a time mode") 45 keyFile = flag.String("k", filepath.Join(home, ".palmkey"), "read palm API key from `file`") 46 ) 47 48 func usage() { 49 fmt.Fprintf(os.Stderr, "usage: palm [-l] [-k keyfile]\n") 50 os.Exit(2) 51 } 52 53 func main() { 54 log.SetFlags(0) 55 log.SetPrefix("palm: ") 56 flag.Usage = usage 57 flag.Parse() 58 59 data, err := os.ReadFile(*keyFile) 60 if err != nil { 61 log.Fatal(err) 62 } 63 key = strings.TrimSpace(string(data)) 64 65 if *lineMode { 66 if flag.NArg() != 0 { 67 log.Fatalf("-l cannot be used with arguments") 68 } 69 scanner := bufio.NewScanner(os.Stdin) 70 for { 71 fmt.Fprintf(os.Stderr, "> ") 72 if !scanner.Scan() { 73 break 74 } 75 line := scanner.Text() 76 fmt.Fprintf(os.Stderr, "\n") 77 do(line) 78 fmt.Fprintf(os.Stderr, "\n") 79 } 80 return 81 } 82 83 if flag.NArg() != 0 { 84 do(strings.Join(flag.Args(), " ")) 85 } else { 86 data, err := io.ReadAll(os.Stdin) 87 if err != nil { 88 log.Fatal(err) 89 } 90 do(string(data)) 91 } 92 } 93 94 func do(prompt string) { 95 // curl \ 96 // -H 'Content-Type: application/json' \ 97 // -d '{ "prompt": { "text": "Write a story about a magic backpack"} }' \ 98 // "https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key=YOUR_API_KEY" 99 100 js, err := json.Marshal(map[string]map[string]string{"prompt": {"text": prompt}}) 101 if err != nil { 102 log.Fatal(err) 103 } 104 resp, err := http.Post("https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key="+key, "application/json", bytes.NewReader(js)) 105 if err != nil { 106 log.Fatal(err) 107 } 108 if err != nil { 109 log.Fatal(err) 110 } 111 data, err := io.ReadAll(resp.Body) 112 resp.Body.Close() 113 if resp.StatusCode != 200 { 114 log.Fatalf("%s:\n%s", resp.Status, data) 115 } 116 if err != nil { 117 log.Fatalf("reading body: %v", err) 118 } 119 120 var r Response 121 if err := json.Unmarshal(data, &r); err != nil { 122 log.Fatal(err) 123 } 124 if len(r.Candidates) == 0 { 125 fmt.Fprintf(os.Stderr, "no candidate answers") 126 } 127 for _, c := range r.Candidates { 128 fmt.Printf("%s\n", c.Output) 129 for _, rate := range c.SafetyRatings { 130 if rate.Probability != "NEGLIGIBLE" { 131 fmt.Printf("%s=%s\n", rate.Category, rate.Probability) 132 } 133 } 134 } 135 } 136 137 type Response struct { 138 Candidates []Candidate 139 } 140 141 type Candidate struct { 142 Output string 143 SafetyRatings []SafetyRating 144 } 145 146 type SafetyRating struct { 147 Category string 148 Probability string 149 }