github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/publisher/combo/fanout.go (about) 1 package combo 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "time" 8 9 "github.com/filecoin-project/bacalhau/pkg/model" 10 "github.com/filecoin-project/bacalhau/pkg/publisher" 11 "github.com/rs/zerolog/log" 12 "go.uber.org/multierr" 13 ) 14 15 // A fanoutPublisher is a publisher that will try multiple publishers in 16 // parallel. By default, publishers are not prioritized and fanoutPublisher will 17 // return the result from the first one to succeed. 18 // Other publishers will continue to run but their results and errors from the 19 // other publishers are also ignored. An error is only returned if all publishers fail to produce a result. 20 // If isPrioritized flag is provided result from providers are prioritized in the order they are provided. 21 // fanoutPublisher will wait for duration of the provided timeout for 22 // prioritized publisher to return before moving on to the next one and returning their result. 23 type fanoutPublisher struct { 24 publishers []publisher.Publisher 25 isPrioritized bool 26 timeout time.Duration 27 } 28 29 func NewFanoutPublisher(publishers ...publisher.Publisher) publisher.Publisher { 30 return &fanoutPublisher{ 31 publishers, 32 false, 33 time.Duration(0), 34 } 35 } 36 37 func NewPrioritizedFanoutPublisher(timeout time.Duration, publishers ...publisher.Publisher) publisher.Publisher { 38 return &fanoutPublisher{ 39 publishers, 40 true, 41 timeout, 42 } 43 } 44 45 type fanoutResult[T any, P any] struct { 46 Value T 47 Sender P 48 } 49 50 // fanout runs the passed method for all publishers in parallel. It immediately 51 // returns two channels from which the results can be read. Return values are 52 // written immediately to the value channel. A single error is written to the 53 // error channel only when all publishers have returned. 54 func fanout[T any, P any](ctx context.Context, publishers []P, method func(P) (T, error)) (chan fanoutResult[T, P], chan error) { 55 valueChannel := make(chan fanoutResult[T, P], len(publishers)) 56 internalErrorChannel := make(chan error, len(publishers)) 57 externalErrorChannel := make(chan error, 1) 58 59 waitGroup := sync.WaitGroup{} 60 waitGroup.Add(len(publishers)) 61 62 go func() { 63 waitGroup.Wait() 64 close(internalErrorChannel) 65 var multi error 66 for err := range internalErrorChannel { 67 multi = multierr.Append(multi, err) 68 } 69 externalErrorChannel <- multi 70 close(externalErrorChannel) 71 }() 72 73 runFunc := func(p P) { 74 value, err := method(p) 75 if err == nil { 76 valueChannel <- fanoutResult[T, P]{value, p} 77 log.Ctx(ctx).Debug().Str("Publisher", fmt.Sprintf("%T", p)).Interface("Value", value).Send() 78 } else { 79 internalErrorChannel <- err 80 log.Ctx(ctx).Error().Str("Publisher", fmt.Sprintf("%T", p)).Err(err).Send() 81 } 82 waitGroup.Done() 83 } 84 85 for _, publisher := range publishers { 86 go runFunc(publisher) 87 } 88 89 return valueChannel, externalErrorChannel 90 } 91 92 // IsInstalled implements publisher.Publisher 93 func (f *fanoutPublisher) IsInstalled(ctx context.Context) (bool, error) { 94 ctx = log.Ctx(ctx).With().Str("Method", "IsInstalled").Logger().WithContext(ctx) 95 96 valueChannel, errorChannel := fanout(ctx, f.publishers, func(p publisher.Publisher) (bool, error) { 97 return p.IsInstalled(ctx) 98 }) 99 100 // If we have a true result, return it right away. Else, wait for any other 101 // publisher that might return a true result. If none do, the errorChannel 102 // will close and if all publishers are actually fine err will just be nil. 103 for { 104 select { 105 case installed := <-valueChannel: 106 if installed.Value { 107 return installed.Value, nil 108 } 109 case err := <-errorChannel: 110 return false, err 111 } 112 } 113 } 114 115 // PublishShardResult implements publisher.Publisher 116 func (f *fanoutPublisher) PublishShardResult( 117 ctx context.Context, 118 shard model.JobShard, 119 hostID string, 120 shardResultPath string, 121 ) (model.StorageSpec, error) { 122 var err error 123 ctx = log.Ctx(ctx).With().Str("Method", "PublishShardResult").Logger().WithContext(ctx) 124 125 valueChannel, errorChannel := fanout(ctx, f.publishers, func(p publisher.Publisher) (model.StorageSpec, error) { 126 return p.PublishShardResult(ctx, shard, hostID, shardResultPath) 127 }) 128 129 timeoutChannel := make(chan bool, 1) 130 results := map[publisher.Publisher]model.StorageSpec{} 131 132 loop: 133 for { 134 select { 135 case value := <-valueChannel: 136 // if non-prioritized fanout publisher return immediately 137 if !f.isPrioritized { 138 return value.Value, nil 139 } 140 141 // if prioritized fanout publisher check if result is from publisher of the highest priority 142 // if that is true return immediately 143 if f.isPrioritized && value.Sender == f.publishers[0] { 144 return value.Value, nil 145 } 146 147 results[value.Sender] = value.Value 148 149 if len(results) == len(f.publishers) { 150 // break because everyone returned 151 break loop 152 } 153 154 // start timeout for other results when first result is returned 155 if len(results) == 1 { 156 go func() { 157 time.Sleep(f.timeout) 158 timeoutChannel <- true 159 }() 160 } 161 162 case <-timeoutChannel: 163 break loop 164 case err = <-errorChannel: 165 break loop 166 } 167 } 168 169 // loop trough publishers by priority and return 170 for _, pub := range f.publishers { 171 result, resultExists := results[pub] 172 if resultExists { 173 return result, nil 174 } 175 } 176 177 return model.StorageSpec{}, err 178 } 179 180 var _ publisher.Publisher = (*fanoutPublisher)(nil)