github.com/rsc/tmp@v0.0.0-20240517235954-6deaab19748b/palm/main.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Palm is an interactive client for [Google's PaLM API].
     6  //
     7  // Usage:
     8  //
     9  //	palm [-l] [-k keyfile] [prompt...]
    10  //
    11  // Palm concatenates its arguments, sends the result as a prompt
    12  // to the PaLM model, and prints the response.
    13  //
    14  // With no arguments, palm reads standard input until EOF
    15  // and uses that as the prompt.
    16  //
    17  // The -l flag runs palm in an interactive line-based mode:
    18  // it reads a single line of input and prints the PaLM response,
    19  // and repeats. The -l flag cannot be used with arguments.
    20  //
    21  // The -k flag specifies the name of a file containing the PaLM API key
    22  // (default $HOME/.palmkey).
    23  //
    24  // [Google's PaLM API]: https://developers.generativeai.google/
    25  package main
    26  
    27  import (
    28  	"bufio"
    29  	"bytes"
    30  	"encoding/json"
    31  	"flag"
    32  	"fmt"
    33  	"io"
    34  	"log"
    35  	"net/http"
    36  	"os"
    37  	"path/filepath"
    38  	"strings"
    39  )
    40  
    41  var (
    42  	home, _  = os.UserHomeDir()
    43  	key      string
    44  	lineMode = flag.Bool("l", false, "line at a time mode")
    45  	keyFile  = flag.String("k", filepath.Join(home, ".palmkey"), "read palm API key from `file`")
    46  )
    47  
    48  func usage() {
    49  	fmt.Fprintf(os.Stderr, "usage: palm [-l] [-k keyfile]\n")
    50  	os.Exit(2)
    51  }
    52  
    53  func main() {
    54  	log.SetFlags(0)
    55  	log.SetPrefix("palm: ")
    56  	flag.Usage = usage
    57  	flag.Parse()
    58  
    59  	data, err := os.ReadFile(*keyFile)
    60  	if err != nil {
    61  		log.Fatal(err)
    62  	}
    63  	key = strings.TrimSpace(string(data))
    64  
    65  	if *lineMode {
    66  		if flag.NArg() != 0 {
    67  			log.Fatalf("-l cannot be used with arguments")
    68  		}
    69  		scanner := bufio.NewScanner(os.Stdin)
    70  		for {
    71  			fmt.Fprintf(os.Stderr, "> ")
    72  			if !scanner.Scan() {
    73  				break
    74  			}
    75  			line := scanner.Text()
    76  			fmt.Fprintf(os.Stderr, "\n")
    77  			do(line)
    78  			fmt.Fprintf(os.Stderr, "\n")
    79  		}
    80  		return
    81  	}
    82  
    83  	if flag.NArg() != 0 {
    84  		do(strings.Join(flag.Args(), " "))
    85  	} else {
    86  		data, err := io.ReadAll(os.Stdin)
    87  		if err != nil {
    88  			log.Fatal(err)
    89  		}
    90  		do(string(data))
    91  	}
    92  }
    93  
    94  func do(prompt string) {
    95  	// curl \
    96  	// -H 'Content-Type: application/json' \
    97  	// -d '{ "prompt": { "text": "Write a story about a magic backpack"} }' \
    98  	// "https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key=YOUR_API_KEY"
    99  
   100  	js, err := json.Marshal(map[string]map[string]string{"prompt": {"text": prompt}})
   101  	if err != nil {
   102  		log.Fatal(err)
   103  	}
   104  	resp, err := http.Post("https://generativelanguage.googleapis.com/v1beta3/models/text-bison-001:generateText?key="+key, "application/json", bytes.NewReader(js))
   105  	if err != nil {
   106  		log.Fatal(err)
   107  	}
   108  	if err != nil {
   109  		log.Fatal(err)
   110  	}
   111  	data, err := io.ReadAll(resp.Body)
   112  	resp.Body.Close()
   113  	if resp.StatusCode != 200 {
   114  		log.Fatalf("%s:\n%s", resp.Status, data)
   115  	}
   116  	if err != nil {
   117  		log.Fatalf("reading body: %v", err)
   118  	}
   119  
   120  	var r Response
   121  	if err := json.Unmarshal(data, &r); err != nil {
   122  		log.Fatal(err)
   123  	}
   124  	if len(r.Candidates) == 0 {
   125  		fmt.Fprintf(os.Stderr, "no candidate answers")
   126  	}
   127  	for _, c := range r.Candidates {
   128  		fmt.Printf("%s\n", c.Output)
   129  		for _, rate := range c.SafetyRatings {
   130  			if rate.Probability != "NEGLIGIBLE" {
   131  				fmt.Printf("%s=%s\n", rate.Category, rate.Probability)
   132  			}
   133  		}
   134  	}
   135  }
   136  
   137  type Response struct {
   138  	Candidates []Candidate
   139  }
   140  
   141  type Candidate struct {
   142  	Output        string
   143  	SafetyRatings []SafetyRating
   144  }
   145  
   146  type SafetyRating struct {
   147  	Category    string
   148  	Probability string
   149  }