github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/publisher/combo/fanout.go (about)

     1  package combo
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/filecoin-project/bacalhau/pkg/model"
    10  	"github.com/filecoin-project/bacalhau/pkg/publisher"
    11  	"github.com/rs/zerolog/log"
    12  	"go.uber.org/multierr"
    13  )
    14  
    15  // A fanoutPublisher is a publisher that will try multiple publishers in
    16  // parallel. By default, publishers are not prioritized and fanoutPublisher will
    17  // return the result from the first one to succeed.
    18  // Other  publishers will continue to run but their results and errors from the
    19  // other publishers are also ignored. An error is only returned if all publishers fail to produce a result.
    20  // If isPrioritized flag is provided result from providers are prioritized in the order they are provided.
    21  // fanoutPublisher will wait for duration of the provided timeout for
    22  // prioritized publisher to return before moving on to the next one and returning their result.
    23  type fanoutPublisher struct {
    24  	publishers    []publisher.Publisher
    25  	isPrioritized bool
    26  	timeout       time.Duration
    27  }
    28  
    29  func NewFanoutPublisher(publishers ...publisher.Publisher) publisher.Publisher {
    30  	return &fanoutPublisher{
    31  		publishers,
    32  		false,
    33  		time.Duration(0),
    34  	}
    35  }
    36  
    37  func NewPrioritizedFanoutPublisher(timeout time.Duration, publishers ...publisher.Publisher) publisher.Publisher {
    38  	return &fanoutPublisher{
    39  		publishers,
    40  		true,
    41  		timeout,
    42  	}
    43  }
    44  
    45  type fanoutResult[T any, P any] struct {
    46  	Value  T
    47  	Sender P
    48  }
    49  
    50  // fanout runs the passed method for all publishers in parallel. It immediately
    51  // returns two channels from which the results can be read. Return values are
    52  // written immediately to the value channel. A single error is written to the
    53  // error channel only when all publishers have returned.
    54  func fanout[T any, P any](ctx context.Context, publishers []P, method func(P) (T, error)) (chan fanoutResult[T, P], chan error) {
    55  	valueChannel := make(chan fanoutResult[T, P], len(publishers))
    56  	internalErrorChannel := make(chan error, len(publishers))
    57  	externalErrorChannel := make(chan error, 1)
    58  
    59  	waitGroup := sync.WaitGroup{}
    60  	waitGroup.Add(len(publishers))
    61  
    62  	go func() {
    63  		waitGroup.Wait()
    64  		close(internalErrorChannel)
    65  		var multi error
    66  		for err := range internalErrorChannel {
    67  			multi = multierr.Append(multi, err)
    68  		}
    69  		externalErrorChannel <- multi
    70  		close(externalErrorChannel)
    71  	}()
    72  
    73  	runFunc := func(p P) {
    74  		value, err := method(p)
    75  		if err == nil {
    76  			valueChannel <- fanoutResult[T, P]{value, p}
    77  			log.Ctx(ctx).Debug().Str("Publisher", fmt.Sprintf("%T", p)).Interface("Value", value).Send()
    78  		} else {
    79  			internalErrorChannel <- err
    80  			log.Ctx(ctx).Error().Str("Publisher", fmt.Sprintf("%T", p)).Err(err).Send()
    81  		}
    82  		waitGroup.Done()
    83  	}
    84  
    85  	for _, publisher := range publishers {
    86  		go runFunc(publisher)
    87  	}
    88  
    89  	return valueChannel, externalErrorChannel
    90  }
    91  
    92  // IsInstalled implements publisher.Publisher
    93  func (f *fanoutPublisher) IsInstalled(ctx context.Context) (bool, error) {
    94  	ctx = log.Ctx(ctx).With().Str("Method", "IsInstalled").Logger().WithContext(ctx)
    95  
    96  	valueChannel, errorChannel := fanout(ctx, f.publishers, func(p publisher.Publisher) (bool, error) {
    97  		return p.IsInstalled(ctx)
    98  	})
    99  
   100  	// If we have a true result, return it right away. Else, wait for any other
   101  	// publisher that might return a true result. If none do, the errorChannel
   102  	// will close and if all publishers are actually fine err will just be nil.
   103  	for {
   104  		select {
   105  		case installed := <-valueChannel:
   106  			if installed.Value {
   107  				return installed.Value, nil
   108  			}
   109  		case err := <-errorChannel:
   110  			return false, err
   111  		}
   112  	}
   113  }
   114  
   115  // PublishShardResult implements publisher.Publisher
   116  func (f *fanoutPublisher) PublishShardResult(
   117  	ctx context.Context,
   118  	shard model.JobShard,
   119  	hostID string,
   120  	shardResultPath string,
   121  ) (model.StorageSpec, error) {
   122  	var err error
   123  	ctx = log.Ctx(ctx).With().Str("Method", "PublishShardResult").Logger().WithContext(ctx)
   124  
   125  	valueChannel, errorChannel := fanout(ctx, f.publishers, func(p publisher.Publisher) (model.StorageSpec, error) {
   126  		return p.PublishShardResult(ctx, shard, hostID, shardResultPath)
   127  	})
   128  
   129  	timeoutChannel := make(chan bool, 1)
   130  	results := map[publisher.Publisher]model.StorageSpec{}
   131  
   132  loop:
   133  	for {
   134  		select {
   135  		case value := <-valueChannel:
   136  			// if non-prioritized fanout publisher return immediately
   137  			if !f.isPrioritized {
   138  				return value.Value, nil
   139  			}
   140  
   141  			// if prioritized fanout publisher check if result is from publisher of the highest priority
   142  			// if that is true return immediately
   143  			if f.isPrioritized && value.Sender == f.publishers[0] {
   144  				return value.Value, nil
   145  			}
   146  
   147  			results[value.Sender] = value.Value
   148  
   149  			if len(results) == len(f.publishers) {
   150  				// break because everyone returned
   151  				break loop
   152  			}
   153  
   154  			// start timeout for other results when first result is returned
   155  			if len(results) == 1 {
   156  				go func() {
   157  					time.Sleep(f.timeout)
   158  					timeoutChannel <- true
   159  				}()
   160  			}
   161  
   162  		case <-timeoutChannel:
   163  			break loop
   164  		case err = <-errorChannel:
   165  			break loop
   166  		}
   167  	}
   168  
   169  	// loop trough publishers by priority and return
   170  	for _, pub := range f.publishers {
   171  		result, resultExists := results[pub]
   172  		if resultExists {
   173  			return result, nil
   174  		}
   175  	}
   176  
   177  	return model.StorageSpec{}, err
   178  }
   179  
   180  var _ publisher.Publisher = (*fanoutPublisher)(nil)