agones.dev/agones@v1.54.0/examples/simple-genai-server/main.go (about) 1 // Copyright 2024 Google LLC All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "bufio" 19 "bytes" 20 "context" 21 "encoding/json" 22 "flag" 23 "fmt" 24 "io" 25 "log" 26 "math/rand" 27 "net" 28 "net/http" 29 "os" 30 "strconv" 31 "strings" 32 "sync" 33 "time" 34 35 "agones.dev/agones/pkg/util/signals" 36 sdk "agones.dev/agones/sdks/go" 37 ) 38 39 // Main starts a server that serves as an example of how to integrate GenAI endpoints into your dedicated game server. 40 func main() { 41 sigCtx, _ := signals.NewSigKillContext() 42 43 port := flag.String("port", "7654", "The port to listen to traffic on") 44 genAiEndpoint := flag.String("GenAiEndpoint", "", "The full base URL to send API requests to simulate computer (NPC) responses to user input") 45 genAiContext := flag.String("GenAiContext", "", "Context for the GenAI endpoint") 46 prompt := flag.String("Prompt", "", "The first prompt for the GenAI endpoint") 47 simEndpoint := flag.String("SimEndpoint", "", "The full base URL to send API requests to simulate user input") 48 simContext := flag.String("SimContext", "", "Context for the Sim endpoint") 49 stopPhrase := flag.String("StopPhrase", "Bye!", "In autonomous chat, if either side sends this, stop after the next turn.") 50 numChats := flag.Int("NumChats", 1, "Number of back and forth chats between the sim and genAI") 51 genAiNpc := flag.Bool("GenAiNpc", false, "Set to true if the GenAIEndpoint is the npc-chat-api endpoint") 52 simNpc := flag.Bool("SimNpc", false, "Set to true if the SimEndpoint is the npc-chat-api endpoint") 53 fromId := flag.Int("FromID", 2, "Entity sending messages to the npc-chat-api. Ignored when autonomous, which uses random FromID") 54 toId := flag.Int("ToID", 1, "Entity receiving messages on the npc-chat-api (the NPC's ID)") 55 concurrentPlayers := flag.Int("ConcurrentPlayers", 1, "Number of concurrent players.") 56 57 flag.Parse() 58 if ep := os.Getenv("PORT"); ep != "" { 59 port = &ep 60 } 61 if sc := os.Getenv("SIM_CONTEXT"); sc != "" { 62 simContext = &sc 63 } 64 if ss := os.Getenv("STOP_PHRASE"); ss != "" { 65 stopPhrase = &ss 66 } 67 if gac := os.Getenv("GEN_AI_CONTEXT"); gac != "" { 68 genAiContext = &gac 69 } 70 if p := os.Getenv("PROMPT"); p != "" { 71 prompt = &p 72 } 73 if se := os.Getenv("SIM_ENDPOINT"); se != "" { 74 simEndpoint = &se 75 } 76 if gae := os.Getenv("GEN_AI_ENDPOINT"); gae != "" { 77 genAiEndpoint = &gae 78 } 79 if nc := os.Getenv("NUM_CHATS"); nc != "" { 80 num, err := strconv.Atoi(nc) 81 if err != nil { 82 log.Fatalf("Could not parse NumChats: %v", err) 83 } 84 numChats = &num 85 } 86 if gan := os.Getenv("GEN_AI_NPC"); gan != "" { 87 gnpc, err := strconv.ParseBool(gan) 88 if err != nil { 89 log.Fatalf("Could parse GenAiNpc: %v", err) 90 } 91 genAiNpc = &gnpc 92 } 93 if sn := os.Getenv("SIM_NPC"); sn != "" { 94 snpc, err := strconv.ParseBool(sn) 95 if err != nil { 96 log.Fatalf("Could parse GenAiNpc: %v", err) 97 } 98 simNpc = &snpc 99 } 100 if fid := os.Getenv("FROM_ID"); fid != "" { 101 num, err := strconv.Atoi(fid) 102 if err != nil { 103 log.Fatalf("Could not parse FromId: %v", err) 104 } 105 fromId = &num 106 } 107 if tid := os.Getenv("TO_ID"); tid != "" { 108 num, err := strconv.Atoi(tid) 109 if err != nil { 110 log.Fatalf("Could not parse ToId: %v", err) 111 } 112 toId = &num 113 } 114 if cp := os.Getenv("CONCURRENT_PLAYERS"); cp != "" { 115 num, err := strconv.Atoi(cp) 116 if err != nil { 117 log.Fatalf("Could not parse ToID: %v", err) 118 } 119 concurrentPlayers = &num 120 } 121 122 log.Print("Creating SDK instance") 123 s, err := sdk.NewSDK() 124 if err != nil { 125 log.Fatalf("Could not connect to sdk: %v", err) 126 } 127 128 log.Print("Starting Health Ping") 129 go doHealth(s, sigCtx) 130 131 log.Print("Marking this server as ready") 132 if err := s.Ready(); err != nil { 133 log.Fatalf("Could not send ready message") 134 } 135 136 if *genAiEndpoint == "" { 137 log.Fatalf("GenAiEndpoint must be specified") 138 } 139 140 // Start up TCP listener so the user can interact with the GenAI endpoint manually 141 if *simEndpoint == "" { 142 log.Printf("Creating GenAI Client at endpoint %s (from_id=%d, to_id=%d)", *genAiEndpoint, *fromId, *toId) 143 genAiConn := initClient(*genAiEndpoint, *genAiContext, "GenAI", *genAiNpc, *fromId, *toId) 144 go tcpListener(*port, genAiConn) 145 <-sigCtx.Done() 146 } else { 147 var wg sync.WaitGroup 148 149 for slot := 0; slot < *concurrentPlayers; slot++ { 150 wg.Add(1) 151 go func() { 152 defer wg.Done() 153 for { 154 // Create a random from_id and name 155 fid := int(rand.Int31()) 156 name := fmt.Sprintf("Sim%08x", fid) 157 log.Printf("=== New player %s (id %d) ===", name, fid) 158 159 log.Printf("Creating GenAI Client at endpoint %s (from_id=%d, to_id=%d)", *genAiEndpoint, fid, *toId) 160 genAiConn := initClient(*genAiEndpoint, *genAiContext, "GenAI", *genAiNpc, fid, *toId) 161 162 log.Printf("%s: Creating client at endpoint %s, sending prompt: %s", name, *simEndpoint, *prompt) 163 simConn := initClient(*simEndpoint, *simContext, name, *simNpc, *toId, *toId) 164 165 chatHistory := []Message{{Author: simConn.name, Content: *prompt}} 166 autonomousChat(*prompt, genAiConn, simConn, *numChats, *stopPhrase, chatHistory) 167 } 168 }() 169 } 170 wg.Wait() 171 } 172 173 log.Printf("Shutting down the Game Server.") 174 shutdownErr := s.Shutdown() 175 if shutdownErr != nil { 176 log.Printf("Could not shutdown") 177 } 178 os.Exit(0) 179 } 180 181 func initClient(endpoint string, context string, name string, npc bool, fromID int, toID int) *connection { 182 // TODO: create option for a client certificate 183 client := &http.Client{} 184 return &connection{client: client, endpoint: endpoint, context: context, name: name, npc: npc, fromId: fromID, toId: toID} 185 } 186 187 type connection struct { 188 client *http.Client 189 endpoint string // Full base URL for API requests 190 context string 191 name string // Human readable name for the connection 192 npc bool // True if the endpoint is the NPC API 193 fromId int // For use with NPC API, sender ID 194 toId int // For use with NPC API, receiver ID 195 // TODO: create options for routes off the base URL 196 } 197 198 // For use with Vertex APIs 199 type GenAIRequest struct { 200 Context string `json:"context,omitempty"` // Optional 201 Prompt string `json:"prompt,omitempty"` 202 ChatHistory []Message `json:"messages,omitempty"` // Optional, stores chat history for use with Vertex Chat API 203 } 204 205 // For use with NPC API 206 type NPCRequest struct { 207 Msg string `json:"message,omitempty"` 208 FromId int `json:"from_id,omitempty"` 209 ToId int `json:"to_id,omitempty"` 210 } 211 212 // Expected format for the NPC endpoint response 213 type NPCResponse struct { 214 Response string `json:"response"` 215 } 216 217 // Conversation history provided to the model in a structured alternate-author form. 218 // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-chat 219 type Message struct { 220 Author string `json:"author"` 221 Content string `json:"content"` 222 } 223 224 func handleGenAIRequest(prompt string, clientConn *connection, chatHistory []Message) (string, error) { 225 var jsonStr []byte 226 var err error 227 // If the endpoint is the NPC API, use the json request format specific to that API 228 if clientConn.npc { 229 npcRequest := NPCRequest{ 230 Msg: prompt, 231 FromId: clientConn.fromId, 232 ToId: clientConn.toId, 233 } 234 jsonStr, err = json.Marshal(npcRequest) 235 } else { 236 // Vertex expects the author to be "user" for user generated messages and "bot" for messages it previously sent. 237 // Translate the chat history we have using the connection names. 238 // 239 // You can think of `prompt` as the message that "user" is sending to "bot", meaning chatHistory should always 240 // end with "bot". 241 var ch []Message 242 for _, chat := range chatHistory { 243 newChat := Message{Content: chat.Content} 244 if chat.Author == clientConn.name { 245 newChat.Author = "user" 246 } else { 247 newChat.Author = "bot" 248 } 249 ch = append(ch, newChat) 250 } 251 if len(ch) > 0 && ch[len(ch)-1].Author != "bot" { 252 log.Fatalf("Chat history does not end in 'bot': %#v", ch) 253 } 254 255 genAIRequest := GenAIRequest{ 256 Context: clientConn.context, 257 Prompt: prompt, 258 ChatHistory: ch, 259 } 260 jsonStr, err = json.Marshal(genAIRequest) 261 } 262 if err != nil { 263 return "", fmt.Errorf("unable to marshal json request: %v", err) 264 } 265 266 req, err := http.NewRequest("POST", clientConn.endpoint, bytes.NewBuffer(jsonStr)) 267 if err != nil { 268 return "", fmt.Errorf("unable create http POST request: %v", err) 269 } 270 req.Header.Set("accept", "application/json") 271 req.Header.Set("Content-Type", "application/json") 272 273 resp, err := clientConn.client.Do(req) 274 if err != nil { 275 return "", fmt.Errorf("unable to post request: %v", err) 276 } 277 278 responseBody, err := io.ReadAll(resp.Body) 279 if err != nil { 280 return "", fmt.Errorf("unable to read response body: %v", err) 281 } 282 defer resp.Body.Close() 283 body := string(responseBody) 284 285 if resp.StatusCode != 200 { 286 err = fmt.Errorf("Status: %s, Body: %s", resp.Status, body) 287 } 288 return string(responseBody) + "\n", err 289 } 290 291 // Two AIs (connection endpoints) talking to each other 292 func autonomousChat(prompt string, conn1 *connection, conn2 *connection, numChats int, stopPhase string, chatHistory []Message) { 293 if numChats <= 0 { 294 return 295 } 296 297 startTime := time.Now() 298 response, err := handleGenAIRequest(prompt, conn1, chatHistory) 299 latency := time.Now().Sub(startTime) 300 if err != nil { 301 log.Printf("ERROR: Could not send request (stopping this chat): %v", err) 302 return 303 } 304 // If we sent the request to the NPC endpoint we need to parse the json response {response: "response"} 305 if conn1.npc { 306 npcResponse := NPCResponse{} 307 err = json.Unmarshal([]byte(response), &npcResponse) 308 if err != nil { 309 log.Fatalf("FATAL ERROR: Unable to unmarshal NPC endpoint response: %v", err) 310 } 311 response = npcResponse.Response 312 } 313 log.Printf("%s->%s [%d turns left]: %s\n", conn1.name, conn2.name, numChats, response) 314 log.Printf("%s PREDICTION RATE: %0.2f b/s", conn1.name, float64(len(response))/latency.Seconds()) 315 316 chat := Message{Author: conn1.name, Content: response} 317 chatHistory = append(chatHistory, chat) 318 319 numChats -= 1 320 321 if strings.Contains(response, stopPhase) { 322 if numChats > 1 { 323 numChats = 1 324 } 325 log.Printf("%s stop received, final turn\n", conn1.name) 326 } 327 328 // Flip between the connection that the response is sent to. 329 autonomousChat(response, conn2, conn1, numChats, stopPhase, chatHistory) 330 } 331 332 // Manually interact via TCP with the GenAI endpoint 333 func tcpListener(port string, genAiConn *connection) { 334 log.Printf("Starting TCP server, listening on port %s", port) 335 ln, err := net.Listen("tcp", ":"+port) 336 if err != nil { 337 log.Fatalf("Could not start TCP server: %v", err) 338 } 339 defer ln.Close() // nolint: errcheck 340 341 for { 342 conn, err := ln.Accept() 343 if err != nil { 344 log.Fatalf("Unable to accept incoming TCP connection: %v", err) 345 } 346 go tcpHandleConnection(conn, genAiConn) 347 } 348 } 349 350 // handleConnection services a single tcp connection to the GenAI endpoint 351 func tcpHandleConnection(conn net.Conn, genAiConn *connection) { 352 log.Printf("TCP Client %s connected", conn.RemoteAddr().String()) 353 354 scanner := bufio.NewScanner(conn) 355 for scanner.Scan() { 356 txt := scanner.Text() 357 log.Printf("TCP txt: %v", txt) 358 359 // TODO: update with chathistroy 360 response, err := handleGenAIRequest(txt, genAiConn, nil) 361 if err != nil { 362 response = "ERROR: " + err.Error() + "\n" 363 } 364 365 if _, err := conn.Write([]byte(response)); err != nil { 366 log.Fatalf("Could not write to TCP stream: %v", err) 367 } 368 } 369 370 log.Printf("TCP Client %s disconnected", conn.RemoteAddr().String()) 371 } 372 373 // doHealth sends the regular Health Pings 374 func doHealth(sdk *sdk.SDK, ctx context.Context) { 375 tick := time.Tick(2 * time.Second) 376 for { 377 err := sdk.Health() 378 if err != nil { 379 log.Fatalf("Could not send health ping, %v", err) 380 } 381 select { 382 case <-ctx.Done(): 383 log.Print("Stopped health pings") 384 return 385 case <-tick: 386 } 387 } 388 }