gitlab.com/SkynetLabs/skyd@v1.6.9/skymodules/renter/skynetportals/skynetportals.go (about)

     1  package skynetportals
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"strings"
     8  	"sync"
     9  
    10  	"gitlab.com/NebulousLabs/encoding"
    11  	"gitlab.com/NebulousLabs/errors"
    12  	"gitlab.com/SkynetLabs/skyd/skymodules"
    13  	"go.sia.tech/siad/modules"
    14  	"go.sia.tech/siad/persist"
    15  	"go.sia.tech/siad/types"
    16  )
    17  
    18  const (
    19  	// persistFile is the name of the persist file
    20  	persistFile string = "skynetportals"
    21  
    22  	// persistSize is the size of a persisted portal in the portals list. It is
    23  	// the length of `NetAddress` plus the `public` and `listed` flags.
    24  	persistSize uint64 = modules.MaxEncodedNetAddressLength + 2
    25  )
    26  
    27  var (
    28  	// ErrSkynetPortalsValidation is the error returned when validation of
    29  	// changes to the Skynet portals list fails.
    30  	ErrSkynetPortalsValidation = errors.New("could not validate additions and removals")
    31  
    32  	// metadataHeader is the header of the metadata for the persist file
    33  	metadataHeader = types.NewSpecifier("SkynetPortals\n")
    34  
    35  	// metadataVersion is the version of the persistence file
    36  	metadataVersion = types.NewSpecifier("v1.4.8\n")
    37  )
    38  
    39  type (
    40  	// SkynetPortals manages a list of known Skynet portals by persisting the
    41  	// list to disk.
    42  	SkynetPortals struct {
    43  		staticAop *persist.AppendOnlyPersist
    44  
    45  		// portals is a map of portal addresses to public status.
    46  		portals map[modules.NetAddress]bool
    47  
    48  		mu sync.Mutex
    49  	}
    50  
    51  	// persistEntry contains a Skynet portal and whether it should be listed as
    52  	// being in the persistence file.
    53  	persistEntry struct {
    54  		address modules.NetAddress
    55  		public  bool
    56  		listed  bool
    57  	}
    58  )
    59  
    60  // New returns an initialized SkynetPortals.
    61  func New(persistDir string) (*SkynetPortals, error) {
    62  	// Initialize the persistence of the portal list.
    63  	aop, reader, err := persist.NewAppendOnlyPersist(persistDir, persistFile, metadataHeader, metadataVersion)
    64  	if err != nil {
    65  		return nil, errors.AddContext(err, fmt.Sprintf("unable to initialize the skynet portal list persistence at '%v'", aop.FilePath()))
    66  	}
    67  
    68  	sp := &SkynetPortals{
    69  		staticAop: aop,
    70  	}
    71  	portals, err := unmarshalObjects(reader)
    72  	if err != nil {
    73  		return nil, errors.AddContext(err, "unable to unmarshal persist objects")
    74  	}
    75  	sp.portals = portals
    76  
    77  	return sp, nil
    78  }
    79  
    80  // Close closes and frees associated resources.
    81  func (sp *SkynetPortals) Close() error {
    82  	return sp.staticAop.Close()
    83  }
    84  
    85  // Portals returns the list of known Skynet portals.
    86  func (sp *SkynetPortals) Portals() []skymodules.SkynetPortal {
    87  	sp.mu.Lock()
    88  	defer sp.mu.Unlock()
    89  
    90  	var portals []skymodules.SkynetPortal
    91  	for addr, public := range sp.portals {
    92  		portal := skymodules.SkynetPortal{
    93  			Address: addr,
    94  			Public:  public,
    95  		}
    96  		portals = append(portals, portal)
    97  	}
    98  	return portals
    99  }
   100  
   101  // UpdatePortals updates the list of known Skynet portals.
   102  func (sp *SkynetPortals) UpdatePortals(additions []skymodules.SkynetPortal, removals []modules.NetAddress) error {
   103  	sp.mu.Lock()
   104  	defer sp.mu.Unlock()
   105  
   106  	// Convert portal addresses to lowercase for case-insensitivity.
   107  	addPortals := make([]skymodules.SkynetPortal, len(additions))
   108  	for i, portalInfo := range additions {
   109  		address := modules.NetAddress(strings.ToLower(string(portalInfo.Address)))
   110  		portalInfo.Address = address
   111  		addPortals[i] = portalInfo
   112  	}
   113  	removePortals := make([]modules.NetAddress, len(removals))
   114  	for i, address := range removals {
   115  		address = modules.NetAddress(strings.ToLower(string(address)))
   116  		removePortals[i] = address
   117  	}
   118  
   119  	// Validate now before we start making changes.
   120  	err := sp.validatePortalChanges(additions, removals)
   121  	if err != nil {
   122  		return errors.AddContext(err, ErrSkynetPortalsValidation.Error())
   123  	}
   124  
   125  	buf, err := sp.marshalObjects(additions, removals)
   126  	if err != nil {
   127  		return errors.AddContext(err, fmt.Sprintf("unable to update skynet portal list persistence at '%v'", sp.staticAop.FilePath()))
   128  	}
   129  	_, err = sp.staticAop.Write(buf.Bytes())
   130  	return errors.AddContext(err, fmt.Sprintf("unable to update skynet portal list persistence at '%v'", sp.staticAop.FilePath()))
   131  }
   132  
   133  // marshalObjects marshals the given objects into a byte buffer.
   134  //
   135  // NOTE: this method does not check for duplicate additions or removals
   136  func (sp *SkynetPortals) marshalObjects(additions []skymodules.SkynetPortal, removals []modules.NetAddress) (bytes.Buffer, error) {
   137  	// Create buffer for encoder
   138  	var buf bytes.Buffer
   139  	// Create and encode the persist portals
   140  	listed := true
   141  	for _, portal := range additions {
   142  		// Add portal to map
   143  		sp.portals[portal.Address] = portal.Public
   144  
   145  		// Marshal the update
   146  		pe := persistEntry{portal.Address, portal.Public, listed}
   147  		err := pe.MarshalSia(&buf)
   148  		if err != nil {
   149  			return bytes.Buffer{}, errors.AddContext(err, "unable to encode persisted portal")
   150  		}
   151  	}
   152  	listed = false
   153  	for _, address := range removals {
   154  		// Remove portal from map
   155  		public, exists := sp.portals[address]
   156  		if !exists {
   157  			return bytes.Buffer{}, fmt.Errorf("address %v does not exist", address)
   158  		}
   159  		delete(sp.portals, address)
   160  
   161  		// Marshal the update
   162  		pe := persistEntry{address, public, listed}
   163  		err := pe.MarshalSia(&buf)
   164  		if err != nil {
   165  			return bytes.Buffer{}, errors.AddContext(err, "unable to encode persisted portal")
   166  		}
   167  	}
   168  
   169  	return buf, nil
   170  }
   171  
   172  // unmarshalObjects unmarshals the sia encoded objects.
   173  func unmarshalObjects(reader io.Reader) (map[modules.NetAddress]bool, error) {
   174  	portals := make(map[modules.NetAddress]bool)
   175  	// Unmarshal portals one by one until EOF.
   176  	for {
   177  		var pe persistEntry
   178  		err := pe.UnmarshalSia(reader)
   179  		if errors.Contains(err, io.EOF) {
   180  			break
   181  		}
   182  		if err != nil {
   183  			return nil, err
   184  		}
   185  		if !pe.listed {
   186  			delete(portals, pe.address)
   187  			continue
   188  		}
   189  		portals[pe.address] = pe.public
   190  	}
   191  	return portals, nil
   192  }
   193  
   194  // MarshalSia implements the encoding.SiaMarshaler interface.
   195  //
   196  // TODO: Remove these custom marshal functions and use encoding marshal
   197  // functions. Note that removing these changes the marshal format and is not
   198  // backwards-compatible.
   199  func (pe persistEntry) MarshalSia(w io.Writer) error {
   200  	if len(pe.address) > modules.MaxEncodedNetAddressLength {
   201  		return fmt.Errorf("given address %v does not fit in %v bytes", pe.address, modules.MaxEncodedNetAddressLength)
   202  	}
   203  	e := encoding.NewEncoder(w)
   204  	// Create a padded buffer so that we always write the same amount of bytes.
   205  	buf := make([]byte, modules.MaxEncodedNetAddressLength)
   206  	copy(buf, pe.address)
   207  	e.Write(buf)
   208  	e.WriteBool(pe.public)
   209  	e.WriteBool(pe.listed)
   210  	return e.Err()
   211  }
   212  
   213  // UnmarshalSia implements the encoding.SiaUnmarshaler interface.
   214  func (pe *persistEntry) UnmarshalSia(r io.Reader) error {
   215  	*pe = persistEntry{}
   216  	d := encoding.NewDecoder(r, encoding.DefaultAllocLimit)
   217  	// Read into a padded buffer and extract the address string.
   218  	buf := make([]byte, modules.MaxEncodedNetAddressLength)
   219  	n, err := d.Read(buf)
   220  	if err != nil {
   221  		return errors.AddContext(err, "unable to read address")
   222  	}
   223  	if n != len(buf) {
   224  		return errors.New("did not read address correctly")
   225  	}
   226  	end := bytes.IndexByte(buf, 0)
   227  	if end == -1 {
   228  		end = len(buf)
   229  	}
   230  	pe.address = modules.NetAddress(string(buf[:end]))
   231  	pe.public = d.NextBool()
   232  	pe.listed = d.NextBool()
   233  	err = d.Err()
   234  	return err
   235  }
   236  
   237  // validatePortalChanges validates the changes to be made to the Skynet portals
   238  // list.
   239  func (sp *SkynetPortals) validatePortalChanges(additions []skymodules.SkynetPortal, removals []modules.NetAddress) error {
   240  	// Check for nil input
   241  	if len(additions)+len(removals) == 0 {
   242  		return errors.New("no portals being added or removed")
   243  	}
   244  
   245  	additionsMap := make(map[modules.NetAddress]struct{})
   246  	for _, addition := range additions {
   247  		address := addition.Address
   248  		if err := address.IsStdValid(); err != nil {
   249  			return errors.New("invalid network address: " + err.Error())
   250  		}
   251  		additionsMap[address] = struct{}{}
   252  	}
   253  	// Check that each removal is valid.
   254  	for _, removalAddress := range removals {
   255  		if err := removalAddress.IsStdValid(); err != nil {
   256  			return errors.New("invalid network address: " + err.Error())
   257  		}
   258  		if _, exists := sp.portals[removalAddress]; !exists {
   259  			if _, added := additionsMap[removalAddress]; !added {
   260  				return errors.New("address " + string(removalAddress) + " not already present in list of portals or being added")
   261  			}
   262  		}
   263  	}
   264  	return nil
   265  }