go.temporal.io/server@v1.23.0/common/persistence/cassandra/queue_v2_store.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package cassandra 26 27 import ( 28 "context" 29 "fmt" 30 31 commonpb "go.temporal.io/api/common/v1" 32 "go.temporal.io/api/enums/v1" 33 persistencespb "go.temporal.io/server/api/persistence/v1" 34 "go.temporal.io/server/common/log" 35 "go.temporal.io/server/common/persistence" 36 "go.temporal.io/server/common/persistence/nosql/nosqlplugin/cassandra/gocql" 37 "go.temporal.io/server/common/persistence/serialization" 38 ) 39 40 type ( 41 // queueV2Store contains the SQL queries and serialization/deserialization functions to interact with the queues and 42 // queue_messages tables that implement the QueueV2 interface. The schema is located at: 43 // schema/cassandra/temporal/versioned/v1.9/queues.cql 44 queueV2Store struct { 45 session gocql.Session 46 logger log.Logger 47 } 48 49 Queue struct { 50 Metadata *persistencespb.Queue 51 Version int64 52 } 53 ) 54 55 const ( 56 TemplateEnqueueMessageQuery = `INSERT INTO queue_messages (queue_type, queue_name, queue_partition, message_id, message_payload, message_encoding) VALUES (?, ?, ?, ?, ?, ?) IF NOT EXISTS` 57 TemplateGetMessagesQuery = `SELECT message_id, message_payload, message_encoding FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? ORDER BY message_id ASC LIMIT ?` 58 TemplateGetMaxMessageIDQuery = `SELECT message_id FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? ORDER BY message_id DESC LIMIT 1` 59 TemplateCreateQueueQuery = `INSERT INTO queues (queue_type, queue_name, metadata_payload, metadata_encoding, version) VALUES (?, ?, ?, ?, ?) IF NOT EXISTS` 60 TemplateGetQueueQuery = `SELECT metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? AND queue_name = ?` 61 TemplateRangeDeleteMessagesQuery = `DELETE FROM queue_messages WHERE queue_type = ? AND queue_name = ? AND queue_partition = ? AND message_id >= ? AND message_id <= ?` 62 TemplateUpdateQueueMetadataQuery = `UPDATE queues SET metadata_payload = ?, metadata_encoding = ?, version = ? WHERE queue_type = ? AND queue_name = ? IF version = ?` 63 // We will have to ALLOW FILTERING for this query since partition key consists of both queue_type and queue_name. 64 templateGetQueueNamesQuery = `SELECT queue_name, metadata_payload, metadata_encoding, version FROM queues WHERE queue_type = ? ALLOW FILTERING` 65 ) 66 67 var ( 68 // ErrEnqueueMessageConflict is returned when a message with the same ID already exists in the queue. This is 69 // possible when there are concurrent writes to the queue because we enqueue a message using two queries: 70 // 71 // 1. SELECT MAX(ID) to get the next message ID (for a given queue partition) 72 // 2. INSERT (ID, message) with IF NOT EXISTS 73 // 74 // See the following example: 75 // 76 // Client A Client B Cassandra DB 77 // | | | 78 // |--1. SELECT MAX(ID) FROM queue_messages----------------------->| 79 // | | | 80 // |<-2. Return X--------------------------------------------------| 81 // | | | 82 // | |--3. SELECT MAX(ID) FROM queue_messages---->| 83 // | | | 84 // | |<-4. Return X-------------------------------| 85 // | | | 86 // |--5. INSERT INTO queue_messages (ID = X)---------------------->| 87 // | | | 88 // |<-6. Acknowledge-----------------------------------------------| 89 // | | | 90 // | |--7. INSERT INTO queue_messages (ID = X)--->| 91 // | | | 92 // | |<-8. Conflict/Error-------------------------| 93 // | | | 94 ErrEnqueueMessageConflict = &persistence.ConditionFailedError{ 95 Msg: "conflict inserting queue message, likely due to concurrent writes", 96 } 97 // ErrUpdateQueueConflict is returned when a queue is updated with the wrong version. This happens when there are 98 // concurrent writes to the queue because we update a queue using two queries, similar to the enqueue message query. 99 // 100 // 1. SELECT (queue, version) FROM queues 101 // 2. UPDATE queue, version IF version = version from step 1 102 // 103 // See the following example: 104 // 105 // Client A Client B Cassandra DB 106 // | | | 107 // |--1. SELECT (queue, version) FROM queues---------------------->| 108 // | | | 109 // |<-2. Return (queue, v1)----------------------------------------| 110 // | | | 111 // | |--3. SELECT (queue, version) FROM queues--->| 112 // | | | 113 // | |<-4. Return (queue, v1)---------------------| 114 // | | | 115 // |--5. UPDATE queue, version IF version = v1-------------------->| 116 // | | | 117 // |<-6. Acknowledge-----------------------------------------------| 118 // | | | 119 // | |--7. UPDATE queue, version IF version = v1->| 120 // | | | 121 // | |<-8. Conflict/Error-------------------------| 122 // | | | 123 ErrUpdateQueueConflict = &persistence.ConditionFailedError{ 124 Msg: "conflict updating queue, likely due to concurrent writes", 125 } 126 ) 127 128 func NewQueueV2Store(session gocql.Session, logger log.Logger) persistence.QueueV2 { 129 return &queueV2Store{ 130 session: session, 131 logger: logger, 132 } 133 } 134 135 func (s *queueV2Store) EnqueueMessage( 136 ctx context.Context, 137 request *persistence.InternalEnqueueMessageRequest, 138 ) (*persistence.InternalEnqueueMessageResponse, error) { 139 // TODO: add concurrency control around this method to avoid things like QueueMessageIDConflict. 140 // TODO: cache the queue in memory to avoid querying the database every time. 141 _, err := s.getQueue(ctx, request.QueueType, request.QueueName) 142 if err != nil { 143 return nil, err 144 } 145 messageID, err := s.getNextMessageID(ctx, request.QueueType, request.QueueName) 146 if err != nil { 147 return nil, err 148 } 149 err = s.tryInsert(ctx, request.QueueType, request.QueueName, request.Blob, messageID) 150 if err != nil { 151 return nil, err 152 } 153 return &persistence.InternalEnqueueMessageResponse{ 154 Metadata: persistence.MessageMetadata{ID: messageID}, 155 }, nil 156 } 157 158 func (s *queueV2Store) ReadMessages( 159 ctx context.Context, 160 request *persistence.InternalReadMessagesRequest, 161 ) (*persistence.InternalReadMessagesResponse, error) { 162 q, err := s.getQueue(ctx, request.QueueType, request.QueueName) 163 if err != nil { 164 return nil, err 165 } 166 if request.PageSize <= 0 { 167 return nil, persistence.ErrNonPositiveReadQueueMessagesPageSize 168 } 169 minMessageID, err := persistence.GetMinMessageIDToReadForQueueV2(request.QueueType, request.QueueName, request.NextPageToken, q.Metadata) 170 if err != nil { 171 return nil, err 172 } 173 174 iter := s.session.Query( 175 TemplateGetMessagesQuery, 176 request.QueueType, 177 request.QueueName, 178 0, 179 int(minMessageID), 180 request.PageSize, 181 ).WithContext(ctx).Iter() 182 183 var ( 184 messages []persistence.QueueV2Message 185 // messageID is the ID of the last message returned by the query. 186 messageID int64 187 ) 188 189 for { 190 var ( 191 messagePayload []byte 192 messageEncoding string 193 ) 194 if !iter.Scan(&messageID, &messagePayload, &messageEncoding) { 195 break 196 } 197 encoding, err := enums.EncodingTypeFromString(messageEncoding) 198 if err != nil { 199 return nil, serialization.NewUnknownEncodingTypeError(messageEncoding) 200 } 201 202 encodingType := enums.EncodingType(encoding) 203 204 message := persistence.QueueV2Message{ 205 MetaData: persistence.MessageMetadata{ID: messageID}, 206 Data: &commonpb.DataBlob{ 207 EncodingType: encodingType, 208 Data: messagePayload, 209 }, 210 } 211 messages = append(messages, message) 212 } 213 214 if err := iter.Close(); err != nil { 215 return nil, gocql.ConvertError("QueueV2ReadMessages", err) 216 } 217 218 nextPageToken := persistence.GetNextPageTokenForReadMessages(messages) 219 return &persistence.InternalReadMessagesResponse{ 220 Messages: messages, 221 NextPageToken: nextPageToken, 222 }, nil 223 } 224 225 func (s *queueV2Store) CreateQueue( 226 ctx context.Context, 227 request *persistence.InternalCreateQueueRequest, 228 ) (*persistence.InternalCreateQueueResponse, error) { 229 queueType := request.QueueType 230 queueName := request.QueueName 231 q := persistencespb.Queue{ 232 Partitions: map[int32]*persistencespb.QueuePartition{ 233 0: { 234 MinMessageId: persistence.FirstQueueMessageID, 235 }, 236 }, 237 } 238 bytes, _ := q.Marshal() 239 applied, err := s.session.Query( 240 TemplateCreateQueueQuery, 241 queueType, 242 queueName, 243 bytes, 244 enums.ENCODING_TYPE_PROTO3.String(), 245 0, 246 ).WithContext(ctx).MapScanCAS(make(map[string]interface{})) 247 if err != nil { 248 return nil, gocql.ConvertError("QueueV2CreateQueue", err) 249 } 250 251 if !applied { 252 return nil, fmt.Errorf( 253 "%w: queue type %v and name %v", 254 persistence.ErrQueueAlreadyExists, 255 queueType, 256 queueName, 257 ) 258 } 259 return &persistence.InternalCreateQueueResponse{}, nil 260 } 261 262 func (s *queueV2Store) RangeDeleteMessages( 263 ctx context.Context, 264 request *persistence.InternalRangeDeleteMessagesRequest, 265 ) (*persistence.InternalRangeDeleteMessagesResponse, error) { 266 if request.InclusiveMaxMessageMetadata.ID < persistence.FirstQueueMessageID { 267 return nil, fmt.Errorf( 268 "%w: id is %d but must be >= %d", 269 persistence.ErrInvalidQueueRangeDeleteMaxMessageID, 270 request.InclusiveMaxMessageMetadata.ID, 271 persistence.FirstQueueMessageID, 272 ) 273 } 274 queueType := request.QueueType 275 queueName := request.QueueName 276 q, err := s.getQueue(ctx, queueType, queueName) 277 if err != nil { 278 return nil, err 279 } 280 partition, err := persistence.GetPartitionForQueueV2(queueType, queueName, q.Metadata) 281 if err != nil { 282 return nil, err 283 } 284 maxMessageID, ok, err := s.getMaxMessageID(ctx, queueType, queueName) 285 if err != nil { 286 return nil, err 287 } 288 if !ok { 289 // Nothing in the queue to delete. 290 return &persistence.InternalRangeDeleteMessagesResponse{}, nil 291 } 292 deleteRange, ok := persistence.GetDeleteRange(persistence.DeleteRequest{ 293 LastIDToDeleteInclusive: request.InclusiveMaxMessageMetadata.ID, 294 ExistingMessageRange: persistence.InclusiveMessageRange{ 295 MinMessageID: partition.MinMessageId, 296 MaxMessageID: maxMessageID, 297 }, 298 }) 299 if !ok { 300 return &persistence.InternalRangeDeleteMessagesResponse{}, nil 301 } 302 err = s.session.Query( 303 TemplateRangeDeleteMessagesQuery, 304 queueType, 305 queueName, 306 0, // partition 307 deleteRange.MinMessageID, 308 deleteRange.MaxMessageID, 309 ).WithContext(ctx).Exec() 310 if err != nil { 311 return nil, gocql.ConvertError("QueueV2RangeDeleteMessages", err) 312 } 313 partition.MinMessageId = deleteRange.NewMinMessageID 314 err = s.updateQueue(ctx, q, queueType, queueName) 315 if err != nil { 316 return nil, err 317 } 318 return &persistence.InternalRangeDeleteMessagesResponse{ 319 MessagesDeleted: deleteRange.MessagesToDelete, 320 }, nil 321 } 322 323 func (s *queueV2Store) updateQueue( 324 ctx context.Context, 325 q *Queue, 326 queueType persistence.QueueV2Type, 327 queueName string, 328 ) error { 329 bytes, _ := q.Metadata.Marshal() 330 version := q.Version 331 nextVersion := version + 1 332 q.Version = nextVersion 333 applied, err := s.session.Query( 334 TemplateUpdateQueueMetadataQuery, 335 bytes, 336 enums.ENCODING_TYPE_PROTO3.String(), 337 nextVersion, 338 queueType, 339 queueName, 340 version, 341 ).WithContext(ctx).MapScanCAS(make(map[string]interface{})) 342 if err != nil { 343 return gocql.ConvertError("QueueV2UpdateQueueMetadata", err) 344 } 345 if !applied { 346 return fmt.Errorf( 347 "%w: queue type %v and name %v", 348 ErrUpdateQueueConflict, 349 queueType, 350 queueName, 351 ) 352 } 353 return nil 354 } 355 356 func (s *queueV2Store) tryInsert( 357 ctx context.Context, 358 queueType persistence.QueueV2Type, 359 queueName string, 360 blob *commonpb.DataBlob, 361 messageID int64, 362 ) error { 363 applied, err := s.session.Query( 364 TemplateEnqueueMessageQuery, 365 queueType, 366 queueName, 367 0, 368 messageID, 369 blob.Data, 370 blob.EncodingType.String(), 371 ).WithContext(ctx).MapScanCAS(make(map[string]interface{})) 372 if err != nil { 373 return gocql.ConvertError("QueueV2EnqueueMessage", err) 374 } 375 if !applied { 376 return fmt.Errorf( 377 "%w: queue type %v and name %v already has a message with ID %v", 378 ErrEnqueueMessageConflict, 379 queueType, 380 queueName, 381 messageID, 382 ) 383 } 384 385 return nil 386 } 387 388 func (s *queueV2Store) getQueue( 389 ctx context.Context, 390 queueType persistence.QueueV2Type, 391 name string, 392 ) (*Queue, error) { 393 return GetQueue(ctx, s.session, name, queueType) 394 } 395 396 func GetQueue( 397 ctx context.Context, 398 session gocql.Session, 399 queueName string, 400 queueType persistence.QueueV2Type, 401 ) (*Queue, error) { 402 var ( 403 queueBytes []byte 404 queueEncodingStr string 405 version int64 406 ) 407 408 err := session.Query(TemplateGetQueueQuery, queueType, queueName).WithContext(ctx).Scan( 409 &queueBytes, 410 &queueEncodingStr, 411 &version, 412 ) 413 if err != nil { 414 if gocql.IsNotFoundError(err) { 415 return nil, persistence.NewQueueNotFoundError(queueType, queueName) 416 } 417 return nil, gocql.ConvertError("QueueV2GetQueue", err) 418 } 419 return getQueueFromMetadata(queueType, queueName, queueBytes, queueEncodingStr, version) 420 } 421 422 func getQueueFromMetadata( 423 queueType persistence.QueueV2Type, 424 queueName string, 425 queueBytes []byte, 426 queueEncodingStr string, 427 version int64, 428 ) (*Queue, error) { 429 if queueEncodingStr != enums.ENCODING_TYPE_PROTO3.String() { 430 return nil, fmt.Errorf( 431 "%w: invalid queue encoding type: queue with type %v and name %v has invalid encoding", 432 serialization.NewUnknownEncodingTypeError(queueEncodingStr, enums.ENCODING_TYPE_PROTO3), 433 queueType, 434 queueName, 435 ) 436 } 437 438 q := &persistencespb.Queue{} 439 err := q.Unmarshal(queueBytes) 440 if err != nil { 441 return nil, serialization.NewDeserializationError( 442 enums.ENCODING_TYPE_PROTO3, 443 fmt.Errorf("%w: unmarshal queue payload: failed for queue with type %v and name %v", 444 err, queueType, queueName), 445 ) 446 } 447 448 return &Queue{ 449 Metadata: q, 450 Version: version, 451 }, nil 452 } 453 454 func (s *queueV2Store) getNextMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string) (int64, error) { 455 maxMessageID, ok, err := s.getMaxMessageID(ctx, queueType, queueName) 456 if err != nil { 457 return 0, err 458 } 459 if !ok { 460 return persistence.FirstQueueMessageID, nil 461 } 462 463 // The next message ID is the max message ID + 1. 464 return maxMessageID + 1, nil 465 } 466 467 func (s *queueV2Store) getMaxMessageID(ctx context.Context, queueType persistence.QueueV2Type, queueName string) (int64, bool, error) { 468 var maxMessageID int64 469 470 err := s.session.Query(TemplateGetMaxMessageIDQuery, queueType, queueName, 0).WithContext(ctx).Scan(&maxMessageID) 471 if err != nil { 472 if gocql.IsNotFoundError(err) { 473 return 0, false, nil 474 } 475 return 0, false, gocql.ConvertError("QueueV2GetMaxMessageID", err) 476 } 477 return maxMessageID, true, nil 478 } 479 480 func (s *queueV2Store) ListQueues( 481 ctx context.Context, 482 request *persistence.InternalListQueuesRequest, 483 ) (*persistence.InternalListQueuesResponse, error) { 484 if request.PageSize <= 0 { 485 return nil, persistence.ErrNonPositiveListQueuesPageSize 486 } 487 iter := s.session.Query( 488 templateGetQueueNamesQuery, 489 request.QueueType, 490 ).PageSize(request.PageSize).PageState(request.NextPageToken).WithContext(ctx).Iter() 491 492 var queues []persistence.QueueInfo 493 for { 494 var ( 495 queueName string 496 metadataBytes []byte 497 metadataEncoding string 498 version int64 499 ) 500 if !iter.Scan(&queueName, &metadataBytes, &metadataEncoding, &version) { 501 break 502 } 503 q, err := getQueueFromMetadata(request.QueueType, queueName, metadataBytes, metadataEncoding, version) 504 if err != nil { 505 return nil, err 506 } 507 partition, err := persistence.GetPartitionForQueueV2(request.QueueType, queueName, q.Metadata) 508 if err != nil { 509 return nil, err 510 } 511 nextMessageID, err := s.getNextMessageID(ctx, request.QueueType, queueName) 512 if err != nil { 513 return nil, err 514 } 515 messageCount := nextMessageID - partition.MinMessageId 516 queues = append(queues, persistence.QueueInfo{ 517 QueueName: queueName, 518 MessageCount: messageCount, 519 }) 520 } 521 if err := iter.Close(); err != nil { 522 return nil, gocql.ConvertError("QueueV2ListQueues", err) 523 } 524 return &persistence.InternalListQueuesResponse{ 525 Queues: queues, 526 NextPageToken: iter.PageState(), 527 }, nil 528 }