agones.dev/agones@v1.54.0/examples/simple-genai-server/main.go (about)

     1  // Copyright 2024 Google LLC All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"context"
    21  	"encoding/json"
    22  	"flag"
    23  	"fmt"
    24  	"io"
    25  	"log"
    26  	"math/rand"
    27  	"net"
    28  	"net/http"
    29  	"os"
    30  	"strconv"
    31  	"strings"
    32  	"sync"
    33  	"time"
    34  
    35  	"agones.dev/agones/pkg/util/signals"
    36  	sdk "agones.dev/agones/sdks/go"
    37  )
    38  
    39  // Main starts a server that serves as an example of how to integrate GenAI endpoints into your dedicated game server.
    40  func main() {
    41  	sigCtx, _ := signals.NewSigKillContext()
    42  
    43  	port := flag.String("port", "7654", "The port to listen to traffic on")
    44  	genAiEndpoint := flag.String("GenAiEndpoint", "", "The full base URL to send API requests to simulate computer (NPC) responses to user input")
    45  	genAiContext := flag.String("GenAiContext", "", "Context for the GenAI endpoint")
    46  	prompt := flag.String("Prompt", "", "The first prompt for the GenAI endpoint")
    47  	simEndpoint := flag.String("SimEndpoint", "", "The full base URL to send API requests to simulate user input")
    48  	simContext := flag.String("SimContext", "", "Context for the Sim endpoint")
    49  	stopPhrase := flag.String("StopPhrase", "Bye!", "In autonomous chat, if either side sends this, stop after the next turn.")
    50  	numChats := flag.Int("NumChats", 1, "Number of back and forth chats between the sim and genAI")
    51  	genAiNpc := flag.Bool("GenAiNpc", false, "Set to true if the GenAIEndpoint is the npc-chat-api endpoint")
    52  	simNpc := flag.Bool("SimNpc", false, "Set to true if the SimEndpoint is the npc-chat-api endpoint")
    53  	fromId := flag.Int("FromID", 2, "Entity sending messages to the npc-chat-api. Ignored when autonomous, which uses random FromID")
    54  	toId := flag.Int("ToID", 1, "Entity receiving messages on the npc-chat-api (the NPC's ID)")
    55  	concurrentPlayers := flag.Int("ConcurrentPlayers", 1, "Number of concurrent players.")
    56  
    57  	flag.Parse()
    58  	if ep := os.Getenv("PORT"); ep != "" {
    59  		port = &ep
    60  	}
    61  	if sc := os.Getenv("SIM_CONTEXT"); sc != "" {
    62  		simContext = &sc
    63  	}
    64  	if ss := os.Getenv("STOP_PHRASE"); ss != "" {
    65  		stopPhrase = &ss
    66  	}
    67  	if gac := os.Getenv("GEN_AI_CONTEXT"); gac != "" {
    68  		genAiContext = &gac
    69  	}
    70  	if p := os.Getenv("PROMPT"); p != "" {
    71  		prompt = &p
    72  	}
    73  	if se := os.Getenv("SIM_ENDPOINT"); se != "" {
    74  		simEndpoint = &se
    75  	}
    76  	if gae := os.Getenv("GEN_AI_ENDPOINT"); gae != "" {
    77  		genAiEndpoint = &gae
    78  	}
    79  	if nc := os.Getenv("NUM_CHATS"); nc != "" {
    80  		num, err := strconv.Atoi(nc)
    81  		if err != nil {
    82  			log.Fatalf("Could not parse NumChats: %v", err)
    83  		}
    84  		numChats = &num
    85  	}
    86  	if gan := os.Getenv("GEN_AI_NPC"); gan != "" {
    87  		gnpc, err := strconv.ParseBool(gan)
    88  		if err != nil {
    89  			log.Fatalf("Could parse GenAiNpc: %v", err)
    90  		}
    91  		genAiNpc = &gnpc
    92  	}
    93  	if sn := os.Getenv("SIM_NPC"); sn != "" {
    94  		snpc, err := strconv.ParseBool(sn)
    95  		if err != nil {
    96  			log.Fatalf("Could parse GenAiNpc: %v", err)
    97  		}
    98  		simNpc = &snpc
    99  	}
   100  	if fid := os.Getenv("FROM_ID"); fid != "" {
   101  		num, err := strconv.Atoi(fid)
   102  		if err != nil {
   103  			log.Fatalf("Could not parse FromId: %v", err)
   104  		}
   105  		fromId = &num
   106  	}
   107  	if tid := os.Getenv("TO_ID"); tid != "" {
   108  		num, err := strconv.Atoi(tid)
   109  		if err != nil {
   110  			log.Fatalf("Could not parse ToId: %v", err)
   111  		}
   112  		toId = &num
   113  	}
   114  	if cp := os.Getenv("CONCURRENT_PLAYERS"); cp != "" {
   115  		num, err := strconv.Atoi(cp)
   116  		if err != nil {
   117  			log.Fatalf("Could not parse ToID: %v", err)
   118  		}
   119  		concurrentPlayers = &num
   120  	}
   121  
   122  	log.Print("Creating SDK instance")
   123  	s, err := sdk.NewSDK()
   124  	if err != nil {
   125  		log.Fatalf("Could not connect to sdk: %v", err)
   126  	}
   127  
   128  	log.Print("Starting Health Ping")
   129  	go doHealth(s, sigCtx)
   130  
   131  	log.Print("Marking this server as ready")
   132  	if err := s.Ready(); err != nil {
   133  		log.Fatalf("Could not send ready message")
   134  	}
   135  
   136  	if *genAiEndpoint == "" {
   137  		log.Fatalf("GenAiEndpoint must be specified")
   138  	}
   139  
   140  	// Start up TCP listener so the user can interact with the GenAI endpoint manually
   141  	if *simEndpoint == "" {
   142  		log.Printf("Creating GenAI Client at endpoint %s (from_id=%d, to_id=%d)", *genAiEndpoint, *fromId, *toId)
   143  		genAiConn := initClient(*genAiEndpoint, *genAiContext, "GenAI", *genAiNpc, *fromId, *toId)
   144  		go tcpListener(*port, genAiConn)
   145  		<-sigCtx.Done()
   146  	} else {
   147  		var wg sync.WaitGroup
   148  
   149  		for slot := 0; slot < *concurrentPlayers; slot++ {
   150  			wg.Add(1)
   151  			go func() {
   152  				defer wg.Done()
   153  				for {
   154  					// Create a random from_id and name
   155  					fid := int(rand.Int31())
   156  					name := fmt.Sprintf("Sim%08x", fid)
   157  					log.Printf("=== New player %s (id %d) ===", name, fid)
   158  
   159  					log.Printf("Creating GenAI Client at endpoint %s (from_id=%d, to_id=%d)", *genAiEndpoint, fid, *toId)
   160  					genAiConn := initClient(*genAiEndpoint, *genAiContext, "GenAI", *genAiNpc, fid, *toId)
   161  
   162  					log.Printf("%s: Creating client at endpoint %s, sending prompt: %s", name, *simEndpoint, *prompt)
   163  					simConn := initClient(*simEndpoint, *simContext, name, *simNpc, *toId, *toId)
   164  
   165  					chatHistory := []Message{{Author: simConn.name, Content: *prompt}}
   166  					autonomousChat(*prompt, genAiConn, simConn, *numChats, *stopPhrase, chatHistory)
   167  				}
   168  			}()
   169  		}
   170  		wg.Wait()
   171  	}
   172  
   173  	log.Printf("Shutting down the Game Server.")
   174  	shutdownErr := s.Shutdown()
   175  	if shutdownErr != nil {
   176  		log.Printf("Could not shutdown")
   177  	}
   178  	os.Exit(0)
   179  }
   180  
   181  func initClient(endpoint string, context string, name string, npc bool, fromID int, toID int) *connection {
   182  	// TODO: create option for a client certificate
   183  	client := &http.Client{}
   184  	return &connection{client: client, endpoint: endpoint, context: context, name: name, npc: npc, fromId: fromID, toId: toID}
   185  }
   186  
   187  type connection struct {
   188  	client   *http.Client
   189  	endpoint string // Full base URL for API requests
   190  	context  string
   191  	name     string // Human readable name for the connection
   192  	npc      bool   // True if the endpoint is the NPC API
   193  	fromId   int    // For use with NPC API, sender ID
   194  	toId     int    // For use with NPC API, receiver ID
   195  	// TODO: create options for routes off the base URL
   196  }
   197  
   198  // For use with Vertex APIs
   199  type GenAIRequest struct {
   200  	Context     string    `json:"context,omitempty"` // Optional
   201  	Prompt      string    `json:"prompt,omitempty"`
   202  	ChatHistory []Message `json:"messages,omitempty"` // Optional, stores chat history for use with Vertex Chat API
   203  }
   204  
   205  // For use with NPC API
   206  type NPCRequest struct {
   207  	Msg    string `json:"message,omitempty"`
   208  	FromId int    `json:"from_id,omitempty"`
   209  	ToId   int    `json:"to_id,omitempty"`
   210  }
   211  
   212  // Expected format for the NPC endpoint response
   213  type NPCResponse struct {
   214  	Response string `json:"response"`
   215  }
   216  
   217  // Conversation history provided to the model in a structured alternate-author form.
   218  // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat
   219  type Message struct {
   220  	Author  string `json:"author"`
   221  	Content string `json:"content"`
   222  }
   223  
   224  func handleGenAIRequest(prompt string, clientConn *connection, chatHistory []Message) (string, error) {
   225  	var jsonStr []byte
   226  	var err error
   227  	// If the endpoint is the NPC API, use the json request format specific to that API
   228  	if clientConn.npc {
   229  		npcRequest := NPCRequest{
   230  			Msg:    prompt,
   231  			FromId: clientConn.fromId,
   232  			ToId:   clientConn.toId,
   233  		}
   234  		jsonStr, err = json.Marshal(npcRequest)
   235  	} else {
   236  		// Vertex expects the author to be "user" for user generated messages and "bot" for messages it previously sent.
   237  		// Translate the chat history we have using the connection names.
   238  		//
   239  		// You can think of `prompt` as the message that "user" is sending to "bot", meaning chatHistory should always
   240  		// end with "bot".
   241  		var ch []Message
   242  		for _, chat := range chatHistory {
   243  			newChat := Message{Content: chat.Content}
   244  			if chat.Author == clientConn.name {
   245  				newChat.Author = "user"
   246  			} else {
   247  				newChat.Author = "bot"
   248  			}
   249  			ch = append(ch, newChat)
   250  		}
   251  		if len(ch) > 0 && ch[len(ch)-1].Author != "bot" {
   252  			log.Fatalf("Chat history does not end in 'bot': %#v", ch)
   253  		}
   254  
   255  		genAIRequest := GenAIRequest{
   256  			Context:     clientConn.context,
   257  			Prompt:      prompt,
   258  			ChatHistory: ch,
   259  		}
   260  		jsonStr, err = json.Marshal(genAIRequest)
   261  	}
   262  	if err != nil {
   263  		return "", fmt.Errorf("unable to marshal json request: %v", err)
   264  	}
   265  
   266  	req, err := http.NewRequest("POST", clientConn.endpoint, bytes.NewBuffer(jsonStr))
   267  	if err != nil {
   268  		return "", fmt.Errorf("unable create http POST request: %v", err)
   269  	}
   270  	req.Header.Set("accept", "application/json")
   271  	req.Header.Set("Content-Type", "application/json")
   272  
   273  	resp, err := clientConn.client.Do(req)
   274  	if err != nil {
   275  		return "", fmt.Errorf("unable to post request: %v", err)
   276  	}
   277  
   278  	responseBody, err := io.ReadAll(resp.Body)
   279  	if err != nil {
   280  		return "", fmt.Errorf("unable to read response body: %v", err)
   281  	}
   282  	defer resp.Body.Close()
   283  	body := string(responseBody)
   284  
   285  	if resp.StatusCode != 200 {
   286  		err = fmt.Errorf("Status: %s, Body: %s", resp.Status, body)
   287  	}
   288  	return string(responseBody) + "\n", err
   289  }
   290  
   291  // Two AIs (connection endpoints) talking to each other
   292  func autonomousChat(prompt string, conn1 *connection, conn2 *connection, numChats int, stopPhase string, chatHistory []Message) {
   293  	if numChats <= 0 {
   294  		return
   295  	}
   296  
   297  	startTime := time.Now()
   298  	response, err := handleGenAIRequest(prompt, conn1, chatHistory)
   299  	latency := time.Now().Sub(startTime)
   300  	if err != nil {
   301  		log.Printf("ERROR: Could not send request (stopping this chat): %v", err)
   302  		return
   303  	}
   304  	// If we sent the request to the NPC endpoint we need to parse the json response {response: "response"}
   305  	if conn1.npc {
   306  		npcResponse := NPCResponse{}
   307  		err = json.Unmarshal([]byte(response), &npcResponse)
   308  		if err != nil {
   309  			log.Fatalf("FATAL ERROR: Unable to unmarshal NPC endpoint response: %v", err)
   310  		}
   311  		response = npcResponse.Response
   312  	}
   313  	log.Printf("%s->%s [%d turns left]: %s\n", conn1.name, conn2.name, numChats, response)
   314  	log.Printf("%s PREDICTION RATE: %0.2f b/s", conn1.name, float64(len(response))/latency.Seconds())
   315  
   316  	chat := Message{Author: conn1.name, Content: response}
   317  	chatHistory = append(chatHistory, chat)
   318  
   319  	numChats -= 1
   320  
   321  	if strings.Contains(response, stopPhase) {
   322  		if numChats > 1 {
   323  			numChats = 1
   324  		}
   325  		log.Printf("%s stop received, final turn\n", conn1.name)
   326  	}
   327  
   328  	// Flip between the connection that the response is sent to.
   329  	autonomousChat(response, conn2, conn1, numChats, stopPhase, chatHistory)
   330  }
   331  
   332  // Manually interact via TCP with the GenAI endpoint
   333  func tcpListener(port string, genAiConn *connection) {
   334  	log.Printf("Starting TCP server, listening on port %s", port)
   335  	ln, err := net.Listen("tcp", ":"+port)
   336  	if err != nil {
   337  		log.Fatalf("Could not start TCP server: %v", err)
   338  	}
   339  	defer ln.Close() // nolint: errcheck
   340  
   341  	for {
   342  		conn, err := ln.Accept()
   343  		if err != nil {
   344  			log.Fatalf("Unable to accept incoming TCP connection: %v", err)
   345  		}
   346  		go tcpHandleConnection(conn, genAiConn)
   347  	}
   348  }
   349  
   350  // handleConnection services a single tcp connection to the GenAI endpoint
   351  func tcpHandleConnection(conn net.Conn, genAiConn *connection) {
   352  	log.Printf("TCP Client %s connected", conn.RemoteAddr().String())
   353  
   354  	scanner := bufio.NewScanner(conn)
   355  	for scanner.Scan() {
   356  		txt := scanner.Text()
   357  		log.Printf("TCP txt: %v", txt)
   358  
   359  		// TODO: update with chathistroy
   360  		response, err := handleGenAIRequest(txt, genAiConn, nil)
   361  		if err != nil {
   362  			response = "ERROR: " + err.Error() + "\n"
   363  		}
   364  
   365  		if _, err := conn.Write([]byte(response)); err != nil {
   366  			log.Fatalf("Could not write to TCP stream: %v", err)
   367  		}
   368  	}
   369  
   370  	log.Printf("TCP Client %s disconnected", conn.RemoteAddr().String())
   371  }
   372  
   373  // doHealth sends the regular Health Pings
   374  func doHealth(sdk *sdk.SDK, ctx context.Context) {
   375  	tick := time.Tick(2 * time.Second)
   376  	for {
   377  		err := sdk.Health()
   378  		if err != nil {
   379  			log.Fatalf("Could not send health ping, %v", err)
   380  		}
   381  		select {
   382  		case <-ctx.Done():
   383  			log.Print("Stopped health pings")
   384  			return
   385  		case <-tick:
   386  		}
   387  	}
   388  }