github.com/instill-ai/component@v0.16.0-beta/pkg/connector/redis/v0/chat_history.go (about)

     1  package redis
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"sort"
     7  	"time"
     8  
     9  	goredis "github.com/redis/go-redis/v9"
    10  )
    11  
    12  var (
    13  	// DefaultLatestK is the default number of latest conversation turns to retrieve
    14  	DefaultLatestK = 5
    15  )
    16  
    17  type Message struct {
    18  	Role     string                  `json:"role"`
    19  	Content  string                  `json:"content"`
    20  	Metadata *map[string]interface{} `json:"metadata,omitempty"`
    21  }
    22  
    23  type MultiModalMessage struct {
    24  	Role     string                  `json:"role"`
    25  	Content  []MultiModalContent     `json:"content"`
    26  	Metadata *map[string]interface{} `json:"metadata,omitempty"`
    27  }
    28  
    29  type MultiModalContent struct {
    30  	Type     string  `json:"type"`
    31  	Text     *string `json:"text,omitempty"`
    32  	ImageURL *struct {
    33  		URL string `json:"url"`
    34  	} `json:"image_url,omitempty"`
    35  }
    36  
    37  type MessageWithTime struct {
    38  	Message
    39  	Timestamp int64 `json:"timestamp"`
    40  }
    41  
    42  type MultiModalMessageWithTime struct {
    43  	MultiModalMessage
    44  	Timestamp int64 `json:"timestamp"`
    45  }
    46  
    47  type ChatMessageWriteInput struct {
    48  	SessionID string `json:"session_id"`
    49  	Message
    50  }
    51  
    52  type ChatMultiModalMessageWriteInput struct {
    53  	SessionID string `json:"session_id"`
    54  	MultiModalMessage
    55  }
    56  
    57  type ChatMessageWriteOutput struct {
    58  	Status bool `json:"status"`
    59  }
    60  
    61  type ChatHistoryRetrieveInput struct {
    62  	SessionID            string `json:"session_id"`
    63  	LatestK              *int   `json:"latest_k,omitempty"`
    64  	IncludeSystemMessage bool   `json:"include_system_message"`
    65  }
    66  
    67  // ChatHistoryReadOutput is a wrapper struct for the messages associated with a session ID
    68  type ChatHistoryRetrieveOutput struct {
    69  	Messages []*MultiModalMessage `json:"messages"`
    70  	Status   bool                 `json:"status"`
    71  }
    72  
    73  // WriteSystemMessage writes system message for a given session ID
    74  func WriteSystemMessage(client *goredis.Client, sessionID string, message MultiModalMessageWithTime) error {
    75  	messageJSON, err := json.Marshal(message)
    76  	if err != nil {
    77  		return err
    78  	}
    79  
    80  	// Store in a hash with a unique SessionID
    81  	return client.HSet(context.Background(), "chat_history:system_messages", sessionID, messageJSON).Err()
    82  }
    83  
    84  func WriteNonSystemMessage(client *goredis.Client, sessionID string, message MultiModalMessageWithTime) error {
    85  	// Marshal the MessageWithTime struct to JSON
    86  	messageJSON, err := json.Marshal(message)
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	// Index by Timestamp: Add to the Sorted Set
    92  	return client.ZAdd(context.Background(), "chat_history:"+sessionID+":timestamps", goredis.Z{
    93  		Score:  float64(message.Timestamp),
    94  		Member: string(messageJSON),
    95  	}).Err()
    96  }
    97  
    98  // RetrieveSystemMessage gets system message based on a given session ID
    99  func RetrieveSystemMessage(client *goredis.Client, sessionID string) (bool, *MultiModalMessageWithTime, error) {
   100  	serializedMessage, err := client.HGet(context.Background(), "chat_history:system_messages", sessionID).Result()
   101  
   102  	// Check if the messageID does not exist
   103  	if err == goredis.Nil {
   104  		// Handle the case where the message does not exist
   105  		return false, nil, nil
   106  	} else if err != nil {
   107  		// Handle other types of errors
   108  		return false, nil, err
   109  	}
   110  
   111  	var message MultiModalMessageWithTime
   112  	if err := json.Unmarshal([]byte(serializedMessage), &message); err != nil {
   113  		return false, nil, err
   114  	}
   115  
   116  	return true, &message, nil
   117  }
   118  
   119  func WriteMessage(client *goredis.Client, input ChatMessageWriteInput) ChatMessageWriteOutput {
   120  	// Current time
   121  	currTime := time.Now().Unix()
   122  
   123  	// Create a MessageWithTime struct with the provided input and timestamp
   124  	messageWithTime := MultiModalMessageWithTime{
   125  		MultiModalMessage: MultiModalMessage{
   126  			Role: input.Role,
   127  			Content: []MultiModalContent{
   128  				{
   129  					Type: "text",
   130  					Text: &input.Content,
   131  				},
   132  			},
   133  			Metadata: input.Metadata,
   134  		},
   135  		Timestamp: currTime,
   136  	}
   137  
   138  	// Treat system message differently
   139  	if input.Role == "system" {
   140  		err := WriteSystemMessage(client, input.SessionID, messageWithTime)
   141  		if err != nil {
   142  			return ChatMessageWriteOutput{Status: false}
   143  		} else {
   144  			return ChatMessageWriteOutput{Status: true}
   145  		}
   146  	}
   147  
   148  	err := WriteNonSystemMessage(client, input.SessionID, messageWithTime)
   149  	if err != nil {
   150  		return ChatMessageWriteOutput{Status: false}
   151  	} else {
   152  		return ChatMessageWriteOutput{Status: true}
   153  	}
   154  }
   155  
   156  func WriteMultiModelMessage(client *goredis.Client, input ChatMultiModalMessageWriteInput) ChatMessageWriteOutput {
   157  	// Current time
   158  	currTime := time.Now().Unix()
   159  
   160  	// Create a MessageWithTime struct with the provided input and timestamp
   161  	messageWithTime := MultiModalMessageWithTime{
   162  		MultiModalMessage: MultiModalMessage{
   163  			Role:     input.Role,
   164  			Content:  input.Content,
   165  			Metadata: input.Metadata,
   166  		},
   167  		Timestamp: currTime,
   168  	}
   169  
   170  	// Treat system message differently
   171  	if input.Role == "system" {
   172  		err := WriteSystemMessage(client, input.SessionID, messageWithTime)
   173  		if err != nil {
   174  			return ChatMessageWriteOutput{Status: false}
   175  		} else {
   176  			return ChatMessageWriteOutput{Status: true}
   177  		}
   178  	}
   179  
   180  	err := WriteNonSystemMessage(client, input.SessionID, messageWithTime)
   181  	if err != nil {
   182  		return ChatMessageWriteOutput{Status: false}
   183  	} else {
   184  		return ChatMessageWriteOutput{Status: true}
   185  	}
   186  }
   187  
   188  // RetrieveSessionMessages retrieves the latest K conversation turns from the Redis list for the given session ID
   189  func RetrieveSessionMessages(client *goredis.Client, input ChatHistoryRetrieveInput) ChatHistoryRetrieveOutput {
   190  	if input.LatestK == nil || *input.LatestK <= 0 {
   191  		input.LatestK = &DefaultLatestK
   192  	}
   193  	key := input.SessionID
   194  
   195  	messagesWithTime := []MultiModalMessageWithTime{}
   196  	messages := []*MultiModalMessage{}
   197  	ctx := context.Background()
   198  
   199  	// Retrieve the latest K conversation turns associated with the session ID by descending timestamp order
   200  	messagesNum := *input.LatestK * 2
   201  	timestampMessages, err := client.ZRevRange(ctx, "chat_history:"+key+":timestamps", 0, int64(messagesNum-1)).Result()
   202  	if err != nil {
   203  		return ChatHistoryRetrieveOutput{
   204  			Messages: messages,
   205  			Status:   false,
   206  		}
   207  	}
   208  
   209  	// Iterate through the members and deserialize them into MessageWithTime
   210  	for _, member := range timestampMessages {
   211  		var messageWithTime MultiModalMessageWithTime
   212  		if err := json.Unmarshal([]byte(member), &messageWithTime); err != nil {
   213  			return ChatHistoryRetrieveOutput{
   214  				Messages: messages,
   215  				Status:   false,
   216  			}
   217  		}
   218  		messagesWithTime = append(messagesWithTime, messageWithTime)
   219  	}
   220  
   221  	// Sort the messages by timestamp in ascending order (earliest first)
   222  	sort.SliceStable(messagesWithTime, func(i, j int) bool {
   223  		return messagesWithTime[i].Timestamp < messagesWithTime[j].Timestamp
   224  	})
   225  
   226  	// Add System message if exist
   227  	if input.IncludeSystemMessage {
   228  		exist, sysMessage, err := RetrieveSystemMessage(client, input.SessionID)
   229  		if err != nil {
   230  			return ChatHistoryRetrieveOutput{
   231  				Messages: messages,
   232  				Status:   false,
   233  			}
   234  		}
   235  		if exist {
   236  			messages = append(messages, &MultiModalMessage{
   237  				Role:     sysMessage.Role,
   238  				Content:  sysMessage.Content,
   239  				Metadata: sysMessage.Metadata,
   240  			})
   241  		}
   242  	}
   243  
   244  	// Convert the MessageWithTime structs to Message structs
   245  	for _, m := range messagesWithTime {
   246  		messages = append(messages, &MultiModalMessage{
   247  			Role:     m.Role,
   248  			Content:  m.Content,
   249  			Metadata: m.Metadata,
   250  		})
   251  	}
   252  	return ChatHistoryRetrieveOutput{
   253  		Messages: messages,
   254  		Status:   true,
   255  	}
   256  }