gitlab.com/SkynetLabs/skyd@v1.6.9/skymodules/renter/workerrpc_test.go (about)

     1  package renter
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  	"testing"
     7  	"time"
     8  
     9  	"gitlab.com/NebulousLabs/encoding"
    10  	"gitlab.com/NebulousLabs/errors"
    11  	"gitlab.com/NebulousLabs/fastrand"
    12  	"gitlab.com/SkynetLabs/skyd/build"
    13  	"go.sia.tech/siad/crypto"
    14  	"go.sia.tech/siad/modules"
    15  	"go.sia.tech/siad/types"
    16  )
    17  
    18  // TestMDMBandwidthCost is a unit test for mdmBandwidthCost.
    19  func TestMDMBandwidthCost(t *testing.T) {
    20  	pt := newDefaultPriceTable()
    21  	pt.UploadBandwidthCost = types.SiacoinPrecision
    22  	pt.DownloadBandwidthCost = types.SiacoinPrecision
    23  
    24  	// compute bandwidth cost for 100 bytes of upload and 200 of download.
    25  	ul := uint64(100)
    26  	dl := uint64(200)
    27  	bandwidthCost, refund := mdmBandwidthCost(pt, ul, dl)
    28  
    29  	// compute the expected value.
    30  	expectedCost := modules.MDMBandwidthCost(pt, ul, dl)
    31  	if !bandwidthCost.Equals(expectedCost) {
    32  		t.Fatal("mismatch", bandwidthCost, expectedCost)
    33  	}
    34  
    35  	// check for a positive refund - less bandwidth was used than expected.
    36  	ref := refund(ul-10, dl-20)
    37  	expectedRef := bandwidthCost.Sub(modules.MDMBandwidthCost(pt, ul-10, dl-20))
    38  	if !ref.Equals(expectedRef) {
    39  		t.Log(bandwidthCost)
    40  		t.Fatal("mismatch", ref, expectedRef)
    41  	}
    42  
    43  	// check for a negative refund - more bandwidth was used than expected.
    44  	ref = refund(ul+10, dl+20)
    45  	if !ref.IsZero() {
    46  		t.Fatal("should be zero", ref)
    47  	}
    48  
    49  	// exact match
    50  	ref = refund(ul, dl)
    51  	if !ref.IsZero() {
    52  		t.Fatal("should be zero", ref)
    53  	}
    54  }
    55  
    56  // TestUseHostBlockHeight verifies we use the host's blockheight.
    57  func TestUseHostBlockHeight(t *testing.T) {
    58  	if testing.Short() {
    59  		t.SkipNow()
    60  	}
    61  	t.Parallel()
    62  
    63  	// create a new worker tester
    64  	wt, err := newWorkerTester(t.Name())
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	defer func() {
    69  		err := wt.Close()
    70  		if err != nil {
    71  			t.Fatal(err)
    72  		}
    73  	}()
    74  	w := wt.worker
    75  
    76  	// manually corrupt the price table's host blockheight
    77  	wpt := w.staticPriceTable()
    78  	hbh := wpt.staticPriceTable.HostBlockHeight // save host blockheight
    79  	var pt modules.RPCPriceTable
    80  	err = encoding.Unmarshal(encoding.Marshal(wpt.staticPriceTable), &pt)
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	pt.HostBlockHeight += 1e3
    85  
    86  	wptc := new(workerPriceTable)
    87  	wptc.staticExpiryTime = wpt.staticExpiryTime
    88  	wptc.staticUpdateTime = wpt.staticUpdateTime
    89  	wptc.staticPriceTable = pt
    90  	w.staticSetPriceTable(wptc)
    91  
    92  	// create a dummy program
    93  	pb := modules.NewProgramBuilder(&pt, 0)
    94  	pb.AddHasSectorInstruction(crypto.Hash{})
    95  	p, data := pb.Program()
    96  	cost, _, _ := pb.Cost(true)
    97  	jhs := new(jobHasSector)
    98  	jhs.staticSectors = []crypto.Hash{{1, 2, 3}}
    99  	ulBandwidth, dlBandwidth := jhs.callExpectedBandwidth()
   100  	bandwidthCost, bandwidthRefund := mdmBandwidthCost(pt, ulBandwidth, dlBandwidth)
   101  	cost = cost.Add(bandwidthCost)
   102  
   103  	// execute the program
   104  	_, _, err = w.managedExecuteProgram(p, data, types.FileContractID{}, categoryDownload, cost, bandwidthRefund)
   105  	if err == nil || !strings.Contains(err.Error(), "ephemeral account withdrawal message expires too far into the future") {
   106  		t.Fatal("Unexpected error", err)
   107  	}
   108  
   109  	// revert the corruption to assert success
   110  	wpt = w.staticPriceTable()
   111  	err = encoding.Unmarshal(encoding.Marshal(wpt.staticPriceTable), &pt)
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	pt.HostBlockHeight = hbh
   116  
   117  	wptc = new(workerPriceTable)
   118  	wptc.staticExpiryTime = wpt.staticExpiryTime
   119  	wptc.staticUpdateTime = wpt.staticUpdateTime
   120  	wptc.staticPriceTable = pt
   121  	w.staticSetPriceTable(wptc)
   122  
   123  	// execute the program
   124  	_, _, err = w.managedExecuteProgram(p, data, types.FileContractID{}, categoryDownload, cost, bandwidthRefund)
   125  	if err != nil {
   126  		t.Fatal("Unexpected error", err)
   127  	}
   128  }
   129  
   130  // TestExecuteProgramUsedBandwidth verifies the bandwidth used by executing
   131  // various MDM programs on the host
   132  func TestExecuteProgramUsedBandwidth(t *testing.T) {
   133  	if testing.Short() {
   134  		t.SkipNow()
   135  	}
   136  	t.Parallel()
   137  
   138  	// create a new worker tester
   139  	wt, err := newWorkerTester(t.Name())
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  	defer func() {
   144  		err := wt.Close()
   145  		if err != nil {
   146  			t.Fatal(err)
   147  		}
   148  	}()
   149  
   150  	// block until worker has funded the EA before starting the tests.
   151  	err = build.Retry(100, 100*time.Millisecond, func() error {
   152  		if wt.worker.staticAccount.managedAvailableBalance().IsZero() {
   153  			return errors.New("balance is zero")
   154  		}
   155  		return nil
   156  	})
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  
   161  	t.Run("HasSector", func(t *testing.T) {
   162  		testExecuteProgramUsedBandwidthHasSector(t, wt)
   163  	})
   164  
   165  	t.Run("ReadSector", func(t *testing.T) {
   166  		testExecuteProgramUsedBandwidthReadSector(t, wt)
   167  	})
   168  }
   169  
   170  // testExecuteProgramUsedBandwidthHasSector verifies the bandwidth consumed by a
   171  // HasSector program
   172  func testExecuteProgramUsedBandwidthHasSector(t *testing.T, wt *workerTester) {
   173  	w := wt.worker
   174  
   175  	// create a dummy program
   176  	pt := wt.staticPriceTable().staticPriceTable
   177  	pb := modules.NewProgramBuilder(&pt, 0)
   178  	pb.AddHasSectorInstruction(crypto.Hash{})
   179  	p, data := pb.Program()
   180  	cost, _, _ := pb.Cost(true)
   181  
   182  	jhs := new(jobHasSector)
   183  	jhs.staticSectors = []crypto.Hash{{1, 2, 3}}
   184  	ulBandwidth, dlBandwidth := jhs.callExpectedBandwidth()
   185  
   186  	bandwidthCost, bandwidthRefund := mdmBandwidthCost(pt, ulBandwidth, dlBandwidth)
   187  	cost = cost.Add(bandwidthCost)
   188  
   189  	// execute it
   190  	_, limit, err := w.managedExecuteProgram(p, data, types.FileContractID{}, categoryDownload, cost, bandwidthRefund)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  
   195  	// ensure bandwidth is as we expected
   196  	expectedDownload := uint64(1460)
   197  	if limit.Downloaded() != expectedDownload {
   198  		t.Errorf("Expected HasSector program to consume %v download bandwidth, instead it consumed %v", expectedDownload, limit.Downloaded())
   199  	}
   200  
   201  	expectedUpload := uint64(1460)
   202  	if limit.Uploaded() != expectedUpload {
   203  		t.Errorf("Expected HasSector program to consume %v upload bandwidth, instead it consumed %v", expectedUpload, limit.Uploaded())
   204  	}
   205  
   206  	// log the bandwidth used
   207  	t.Logf("Used bandwidth (has sector program): %v down, %v up", limit.Downloaded(), limit.Uploaded())
   208  }
   209  
   210  // testExecuteProgramUsedBandwidthReadSector verifies the bandwidth consumed by
   211  // a ReadSector program
   212  func testExecuteProgramUsedBandwidthReadSector(t *testing.T, wt *workerTester) {
   213  	w := wt.worker
   214  	sectorData := fastrand.Bytes(int(modules.SectorSize))
   215  	sectorRoot := crypto.MerkleRoot(sectorData)
   216  	err := wt.host.AddSector(sectorRoot, sectorData)
   217  	if err != nil {
   218  		t.Fatal("could not add sector to host")
   219  	}
   220  
   221  	// create a dummy program
   222  	pt := wt.staticPriceTable().staticPriceTable
   223  	pb := modules.NewProgramBuilder(&pt, 0)
   224  	pb.AddReadSectorInstruction(modules.SectorSize, 0, sectorRoot, true)
   225  	p, data := pb.Program()
   226  	cost, _, _ := pb.Cost(true)
   227  
   228  	// create job metadata
   229  	jobMetadata := jobReadMetadata{
   230  		staticWorker:           w,
   231  		staticSectorRoot:       sectorRoot,
   232  		staticSpendingCategory: categoryDownload,
   233  	}
   234  
   235  	// create read sector job
   236  	readSectorRespChan := make(chan *jobReadResponse)
   237  	jrs := w.newJobReadSector(context.Background(), w.staticJobReadQueue, readSectorRespChan, jobMetadata, sectorRoot, 0, modules.SectorSize)
   238  
   239  	ulBandwidth, dlBandwidth := jrs.callExpectedBandwidth()
   240  	bandwidthCost, bandwidthRefund := mdmBandwidthCost(pt, ulBandwidth, dlBandwidth)
   241  	cost = cost.Add(bandwidthCost)
   242  
   243  	// execute it
   244  	_, limit, err := w.managedExecuteProgram(p, data, types.FileContractID{}, categoryDownload, cost, bandwidthRefund)
   245  	if err != nil {
   246  		t.Fatal(err)
   247  	}
   248  
   249  	// ensure bandwidth is as we expected
   250  	expectedDownload := uint64(4380)
   251  	if limit.Downloaded() != expectedDownload {
   252  		t.Errorf("Expected ReadSector program to consume %v download bandwidth, instead it consumed %v", expectedDownload, limit.Downloaded())
   253  	}
   254  
   255  	expectedUpload := uint64(1460)
   256  	if limit.Uploaded() != expectedUpload {
   257  		t.Errorf("Expected ReadSector program to consume %v upload bandwidth, instead it consumed %v", expectedUpload, limit.Uploaded())
   258  	}
   259  
   260  	// log the bandwidth used
   261  	t.Logf("Used bandwidth (read sector program): %v down, %v up", limit.Downloaded(), limit.Uploaded())
   262  }