github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/pubsub/libp2p/pubsub.go (about)

     1  package libp2p
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  	realsync "sync"
     8  
     9  	"github.com/filecoin-project/bacalhau/pkg/logger"
    10  	"github.com/filecoin-project/bacalhau/pkg/model"
    11  	"github.com/filecoin-project/bacalhau/pkg/pubsub"
    12  	"github.com/filecoin-project/bacalhau/pkg/system"
    13  	libp2p_pubsub "github.com/libp2p/go-libp2p-pubsub"
    14  	"github.com/libp2p/go-libp2p/core/host"
    15  	"github.com/rs/zerolog/log"
    16  )
    17  
    18  type PubSubParams struct {
    19  	Host        host.Host
    20  	TopicName   string
    21  	PubSub      *libp2p_pubsub.PubSub
    22  	IgnoreLocal bool
    23  }
    24  type PubSub[T any] struct {
    25  	hostID      string
    26  	topicName   string
    27  	pubSub      *libp2p_pubsub.PubSub
    28  	ignoreLocal bool
    29  
    30  	topic        *libp2p_pubsub.Topic
    31  	subscription *libp2p_pubsub.Subscription
    32  
    33  	subscriber     pubsub.Subscriber[T]
    34  	subscriberOnce realsync.Once
    35  	closeOnce      realsync.Once
    36  }
    37  
    38  func NewPubSub[T any](params PubSubParams) (*PubSub[T], error) {
    39  	topic, err := params.PubSub.Join(params.TopicName)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	newPubSub := &PubSub[T]{
    44  		hostID:      params.Host.ID().String(),
    45  		pubSub:      params.PubSub,
    46  		topic:       topic,
    47  		topicName:   params.TopicName,
    48  		ignoreLocal: params.IgnoreLocal,
    49  	}
    50  	return newPubSub, nil
    51  }
    52  
    53  func (p *PubSub[T]) Publish(ctx context.Context, message T) error {
    54  	ctx, span := system.NewSpan(ctx, system.GetTracer(), "pkg/pubsub/libp2p.Publish.Publish")
    55  	defer span.End()
    56  
    57  	payload, err := model.JSONMarshalWithMax(message)
    58  	if err != nil {
    59  		return err
    60  	}
    61  
    62  	log.Ctx(ctx).Trace().Msgf("Sending message %+v", message)
    63  	return p.topic.Publish(ctx, payload)
    64  }
    65  
    66  func (p *PubSub[T]) Subscribe(ctx context.Context, subscriber pubsub.Subscriber[T]) (err error) {
    67  	var firstSubscriber bool
    68  	p.subscriberOnce.Do(func() {
    69  		// register the subscriber
    70  		p.subscriber = subscriber
    71  
    72  		p.subscription, err = p.topic.Subscribe()
    73  		if err != nil {
    74  			return
    75  		}
    76  
    77  		// start listening for events
    78  		go p.listenForEvents()
    79  		firstSubscriber = true
    80  	})
    81  	if err != nil {
    82  		return err
    83  	}
    84  	if !firstSubscriber {
    85  		err = errors.New("only a single subscriber is allowed. Use ChainedSubscriber to chain multiple subscribers")
    86  	}
    87  	return err
    88  }
    89  
    90  func (p *PubSub[T]) listenForEvents() {
    91  	ctx := logger.ContextWithNodeIDLogger(context.Background(), p.hostID)
    92  	for {
    93  		msg, err := p.subscription.Next(ctx)
    94  		if err != nil {
    95  			if err == context.Canceled || err == context.DeadlineExceeded || err == libp2p_pubsub.ErrSubscriptionCancelled {
    96  				log.Ctx(ctx).Trace().Msgf("libp2p pubsub shutting down: %v", err)
    97  			} else {
    98  				log.Ctx(ctx).Error().Msgf(
    99  					"libp2p encountered an unexpected error, shutting down: %v", err)
   100  			}
   101  			return
   102  		}
   103  		if p.ignoreLocal && msg.GetFrom().String() == p.hostID {
   104  			continue
   105  		}
   106  		p.readMessage(ctx, msg)
   107  	}
   108  }
   109  
   110  func (p *PubSub[T]) readMessage(ctx context.Context, msg *libp2p_pubsub.Message) {
   111  	// TODO: we would enforce the claims to SourceNodeID here
   112  	// i.e. msg.ReceivedFrom() should match msg.Data.JobEvent.SourceNodeID
   113  	var payload T
   114  	err := model.JSONUnmarshalWithMax(msg.Data, &payload)
   115  	if err != nil {
   116  		log.Ctx(ctx).Error().Msgf("error unmarshalling libp2p payload: %v", err)
   117  		return
   118  	}
   119  
   120  	err = p.subscriber.Handle(ctx, payload)
   121  	if err != nil {
   122  		log.Ctx(ctx).Error().Err(err).Msgf("error in handle message of type: %s", reflect.TypeOf(payload))
   123  	}
   124  }
   125  
   126  func (p *PubSub[T]) Close(ctx context.Context) (err error) {
   127  	p.closeOnce.Do(func() {
   128  		if p.subscription != nil {
   129  			p.subscription.Cancel()
   130  		}
   131  		if p.topic != nil {
   132  			err = p.topic.Close()
   133  		}
   134  	})
   135  	if err != nil {
   136  		return err
   137  	}
   138  	log.Ctx(ctx).Info().Msgf("done closing libp2p pubsub for topic %s", p.topicName)
   139  	return nil
   140  }
   141  
   142  // compile-time interface assertions
   143  var _ pubsub.PubSub[string] = (*PubSub[string])(nil)