github.com/instill-ai/component@v0.16.0-beta/pkg/connector/redis/v0/chat_history.go (about) 1 package redis 2 3 import ( 4 "context" 5 "encoding/json" 6 "sort" 7 "time" 8 9 goredis "github.com/redis/go-redis/v9" 10 ) 11 12 var ( 13 // DefaultLatestK is the default number of latest conversation turns to retrieve 14 DefaultLatestK = 5 15 ) 16 17 type Message struct { 18 Role string `json:"role"` 19 Content string `json:"content"` 20 Metadata *map[string]interface{} `json:"metadata,omitempty"` 21 } 22 23 type MultiModalMessage struct { 24 Role string `json:"role"` 25 Content []MultiModalContent `json:"content"` 26 Metadata *map[string]interface{} `json:"metadata,omitempty"` 27 } 28 29 type MultiModalContent struct { 30 Type string `json:"type"` 31 Text *string `json:"text,omitempty"` 32 ImageURL *struct { 33 URL string `json:"url"` 34 } `json:"image_url,omitempty"` 35 } 36 37 type MessageWithTime struct { 38 Message 39 Timestamp int64 `json:"timestamp"` 40 } 41 42 type MultiModalMessageWithTime struct { 43 MultiModalMessage 44 Timestamp int64 `json:"timestamp"` 45 } 46 47 type ChatMessageWriteInput struct { 48 SessionID string `json:"session_id"` 49 Message 50 } 51 52 type ChatMultiModalMessageWriteInput struct { 53 SessionID string `json:"session_id"` 54 MultiModalMessage 55 } 56 57 type ChatMessageWriteOutput struct { 58 Status bool `json:"status"` 59 } 60 61 type ChatHistoryRetrieveInput struct { 62 SessionID string `json:"session_id"` 63 LatestK *int `json:"latest_k,omitempty"` 64 IncludeSystemMessage bool `json:"include_system_message"` 65 } 66 67 // ChatHistoryReadOutput is a wrapper struct for the messages associated with a session ID 68 type ChatHistoryRetrieveOutput struct { 69 Messages []*MultiModalMessage `json:"messages"` 70 Status bool `json:"status"` 71 } 72 73 // WriteSystemMessage writes system message for a given session ID 74 func WriteSystemMessage(client *goredis.Client, sessionID string, message MultiModalMessageWithTime) error { 75 messageJSON, err := json.Marshal(message) 76 if err != nil { 77 return err 78 } 79 80 // Store in a hash with a unique SessionID 81 return client.HSet(context.Background(), "chat_history:system_messages", sessionID, messageJSON).Err() 82 } 83 84 func WriteNonSystemMessage(client *goredis.Client, sessionID string, message MultiModalMessageWithTime) error { 85 // Marshal the MessageWithTime struct to JSON 86 messageJSON, err := json.Marshal(message) 87 if err != nil { 88 return err 89 } 90 91 // Index by Timestamp: Add to the Sorted Set 92 return client.ZAdd(context.Background(), "chat_history:"+sessionID+":timestamps", goredis.Z{ 93 Score: float64(message.Timestamp), 94 Member: string(messageJSON), 95 }).Err() 96 } 97 98 // RetrieveSystemMessage gets system message based on a given session ID 99 func RetrieveSystemMessage(client *goredis.Client, sessionID string) (bool, *MultiModalMessageWithTime, error) { 100 serializedMessage, err := client.HGet(context.Background(), "chat_history:system_messages", sessionID).Result() 101 102 // Check if the messageID does not exist 103 if err == goredis.Nil { 104 // Handle the case where the message does not exist 105 return false, nil, nil 106 } else if err != nil { 107 // Handle other types of errors 108 return false, nil, err 109 } 110 111 var message MultiModalMessageWithTime 112 if err := json.Unmarshal([]byte(serializedMessage), &message); err != nil { 113 return false, nil, err 114 } 115 116 return true, &message, nil 117 } 118 119 func WriteMessage(client *goredis.Client, input ChatMessageWriteInput) ChatMessageWriteOutput { 120 // Current time 121 currTime := time.Now().Unix() 122 123 // Create a MessageWithTime struct with the provided input and timestamp 124 messageWithTime := MultiModalMessageWithTime{ 125 MultiModalMessage: MultiModalMessage{ 126 Role: input.Role, 127 Content: []MultiModalContent{ 128 { 129 Type: "text", 130 Text: &input.Content, 131 }, 132 }, 133 Metadata: input.Metadata, 134 }, 135 Timestamp: currTime, 136 } 137 138 // Treat system message differently 139 if input.Role == "system" { 140 err := WriteSystemMessage(client, input.SessionID, messageWithTime) 141 if err != nil { 142 return ChatMessageWriteOutput{Status: false} 143 } else { 144 return ChatMessageWriteOutput{Status: true} 145 } 146 } 147 148 err := WriteNonSystemMessage(client, input.SessionID, messageWithTime) 149 if err != nil { 150 return ChatMessageWriteOutput{Status: false} 151 } else { 152 return ChatMessageWriteOutput{Status: true} 153 } 154 } 155 156 func WriteMultiModelMessage(client *goredis.Client, input ChatMultiModalMessageWriteInput) ChatMessageWriteOutput { 157 // Current time 158 currTime := time.Now().Unix() 159 160 // Create a MessageWithTime struct with the provided input and timestamp 161 messageWithTime := MultiModalMessageWithTime{ 162 MultiModalMessage: MultiModalMessage{ 163 Role: input.Role, 164 Content: input.Content, 165 Metadata: input.Metadata, 166 }, 167 Timestamp: currTime, 168 } 169 170 // Treat system message differently 171 if input.Role == "system" { 172 err := WriteSystemMessage(client, input.SessionID, messageWithTime) 173 if err != nil { 174 return ChatMessageWriteOutput{Status: false} 175 } else { 176 return ChatMessageWriteOutput{Status: true} 177 } 178 } 179 180 err := WriteNonSystemMessage(client, input.SessionID, messageWithTime) 181 if err != nil { 182 return ChatMessageWriteOutput{Status: false} 183 } else { 184 return ChatMessageWriteOutput{Status: true} 185 } 186 } 187 188 // RetrieveSessionMessages retrieves the latest K conversation turns from the Redis list for the given session ID 189 func RetrieveSessionMessages(client *goredis.Client, input ChatHistoryRetrieveInput) ChatHistoryRetrieveOutput { 190 if input.LatestK == nil || *input.LatestK <= 0 { 191 input.LatestK = &DefaultLatestK 192 } 193 key := input.SessionID 194 195 messagesWithTime := []MultiModalMessageWithTime{} 196 messages := []*MultiModalMessage{} 197 ctx := context.Background() 198 199 // Retrieve the latest K conversation turns associated with the session ID by descending timestamp order 200 messagesNum := *input.LatestK * 2 201 timestampMessages, err := client.ZRevRange(ctx, "chat_history:"+key+":timestamps", 0, int64(messagesNum-1)).Result() 202 if err != nil { 203 return ChatHistoryRetrieveOutput{ 204 Messages: messages, 205 Status: false, 206 } 207 } 208 209 // Iterate through the members and deserialize them into MessageWithTime 210 for _, member := range timestampMessages { 211 var messageWithTime MultiModalMessageWithTime 212 if err := json.Unmarshal([]byte(member), &messageWithTime); err != nil { 213 return ChatHistoryRetrieveOutput{ 214 Messages: messages, 215 Status: false, 216 } 217 } 218 messagesWithTime = append(messagesWithTime, messageWithTime) 219 } 220 221 // Sort the messages by timestamp in ascending order (earliest first) 222 sort.SliceStable(messagesWithTime, func(i, j int) bool { 223 return messagesWithTime[i].Timestamp < messagesWithTime[j].Timestamp 224 }) 225 226 // Add System message if exist 227 if input.IncludeSystemMessage { 228 exist, sysMessage, err := RetrieveSystemMessage(client, input.SessionID) 229 if err != nil { 230 return ChatHistoryRetrieveOutput{ 231 Messages: messages, 232 Status: false, 233 } 234 } 235 if exist { 236 messages = append(messages, &MultiModalMessage{ 237 Role: sysMessage.Role, 238 Content: sysMessage.Content, 239 Metadata: sysMessage.Metadata, 240 }) 241 } 242 } 243 244 // Convert the MessageWithTime structs to Message structs 245 for _, m := range messagesWithTime { 246 messages = append(messages, &MultiModalMessage{ 247 Role: m.Role, 248 Content: m.Content, 249 Metadata: m.Metadata, 250 }) 251 } 252 return ChatHistoryRetrieveOutput{ 253 Messages: messages, 254 Status: true, 255 } 256 }