github.com/vmware/transport-go@v1.3.4/bus/transaction.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package bus
     5  
     6  import (
     7  	"fmt"
     8  	"github.com/google/uuid"
     9  	"github.com/vmware/transport-go/model"
    10  	"sync"
    11  )
    12  
    13  type transactionType int
    14  
    15  const (
    16  	asyncTransaction transactionType = iota
    17  	syncTransaction
    18  )
    19  
    20  type BusTransactionReadyFunction func(responses []*model.Message)
    21  
    22  type BusTransaction interface {
    23  	// Sends a request to a channel as a part of this transaction.
    24  	SendRequest(channel string, payload interface{}) error
    25  	//  Wait for a store to be initialized as a part of this transaction.
    26  	WaitForStoreReady(storeName string) error
    27  	// Registers a new complete handler. Once all responses to requests have been received,
    28  	// the transaction is complete.
    29  	OnComplete(completeHandler BusTransactionReadyFunction) error
    30  	// Register a new error handler. If an error is thrown by any of the responders, the transaction
    31  	// is aborted and the error sent to the registered errorHandlers.
    32  	OnError(errorHandler MessageErrorFunction) error
    33  	// Commit the transaction, all requests will be sent and will wait for responses.
    34  	// Once all the responses are in, onComplete handlers will be called with the responses.
    35  	Commit() error
    36  }
    37  
    38  type transactionState int
    39  
    40  const (
    41  	uncommittedState transactionState = iota
    42  	committedState
    43  	completedState
    44  	abortedState
    45  )
    46  
    47  type busTransactionRequest struct {
    48  	requestIndex int
    49  	storeName    string
    50  	channelName  string
    51  	payload      interface{}
    52  }
    53  
    54  type busTransaction struct {
    55  	transactionType    transactionType
    56  	state              transactionState
    57  	lock               sync.Mutex
    58  	requests           []*busTransactionRequest
    59  	responses          []*model.Message
    60  	onCompleteHandlers []BusTransactionReadyFunction
    61  	onErrorHandlers    []MessageErrorFunction
    62  	bus                EventBus
    63  	completedRequests  int
    64  }
    65  
    66  func newBusTransaction(bus EventBus, transactionType transactionType) BusTransaction {
    67  	transaction := new(busTransaction)
    68  
    69  	transaction.bus = bus
    70  	transaction.state = uncommittedState
    71  	transaction.transactionType = transactionType
    72  	transaction.requests = make([]*busTransactionRequest, 0)
    73  	transaction.onCompleteHandlers = make([]BusTransactionReadyFunction, 0)
    74  	transaction.onErrorHandlers = make([]MessageErrorFunction, 0)
    75  	transaction.completedRequests = 0
    76  
    77  	return transaction
    78  }
    79  
    80  func (tr *busTransaction) checkUncommittedState() error {
    81  	if tr.state != uncommittedState {
    82  		return fmt.Errorf("transaction has already been committed")
    83  	}
    84  	return nil
    85  }
    86  
    87  func (tr *busTransaction) SendRequest(channel string, payload interface{}) error {
    88  	tr.lock.Lock()
    89  	defer tr.lock.Unlock()
    90  
    91  	if err := tr.checkUncommittedState(); err != nil {
    92  		return err
    93  	}
    94  
    95  	tr.requests = append(tr.requests, &busTransactionRequest{
    96  		channelName:  channel,
    97  		payload:      payload,
    98  		requestIndex: len(tr.requests),
    99  	})
   100  
   101  	return nil
   102  }
   103  
   104  func (tr *busTransaction) WaitForStoreReady(storeName string) error {
   105  	tr.lock.Lock()
   106  	defer tr.lock.Unlock()
   107  
   108  	if err := tr.checkUncommittedState(); err != nil {
   109  		return err
   110  	}
   111  
   112  	if tr.bus.GetStoreManager().GetStore(storeName) == nil {
   113  		return fmt.Errorf("cannot find store '%s'", storeName)
   114  	}
   115  
   116  	tr.requests = append(tr.requests, &busTransactionRequest{
   117  		storeName:    storeName,
   118  		requestIndex: len(tr.requests),
   119  	})
   120  
   121  	return nil
   122  }
   123  
   124  func (tr *busTransaction) OnComplete(completeHandler BusTransactionReadyFunction) error {
   125  	tr.lock.Lock()
   126  	defer tr.lock.Unlock()
   127  
   128  	if err := tr.checkUncommittedState(); err != nil {
   129  		return err
   130  	}
   131  
   132  	tr.onCompleteHandlers = append(tr.onCompleteHandlers, completeHandler)
   133  	return nil
   134  }
   135  
   136  func (tr *busTransaction) OnError(errorHandler MessageErrorFunction) error {
   137  	tr.lock.Lock()
   138  	defer tr.lock.Unlock()
   139  
   140  	if err := tr.checkUncommittedState(); err != nil {
   141  		return err
   142  	}
   143  
   144  	tr.onErrorHandlers = append(tr.onErrorHandlers, errorHandler)
   145  	return nil
   146  }
   147  
   148  func (tr *busTransaction) Commit() error {
   149  	tr.lock.Lock()
   150  	defer tr.lock.Unlock()
   151  
   152  	if err := tr.checkUncommittedState(); err != nil {
   153  		return err
   154  	}
   155  
   156  	if len(tr.requests) == 0 {
   157  		return fmt.Errorf("cannot commit empty transaction")
   158  	}
   159  
   160  	tr.state = committedState
   161  
   162  	// init responses slice
   163  	tr.responses = make([]*model.Message, len(tr.requests))
   164  
   165  	if tr.transactionType == asyncTransaction {
   166  		tr.startAsyncTransaction()
   167  	} else {
   168  		tr.startSyncTransaction()
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (tr *busTransaction) startSyncTransaction() {
   175  	tr.executeRequest(tr.requests[0])
   176  }
   177  
   178  func (tr *busTransaction) executeRequest(request *busTransactionRequest) {
   179  	if request.storeName != "" {
   180  		tr.waitForStore(request)
   181  	} else {
   182  		tr.sendRequest(request)
   183  	}
   184  }
   185  
   186  func (tr *busTransaction) startAsyncTransaction() {
   187  	for _, req := range tr.requests {
   188  		tr.executeRequest(req)
   189  	}
   190  }
   191  
   192  func (tr *busTransaction) sendRequest(req *busTransactionRequest) {
   193  	reqId := uuid.New()
   194  
   195  	mh, err := tr.bus.ListenOnceForDestination(req.channelName, &reqId)
   196  	if err != nil {
   197  		tr.onTransactionError(err)
   198  		return
   199  	}
   200  
   201  	mh.Handle(func(message *model.Message) {
   202  		tr.onTransactionRequestSuccess(req, message)
   203  	}, func(e error) {
   204  		tr.onTransactionError(e)
   205  	})
   206  
   207  	tr.bus.SendRequestMessage(req.channelName, req.payload, &reqId)
   208  }
   209  
   210  func (tr *busTransaction) onTransactionError(err error) {
   211  	tr.lock.Lock()
   212  	defer tr.lock.Unlock()
   213  
   214  	if tr.state == abortedState {
   215  		return
   216  	}
   217  
   218  	tr.state = abortedState
   219  	for _, errorHandler := range tr.onErrorHandlers {
   220  		go errorHandler(err)
   221  	}
   222  }
   223  
   224  func (tr *busTransaction) waitForStore(req *busTransactionRequest) {
   225  	store := tr.bus.GetStoreManager().GetStore(req.storeName)
   226  	if store == nil {
   227  		tr.onTransactionError(fmt.Errorf("cannot find store '%s'", req.storeName))
   228  		return
   229  	}
   230  	store.WhenReady(func() {
   231  		tr.onTransactionRequestSuccess(req, &model.Message{
   232  			Direction: model.ResponseDir,
   233  			Payload:   store.AllValuesAsMap(),
   234  		})
   235  	})
   236  }
   237  
   238  func (tr *busTransaction) onTransactionRequestSuccess(req *busTransactionRequest, message *model.Message) {
   239  	var triggerOnCompleteHandler = false
   240  	tr.lock.Lock()
   241  
   242  	if tr.state == abortedState {
   243  		tr.lock.Unlock()
   244  		return
   245  	}
   246  
   247  	tr.responses[req.requestIndex] = message
   248  	tr.completedRequests++
   249  
   250  	if tr.completedRequests == len(tr.requests) {
   251  		tr.state = completedState
   252  		triggerOnCompleteHandler = true
   253  	}
   254  
   255  	tr.lock.Unlock()
   256  
   257  	if triggerOnCompleteHandler {
   258  		for _, completeHandler := range tr.onCompleteHandlers {
   259  			go completeHandler(tr.responses)
   260  		}
   261  		return
   262  	}
   263  
   264  	// If this is a sync transaction execute the next request
   265  	if tr.transactionType == syncTransaction && req.requestIndex < len(tr.requests)-1 {
   266  		tr.executeRequest(tr.requests[req.requestIndex+1])
   267  	}
   268  }