go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/context.go (about)

     1  package onet
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"sync"
     8  
     9  	"go.dedis.ch/onet/v3/log"
    10  	"go.dedis.ch/onet/v3/network"
    11  	bbolt "go.etcd.io/bbolt"
    12  	"golang.org/x/xerrors"
    13  )
    14  
    15  // Context represents the methods that are available to a service.
    16  type Context struct {
    17  	overlay           *Overlay
    18  	server            *Server
    19  	serviceID         ServiceID
    20  	manager           *serviceManager
    21  	bucketName        []byte
    22  	bucketVersionName []byte
    23  }
    24  
    25  // defaultContext is the implementation of the Context interface. It is
    26  // instantiated for each Service.
    27  func newContext(c *Server, o *Overlay, servID ServiceID, manager *serviceManager) *Context {
    28  	ctx := &Context{
    29  		overlay:           o,
    30  		server:            c,
    31  		serviceID:         servID,
    32  		manager:           manager,
    33  		bucketName:        []byte(ServiceFactory.Name(servID)),
    34  		bucketVersionName: []byte(ServiceFactory.Name(servID) + "version"),
    35  	}
    36  	err := manager.db.Update(func(tx *bbolt.Tx) error {
    37  		_, err := tx.CreateBucketIfNotExists(ctx.bucketName)
    38  		if err != nil {
    39  			return xerrors.Errorf("creating bucket: %v", err)
    40  		}
    41  		_, err = tx.CreateBucketIfNotExists(ctx.bucketVersionName)
    42  		if err != nil {
    43  			return xerrors.Errorf("creating bucket: %v", err)
    44  		}
    45  		return nil
    46  	})
    47  	if err != nil {
    48  		log.Panic("Failed to create bucket: " + err.Error())
    49  	}
    50  	return ctx
    51  }
    52  
    53  // NewTreeNodeInstance creates a TreeNodeInstance that is bound to a
    54  // service instead of the Overlay.
    55  func (c *Context) NewTreeNodeInstance(t *Tree, tn *TreeNode, protoName string) *TreeNodeInstance {
    56  	io := c.overlay.protoIO.getByName(protoName)
    57  	return c.overlay.NewTreeNodeInstanceFromService(t, tn, ProtocolNameToID(protoName), c.serviceID, io)
    58  }
    59  
    60  // SendRaw sends a message to the ServerIdentity.
    61  func (c *Context) SendRaw(si *network.ServerIdentity, msg interface{}) error {
    62  	_, err := c.server.Send(si, msg)
    63  	if err != nil {
    64  		xerrors.Errorf("sending message: %v", err)
    65  	}
    66  	return nil
    67  }
    68  
    69  // ServerIdentity returns this server's identity.
    70  func (c *Context) ServerIdentity() *network.ServerIdentity {
    71  	return c.server.ServerIdentity
    72  }
    73  
    74  // Suite returns the suite for the context's associated server.
    75  func (c *Context) Suite() network.Suite {
    76  	return c.server.Suite()
    77  }
    78  
    79  // ServiceID returns the service-id.
    80  func (c *Context) ServiceID() ServiceID {
    81  	return c.serviceID
    82  }
    83  
    84  // CreateProtocol returns a ProtocolInstance bound to the service.
    85  func (c *Context) CreateProtocol(name string, t *Tree) (ProtocolInstance, error) {
    86  	pi, err := c.overlay.CreateProtocol(name, t, c.serviceID)
    87  	if err != nil {
    88  		return nil, xerrors.Errorf("creating protocol: %v", err)
    89  	}
    90  
    91  	return pi, nil
    92  }
    93  
    94  // ProtocolRegister signs up a new protocol to this Server. Contrary go
    95  // GlobalProtocolRegister, the protocol registered here is tied to that server.
    96  // This is useful for simulations where more than one Server exists in the
    97  // global namespace.
    98  // It returns the ID of the protocol.
    99  func (c *Context) ProtocolRegister(name string, protocol NewProtocol) (ProtocolID, error) {
   100  	id, err := c.server.ProtocolRegister(name, protocol)
   101  	if err != nil {
   102  		return id, xerrors.Errorf("protocol registration: %v", err)
   103  	}
   104  	return id, nil
   105  }
   106  
   107  // RegisterProtocolInstance registers a new instance of a protocol using overlay.
   108  func (c *Context) RegisterProtocolInstance(pi ProtocolInstance) error {
   109  	err := c.overlay.RegisterProtocolInstance(pi)
   110  	if err != nil {
   111  		return xerrors.Errorf("protocol instance regisration: %v", err)
   112  	}
   113  	return nil
   114  }
   115  
   116  // ReportStatus returns all status of the services.
   117  func (c *Context) ReportStatus() map[string]*Status {
   118  	return c.server.statusReporterStruct.ReportStatus()
   119  }
   120  
   121  // RegisterStatusReporter registers a new StatusReporter.
   122  func (c *Context) RegisterStatusReporter(name string, s StatusReporter) {
   123  	c.server.statusReporterStruct.RegisterStatusReporter(name, s)
   124  }
   125  
   126  // RegisterProcessor overrides the RegisterProcessor methods of the Dispatcher.
   127  // It delegates the dispatching to the serviceManager.
   128  func (c *Context) RegisterProcessor(p network.Processor, msgType network.MessageTypeID) {
   129  	c.manager.registerProcessor(p, msgType)
   130  }
   131  
   132  // RegisterProcessorFunc takes a message-type and a function that will be called
   133  // if this message-type is received.
   134  func (c *Context) RegisterProcessorFunc(msgType network.MessageTypeID, fn func(*network.Envelope) error) {
   135  	c.manager.registerProcessorFunc(msgType, fn)
   136  }
   137  
   138  // RegisterMessageProxy registers a message proxy only for this server /
   139  // overlay
   140  func (c *Context) RegisterMessageProxy(m MessageProxy) {
   141  	c.overlay.RegisterMessageProxy(m)
   142  }
   143  
   144  // Service returns the corresponding service.
   145  func (c *Context) Service(name string) Service {
   146  	return c.manager.service(name)
   147  }
   148  
   149  // String returns the host it's running on.
   150  func (c *Context) String() string {
   151  	return c.server.ServerIdentity.String()
   152  }
   153  
   154  var testContextData = struct {
   155  	service map[string][]byte
   156  	sync.Mutex
   157  }{service: make(map[string][]byte, 0)}
   158  
   159  // The ContextDB interface allows for easy testing in the services.
   160  type ContextDB interface {
   161  	Load(key []byte) (interface{}, error)
   162  	LoadRaw(key []byte) ([]byte, error)
   163  	LoadVersion() (int, error)
   164  	SaveVersion(version int) error
   165  }
   166  
   167  // Save takes a key and an interface. The interface will be network.Marshal'ed
   168  // and saved in the database under the bucket named after the service name.
   169  //
   170  // The data will be stored in a different bucket for every service.
   171  func (c *Context) Save(key []byte, data interface{}) error {
   172  	buf, err := network.Marshal(data)
   173  	if err != nil {
   174  		return xerrors.Errorf("marshaling: %v", err)
   175  	}
   176  	err = c.manager.db.Update(func(tx *bbolt.Tx) error {
   177  		b := tx.Bucket(c.bucketName)
   178  		return b.Put(key, buf)
   179  	})
   180  	if err != nil {
   181  		return xerrors.Errorf("tx error: %v", err)
   182  	}
   183  	return nil
   184  }
   185  
   186  // Load takes a key and returns the network.Unmarshaled data.
   187  // Returns a nil value if the key does not exist.
   188  func (c *Context) Load(key []byte) (interface{}, error) {
   189  	var buf []byte
   190  	err := c.manager.db.View(func(tx *bbolt.Tx) error {
   191  		v := tx.Bucket(c.bucketName).Get(key)
   192  		if v == nil {
   193  			return nil
   194  		}
   195  
   196  		buf = make([]byte, len(v))
   197  		copy(buf, v)
   198  		return nil
   199  	})
   200  	if err != nil {
   201  		return nil, xerrors.Errorf("tx error: %v", err)
   202  	}
   203  
   204  	if buf == nil {
   205  		return nil, nil
   206  	}
   207  
   208  	_, ret, err := network.Unmarshal(buf, c.server.suite)
   209  	if err != nil {
   210  		return nil, xerrors.Errorf("unmarshaling: %v")
   211  	}
   212  
   213  	return ret, nil
   214  }
   215  
   216  // LoadRaw takes a key and returns the raw, unmarshalled data.
   217  // Returns a nil value if the key does not exist.
   218  func (c *Context) LoadRaw(key []byte) ([]byte, error) {
   219  	var buf []byte
   220  	err := c.manager.db.View(func(tx *bbolt.Tx) error {
   221  		v := tx.Bucket(c.bucketName).Get(key)
   222  		if v == nil {
   223  			return nil
   224  		}
   225  
   226  		buf = make([]byte, len(v))
   227  		copy(buf, v)
   228  		return nil
   229  	})
   230  	if err != nil {
   231  		return nil, xerrors.Errorf("tx error: %v", err)
   232  	}
   233  	return buf, nil
   234  }
   235  
   236  var dbVersion = []byte("dbVersion")
   237  
   238  // LoadVersion returns the version of the database, or 0 if
   239  // no version has been found.
   240  func (c *Context) LoadVersion() (int, error) {
   241  	var buf []byte
   242  	err := c.manager.db.View(func(tx *bbolt.Tx) error {
   243  		v := tx.Bucket(c.bucketVersionName).Get(dbVersion)
   244  		if v == nil {
   245  			return nil
   246  		}
   247  
   248  		buf = make([]byte, len(v))
   249  		copy(buf, v)
   250  		return nil
   251  	})
   252  
   253  	if err != nil {
   254  		return -1, xerrors.Errorf("tx error: %v", err)
   255  	}
   256  
   257  	if len(buf) == 0 {
   258  		return 0, nil
   259  	}
   260  	var version int32
   261  	err = binary.Read(bytes.NewReader(buf), binary.LittleEndian, &version)
   262  	if err != nil {
   263  		return -1, xerrors.Errorf("bytes to int: %v", err)
   264  	}
   265  	return int(version), nil
   266  }
   267  
   268  // SaveVersion stores the given version as the current database version.
   269  func (c *Context) SaveVersion(version int) error {
   270  	buf := bytes.NewBuffer(nil)
   271  	err := binary.Write(buf, binary.LittleEndian, int32(version))
   272  	if err != nil {
   273  		return xerrors.Errorf("int to bytes: %v", err)
   274  	}
   275  	err = c.manager.db.Update(func(tx *bbolt.Tx) error {
   276  		b := tx.Bucket(c.bucketVersionName)
   277  		return b.Put(dbVersion, buf.Bytes())
   278  	})
   279  	if err != nil {
   280  		return xerrors.Errorf("tx error: %v", err)
   281  	}
   282  	return nil
   283  }
   284  
   285  // GetAdditionalBucket makes sure that a bucket with the given name
   286  // exists, by eventually creating it, and returns the created bucket name,
   287  // which is the servicename + "_" + the given name.
   288  //
   289  // This function should only be used if the Load and Save functions are not sufficient.
   290  // Additionally, the user should not create buckets directly on the DB but always
   291  // call this function to create new buckets to avoid bucket name conflicts.
   292  func (c *Context) GetAdditionalBucket(name []byte) (*bbolt.DB, []byte) {
   293  	// make a copy to insure c.bucketName is not written
   294  	bucketName := make([]byte, len(c.bucketName))
   295  	copy(bucketName, c.bucketName)
   296  
   297  	fullName := append(append(bucketName, byte('_')), name...)
   298  	err := c.manager.db.Update(func(tx *bbolt.Tx) error {
   299  		_, err := tx.CreateBucketIfNotExists(fullName)
   300  		if err != nil {
   301  			return xerrors.Errorf("create bucket: %v", err)
   302  		}
   303  		return nil
   304  	})
   305  	if err != nil {
   306  		panic(xerrors.Errorf("tx error: %v", err))
   307  	}
   308  	return c.manager.db, fullName
   309  }
   310  
   311  // SetValidPeers sets the set of peers with which the server underlying this
   312  // context can communicate.
   313  func (c *Context) SetValidPeers(peerID network.PeerSetID,
   314  	peers []*network.ServerIdentity) {
   315  	c.server.SetValidPeers(peerID, peers)
   316  }
   317  
   318  // GetValidPeers returns the set of peers with which the server underlying this
   319  // context can communicate.
   320  // The return value is `nil` in case the set of valid peers has not yet been
   321  // initialized, meaning that all peers are valid.
   322  func (c *Context) GetValidPeers(peerID network.PeerSetID) []network.
   323  	ServerIdentityID {
   324  	return c.server.GetValidPeers(peerID)
   325  }
   326  
   327  // NewPeerSetID creates a new PeerSetID identifying a subset of valid peers.
   328  // This is to be used by services, providing their own specific identifier,
   329  // e.g. the SkipChainID for ByzCoin.
   330  func (c *Context) NewPeerSetID(data []byte) network.PeerSetID {
   331  	// Compute the PeerSetID as hash(serviceID | data)
   332  	h := sha256.New()
   333  	h.Write(c.serviceID[:])
   334  	h.Write(data)
   335  
   336  	return network.NewPeerSetID(h.Sum(nil))
   337  }