github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/database/sqlcommon/message_sql.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package sqlcommon
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"fmt"
    23  
    24  	sq "github.com/Masterminds/squirrel"
    25  	"github.com/kaleido-io/firefly/internal/i18n"
    26  	"github.com/kaleido-io/firefly/internal/log"
    27  	"github.com/kaleido-io/firefly/pkg/database"
    28  	"github.com/kaleido-io/firefly/pkg/fftypes"
    29  )
    30  
    31  var (
    32  	msgColumns = []string{
    33  		"id",
    34  		"cid",
    35  		"mtype",
    36  		"author",
    37  		"created",
    38  		"namespace",
    39  		"topics",
    40  		"tag",
    41  		"group_hash",
    42  		"datahash",
    43  		"hash",
    44  		"pins",
    45  		"confirmed",
    46  		"tx_type",
    47  		"batch_id",
    48  		"local",
    49  	}
    50  	msgFilterTypeMap = map[string]string{
    51  		"type":    "mtype",
    52  		"txntype": "tx_type",
    53  		"batch":   "batch_id",
    54  		"group":   "group_hash",
    55  	}
    56  )
    57  
    58  func (s *SQLCommon) InsertMessageLocal(ctx context.Context, message *fftypes.Message) (err error) {
    59  	message.Local = true
    60  	return s.upsertMessageCommon(ctx, message, false, false, true /* local insert */)
    61  }
    62  
    63  func (s *SQLCommon) UpsertMessage(ctx context.Context, message *fftypes.Message, allowExisting, allowHashUpdate bool) (err error) {
    64  	return s.upsertMessageCommon(ctx, message, allowExisting, allowHashUpdate, false /* not local */)
    65  }
    66  
    67  func (s *SQLCommon) upsertMessageCommon(ctx context.Context, message *fftypes.Message, allowExisting, allowHashUpdate, isLocal bool) (err error) {
    68  	ctx, tx, autoCommit, err := s.beginOrUseTx(ctx)
    69  	if err != nil {
    70  		return err
    71  	}
    72  	defer s.rollbackTx(ctx, tx, autoCommit)
    73  
    74  	existing := false
    75  	if allowExisting {
    76  		// Do a select within the transaction to detemine if the UUID already exists
    77  		msgRows, err := s.queryTx(ctx, tx,
    78  			sq.Select("hash").
    79  				From("messages").
    80  				Where(sq.Eq{"id": message.Header.ID}),
    81  		)
    82  		if err != nil {
    83  			return err
    84  		}
    85  
    86  		existing = msgRows.Next()
    87  		if existing && !allowHashUpdate {
    88  			var hash *fftypes.Bytes32
    89  			_ = msgRows.Scan(&hash)
    90  			if !fftypes.SafeHashCompare(hash, message.Hash) {
    91  				msgRows.Close()
    92  				log.L(ctx).Errorf("Existing=%s New=%s", hash, message.Hash)
    93  				return database.HashMismatch
    94  			}
    95  		}
    96  		msgRows.Close()
    97  	}
    98  
    99  	if existing {
   100  
   101  		// Update the message
   102  		if err = s.updateTx(ctx, tx,
   103  			sq.Update("messages").
   104  				Set("cid", message.Header.CID).
   105  				Set("mtype", string(message.Header.Type)).
   106  				Set("author", message.Header.Author).
   107  				Set("created", message.Header.Created).
   108  				Set("namespace", message.Header.Namespace).
   109  				Set("topics", message.Header.Topics).
   110  				Set("tag", message.Header.Tag).
   111  				Set("group_hash", message.Header.Group).
   112  				Set("datahash", message.Header.DataHash).
   113  				Set("hash", message.Hash).
   114  				Set("pins", message.Pins).
   115  				Set("confirmed", message.Confirmed).
   116  				Set("tx_type", message.Header.TxType).
   117  				Set("batch_id", message.BatchID).
   118  				// Intentionally does NOT include the "local" column
   119  				Where(sq.Eq{"id": message.Header.ID}),
   120  		); err != nil {
   121  			return err
   122  		}
   123  	} else {
   124  		sequence, err := s.insertTx(ctx, tx,
   125  			sq.Insert("messages").
   126  				Columns(msgColumns...).
   127  				Values(
   128  					message.Header.ID,
   129  					message.Header.CID,
   130  					string(message.Header.Type),
   131  					message.Header.Author,
   132  					message.Header.Created,
   133  					message.Header.Namespace,
   134  					message.Header.Topics,
   135  					message.Header.Tag,
   136  					message.Header.Group,
   137  					message.Header.DataHash,
   138  					message.Hash,
   139  					message.Pins,
   140  					message.Confirmed,
   141  					message.Header.TxType,
   142  					message.BatchID,
   143  					isLocal,
   144  				),
   145  		)
   146  		if err != nil {
   147  			return err
   148  		}
   149  
   150  		s.postCommitEvent(tx, func() {
   151  			s.callbacks.MessageCreated(sequence)
   152  		})
   153  
   154  	}
   155  
   156  	if err = s.updateMessageDataRefs(ctx, tx, message, existing); err != nil {
   157  		return err
   158  	}
   159  
   160  	return s.commitTx(ctx, tx, autoCommit)
   161  }
   162  
   163  func (s *SQLCommon) updateMessageDataRefs(ctx context.Context, tx *txWrapper, message *fftypes.Message, existing bool) error {
   164  
   165  	if existing {
   166  		if err := s.deleteTx(ctx, tx,
   167  			sq.Delete("messages_data").
   168  				Where(sq.And{
   169  					sq.Eq{"message_id": message.Header.ID},
   170  				}),
   171  		); err != nil {
   172  			return err
   173  		}
   174  	}
   175  
   176  	// Run through the ones in the message, finding ones that already exist, and ones that need to be created
   177  	for msgDataRefIDx, msgDataRef := range message.Data {
   178  		if msgDataRef.ID == nil {
   179  			return i18n.NewError(ctx, i18n.MsgNullDataReferenceID, msgDataRefIDx)
   180  		}
   181  		if msgDataRef.Hash == nil {
   182  			return i18n.NewError(ctx, i18n.MsgMissingDataHashIndex, msgDataRefIDx)
   183  		}
   184  		// Add the linkage
   185  		if _, err := s.insertTx(ctx, tx,
   186  			sq.Insert("messages_data").
   187  				Columns(
   188  					"message_id",
   189  					"data_id",
   190  					"data_hash",
   191  					"data_idx",
   192  				).
   193  				Values(
   194  					message.Header.ID,
   195  					msgDataRef.ID,
   196  					msgDataRef.Hash,
   197  					msgDataRefIDx,
   198  				),
   199  		); err != nil {
   200  			return err
   201  		}
   202  	}
   203  
   204  	return nil
   205  
   206  }
   207  
   208  // Why not a LEFT JOIN you ask? ... well we need to be able to reliably perform a LIMIT on
   209  // the number of messages, and it seems there isn't a clean and cross-database
   210  // way for a single-query option. So a two-query option ended up being simplest.
   211  // See commit e304161a30b8044a42b5bac3fcfca7e7bd8f8ab7 for the abandoned changeset
   212  // that implemented LEFT JOIN
   213  func (s *SQLCommon) loadDataRefs(ctx context.Context, msgs []*fftypes.Message) error {
   214  
   215  	msgIDs := make([]string, len(msgs))
   216  	for i, m := range msgs {
   217  		if m != nil {
   218  			msgIDs[i] = m.Header.ID.String()
   219  		}
   220  	}
   221  
   222  	existingRefs, err := s.query(ctx,
   223  		sq.Select(
   224  			"message_id",
   225  			"data_id",
   226  			"data_hash",
   227  			"data_idx",
   228  		).
   229  			From("messages_data").
   230  			Where(sq.Eq{"message_id": msgIDs}).
   231  			OrderBy("data_idx"),
   232  	)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	defer existingRefs.Close()
   237  
   238  	for existingRefs.Next() {
   239  		var msgID fftypes.UUID
   240  		var dataID fftypes.UUID
   241  		var dataHash fftypes.Bytes32
   242  		var dataIDx int
   243  		if err = existingRefs.Scan(&msgID, &dataID, &dataHash, &dataIDx); err != nil {
   244  			return i18n.WrapError(ctx, err, i18n.MsgDBReadErr, "messages_data")
   245  		}
   246  		for _, m := range msgs {
   247  			if *m.Header.ID == msgID {
   248  				m.Data = append(m.Data, &fftypes.DataRef{
   249  					ID:   &dataID,
   250  					Hash: &dataHash,
   251  				})
   252  			}
   253  		}
   254  	}
   255  	// Ensure we return an empty array if no entries, and a consistent order for the data
   256  	for _, m := range msgs {
   257  		if m.Data == nil {
   258  			m.Data = fftypes.DataRefs{}
   259  		}
   260  	}
   261  
   262  	return nil
   263  }
   264  
   265  func (s *SQLCommon) msgResult(ctx context.Context, row *sql.Rows) (*fftypes.Message, error) {
   266  	var msg fftypes.Message
   267  	err := row.Scan(
   268  		&msg.Header.ID,
   269  		&msg.Header.CID,
   270  		&msg.Header.Type,
   271  		&msg.Header.Author,
   272  		&msg.Header.Created,
   273  		&msg.Header.Namespace,
   274  		&msg.Header.Topics,
   275  		&msg.Header.Tag,
   276  		&msg.Header.Group,
   277  		&msg.Header.DataHash,
   278  		&msg.Hash,
   279  		&msg.Pins,
   280  		&msg.Confirmed,
   281  		&msg.Header.TxType,
   282  		&msg.BatchID,
   283  		&msg.Local,
   284  		// Must be added to the list of columns in all selects
   285  		&msg.Sequence,
   286  	)
   287  	if err != nil {
   288  		return nil, i18n.WrapError(ctx, err, i18n.MsgDBReadErr, "messages")
   289  	}
   290  	return &msg, nil
   291  }
   292  
   293  func (s *SQLCommon) GetMessageByID(ctx context.Context, id *fftypes.UUID) (message *fftypes.Message, err error) {
   294  
   295  	cols := append([]string{}, msgColumns...)
   296  	cols = append(cols, s.provider.SequenceField(""))
   297  	rows, err := s.query(ctx,
   298  		sq.Select(cols...).
   299  			From("messages").
   300  			Where(sq.Eq{"id": id}),
   301  	)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	defer rows.Close()
   306  
   307  	if !rows.Next() {
   308  		log.L(ctx).Debugf("Message '%s' not found", id)
   309  		return nil, nil
   310  	}
   311  
   312  	msg, err := s.msgResult(ctx, rows)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	rows.Close()
   318  	if err = s.loadDataRefs(ctx, []*fftypes.Message{msg}); err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	return msg, nil
   323  }
   324  
   325  func (s *SQLCommon) getMessagesQuery(ctx context.Context, query sq.SelectBuilder) (message []*fftypes.Message, err error) {
   326  	rows, err := s.query(ctx, query)
   327  	if err != nil {
   328  		return nil, err
   329  	}
   330  	defer rows.Close()
   331  
   332  	msgs := []*fftypes.Message{}
   333  	for rows.Next() {
   334  		msg, err := s.msgResult(ctx, rows)
   335  		if err != nil {
   336  			return nil, err
   337  		}
   338  		msgs = append(msgs, msg)
   339  	}
   340  
   341  	rows.Close()
   342  	if len(msgs) > 0 {
   343  		if err = s.loadDataRefs(ctx, msgs); err != nil {
   344  			return nil, err
   345  		}
   346  	}
   347  
   348  	return msgs, err
   349  }
   350  
   351  func (s *SQLCommon) GetMessages(ctx context.Context, filter database.Filter) (message []*fftypes.Message, err error) {
   352  	cols := append([]string{}, msgColumns...)
   353  	cols = append(cols, s.provider.SequenceField(""))
   354  	query, err := s.filterSelect(ctx, "", sq.Select(cols...).From("messages"), filter, msgFilterTypeMap)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  	return s.getMessagesQuery(ctx, query)
   359  }
   360  
   361  func (s *SQLCommon) GetMessagesForData(ctx context.Context, dataID *fftypes.UUID, filter database.Filter) (message []*fftypes.Message, err error) {
   362  	cols := make([]string, len(msgColumns)+1)
   363  	for i, col := range msgColumns {
   364  		cols[i] = fmt.Sprintf("m.%s", col)
   365  	}
   366  	cols[len(msgColumns)] = s.provider.SequenceField("m")
   367  	query, err := s.filterSelect(ctx, "m", sq.Select(cols...).From("messages_data AS md"), filter, msgFilterTypeMap,
   368  		sq.Eq{"md.data_id": dataID})
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	query = query.LeftJoin("messages AS m ON m.id = md.message_id")
   374  	return s.getMessagesQuery(ctx, query)
   375  }
   376  
   377  func (s *SQLCommon) GetMessageRefs(ctx context.Context, filter database.Filter) ([]*fftypes.MessageRef, error) {
   378  	query, err := s.filterSelect(ctx, "", sq.Select("id", s.provider.SequenceField(""), "hash").From("messages"), filter, msgFilterTypeMap)
   379  	if err != nil {
   380  		return nil, err
   381  	}
   382  	rows, err := s.query(ctx, query)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	defer rows.Close()
   387  
   388  	msgRefs := []*fftypes.MessageRef{}
   389  	for rows.Next() {
   390  		var msgRef fftypes.MessageRef
   391  		if err = rows.Scan(&msgRef.ID, &msgRef.Sequence, &msgRef.Hash); err != nil {
   392  			return nil, i18n.WrapError(ctx, err, i18n.MsgDBReadErr, "messages")
   393  		}
   394  		msgRefs = append(msgRefs, &msgRef)
   395  	}
   396  	return msgRefs, nil
   397  }
   398  
   399  func (s *SQLCommon) UpdateMessage(ctx context.Context, msgid *fftypes.UUID, update database.Update) (err error) {
   400  	return s.UpdateMessages(ctx, database.MessageQueryFactory.NewFilter(ctx).Eq("id", msgid), update)
   401  }
   402  
   403  func (s *SQLCommon) UpdateMessages(ctx context.Context, filter database.Filter, update database.Update) (err error) {
   404  
   405  	ctx, tx, autoCommit, err := s.beginOrUseTx(ctx)
   406  	if err != nil {
   407  		return err
   408  	}
   409  	defer s.rollbackTx(ctx, tx, autoCommit)
   410  
   411  	query, err := s.buildUpdate(sq.Update("messages"), update, msgFilterTypeMap)
   412  	if err != nil {
   413  		return err
   414  	}
   415  
   416  	query, err = s.filterUpdate(ctx, "", query, filter, opFilterTypeMap)
   417  	if err != nil {
   418  		return err
   419  	}
   420  
   421  	err = s.updateTx(ctx, tx, query)
   422  	if err != nil {
   423  		return err
   424  	}
   425  
   426  	return s.commitTx(ctx, tx, autoCommit)
   427  }