github.com/storacha/go-ucanto@v0.7.2/client/retrieval/connection_test.go (about)

     1  package retrieval
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"testing"
    12  
    13  	prime "github.com/ipld/go-ipld-prime"
    14  	"github.com/multiformats/go-multihash"
    15  	"github.com/storacha/go-ucanto/core/dag/blockstore"
    16  	"github.com/storacha/go-ucanto/core/delegation"
    17  	"github.com/storacha/go-ucanto/core/invocation"
    18  	"github.com/storacha/go-ucanto/core/ipld"
    19  	"github.com/storacha/go-ucanto/core/receipt"
    20  	"github.com/storacha/go-ucanto/core/receipt/fx"
    21  	"github.com/storacha/go-ucanto/core/result"
    22  	"github.com/storacha/go-ucanto/core/result/failure"
    23  	"github.com/storacha/go-ucanto/core/schema"
    24  	ed25519 "github.com/storacha/go-ucanto/principal/ed25519/signer"
    25  	"github.com/storacha/go-ucanto/server"
    26  	"github.com/storacha/go-ucanto/server/retrieval"
    27  	"github.com/storacha/go-ucanto/testing/fixtures"
    28  	"github.com/storacha/go-ucanto/testing/helpers"
    29  	"github.com/storacha/go-ucanto/testing/helpers/printer"
    30  	thttp "github.com/storacha/go-ucanto/transport/http"
    31  	"github.com/storacha/go-ucanto/ucan"
    32  	"github.com/storacha/go-ucanto/validator"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  type serveCaveats struct {
    37  	Digest []byte
    38  	Range  []int
    39  }
    40  
    41  var serveTS = helpers.Must(prime.LoadSchemaBytes([]byte(`
    42  	type ServeCaveats struct {
    43  		digest Bytes
    44  		range [Int]
    45  	}
    46  	type ServeOk struct {
    47  		digest Bytes
    48  		range [Int]
    49  	}
    50  `)))
    51  
    52  func (sc serveCaveats) ToIPLD() (ipld.Node, error) {
    53  	return ipld.WrapWithRecovery(&sc, serveTS.TypeByName("ServeCaveats"))
    54  }
    55  
    56  type serveOk struct {
    57  	Digest []byte
    58  	Range  []int
    59  }
    60  
    61  func (so serveOk) ToIPLD() (ipld.Node, error) {
    62  	return ipld.WrapWithRecovery(&so, serveTS.TypeByName("ServeOk"))
    63  }
    64  
    65  var serveCaveatsReader = schema.Struct[serveCaveats](serveTS.TypeByName("ServeCaveats"), nil)
    66  
    67  var serve = validator.NewCapability(
    68  	"content/serve",
    69  	schema.DIDString(),
    70  	serveCaveatsReader,
    71  	validator.DefaultDerives,
    72  )
    73  
    74  func mkDelegationChain(t *testing.T, rootIssuer ucan.Signer, endAudience ucan.Principal, can ucan.Ability, len int) delegation.Delegation {
    75  	require.GreaterOrEqual(t, len, 1)
    76  
    77  	var dlg delegation.Delegation
    78  	var proof delegation.Delegation
    79  
    80  	iss := rootIssuer
    81  	aud, err := ed25519.Generate()
    82  	require.NoError(t, err)
    83  
    84  	for range len - 1 {
    85  		var opts []delegation.Option
    86  		if proof != nil {
    87  			opts = append(opts, delegation.WithProof(delegation.FromDelegation(proof)))
    88  		}
    89  		dlg, err = delegation.Delegate(
    90  			iss,
    91  			aud,
    92  			[]ucan.Capability[ucan.NoCaveats]{
    93  				ucan.NewCapability(can, rootIssuer.DID().String(), ucan.NoCaveats{}),
    94  			},
    95  			opts...,
    96  		)
    97  		require.NoError(t, err)
    98  		iss = aud
    99  		aud, err = ed25519.Generate()
   100  		require.NoError(t, err)
   101  		proof = dlg
   102  	}
   103  
   104  	var opts []delegation.Option
   105  	if proof != nil {
   106  		opts = append(opts, delegation.WithProof(delegation.FromDelegation(proof)))
   107  	}
   108  	dlg, err = delegation.Delegate(
   109  		iss,
   110  		endAudience,
   111  		[]ucan.Capability[ucan.NoCaveats]{
   112  			ucan.NewCapability(can, rootIssuer.DID().String(), ucan.NoCaveats{}),
   113  		},
   114  		opts...,
   115  	)
   116  	require.NoError(t, err)
   117  
   118  	return dlg
   119  }
   120  
   121  func calcHeadersSize(h http.Header) int {
   122  	var buf bytes.Buffer
   123  	h.Write(&buf)
   124  	return buf.Len()
   125  }
   126  
   127  var kb = 1024
   128  
   129  // newRetrievalHTTPServer creates a HTTP server that will send a 431 response
   130  // when HTTP headers exceed 2KiB, but otherwise calls the UCAN server as usual
   131  func newRetrievalHTTPServer(t *testing.T, server server.ServerView[retrieval.Service]) *httptest.Server {
   132  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   133  		t.Logf("-> %s %s", r.Method, r.URL)
   134  		printer.PrintHeaders(t, r.Header)
   135  		size := calcHeadersSize(r.Header)
   136  		t.Logf("Total size of headers: %s", printer.SprintBytes(t, size))
   137  
   138  		if size > 2*kb {
   139  			t.Logf("<- %d %s", http.StatusRequestHeaderFieldsTooLarge, http.StatusText(http.StatusRequestHeaderFieldsTooLarge))
   140  			w.WriteHeader(http.StatusRequestHeaderFieldsTooLarge)
   141  			return
   142  		}
   143  
   144  		resp, err := server.Request(r.Context(), thttp.NewInboundRequest(r.URL, r.Body, r.Header))
   145  		require.NoError(t, err)
   146  
   147  		t.Logf("<- %d %s", resp.Status(), http.StatusText(resp.Status()))
   148  		printer.PrintHeaders(t, resp.Headers())
   149  		t.Logf("Total size of headers: %s", printer.SprintBytes(t, calcHeadersSize(resp.Headers())))
   150  
   151  		for name, values := range resp.Headers() {
   152  			for _, value := range values {
   153  				w.Header().Add(name, value)
   154  			}
   155  		}
   156  		w.WriteHeader(resp.Status())
   157  		body := resp.Body()
   158  		if body != nil {
   159  			// log out the "not extended" dag-json response for debugging purposes
   160  			if resp.Status() == http.StatusNotExtended {
   161  				bodyBytes, err := io.ReadAll(body)
   162  				require.NoError(t, err)
   163  				t.Logf("Body: %s", string(bodyBytes))
   164  				body = io.NopCloser(bytes.NewReader(bodyBytes))
   165  			}
   166  			_, err := io.Copy(w, body)
   167  			require.NoError(t, err)
   168  		}
   169  	}))
   170  }
   171  
   172  type testDelegationCache struct {
   173  	t    *testing.T
   174  	data map[string]delegation.Delegation
   175  }
   176  
   177  func (c *testDelegationCache) Get(ctx context.Context, root ipld.Link) (delegation.Delegation, bool, error) {
   178  	d, ok := c.data[root.String()]
   179  	if ok {
   180  		c.t.Logf("CACHE HIT: %s", root.String())
   181  	} else {
   182  		c.t.Logf("CACHE MISS: %s", root.String())
   183  	}
   184  	return d, ok, nil
   185  }
   186  
   187  func (c *testDelegationCache) Put(ctx context.Context, d delegation.Delegation) error {
   188  	c.data[d.Link().String()] = d
   189  	c.t.Logf("CACHE PUT: %s", d.Link().String())
   190  	return nil
   191  }
   192  
   193  func newTestDelegationCache(t *testing.T) *testDelegationCache {
   194  	return &testDelegationCache{t: t, data: map[string]delegation.Delegation{}}
   195  }
   196  
   197  func TestExecute(t *testing.T) {
   198  	chainLengths := []int{1, 5, 10}
   199  	for _, length := range chainLengths {
   200  		t.Run(fmt.Sprintf("retrieval via partitioned request (proof chain of %d delegations)", length), func(t *testing.T) {
   201  			dlg := mkDelegationChain(t, fixtures.Service, fixtures.Alice, serve.Can(), length)
   202  			data := helpers.RandomBytes(512)
   203  
   204  			// create a retrieval server that will send bytes back for an authorized
   205  			// UCAN invocation sent in HTTP headers of the GET request
   206  			server, err := retrieval.NewServer(
   207  				fixtures.Service,
   208  				retrieval.WithServiceMethod(
   209  					serve.Can(),
   210  					retrieval.Provide(
   211  						serve,
   212  						func(ctx context.Context, cap ucan.Capability[serveCaveats], inv invocation.Invocation, ictx server.InvocationContext, req retrieval.Request) (result.Result[serveOk, failure.IPLDBuilderFailure], fx.Effects, retrieval.Response, error) {
   213  							t.Logf("Handling %s: %s", serve.Can(), req.URL.String())
   214  							t.Log("Invocation:")
   215  							printer.PrintDelegation(t, inv, 0)
   216  							nb := cap.Nb()
   217  							result := result.Ok[serveOk, failure.IPLDBuilderFailure](serveOk(nb))
   218  							start, end := nb.Range[0], nb.Range[1]
   219  							// ensure the requested range matches the HTTP request headers
   220  							if req.Headers.Get("Range") != fmt.Sprintf("bytes=%d-%d", start, end) {
   221  								return nil, nil, retrieval.Response{Status: http.StatusBadRequest}, nil
   222  							}
   223  							length := end - start + 1
   224  							headers := http.Header{}
   225  							headers.Set("Content-Length", fmt.Sprintf("%d", length))
   226  							headers.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(data)))
   227  							response := retrieval.Response{
   228  								Status:  http.StatusPartialContent,
   229  								Headers: headers,
   230  								Body:    io.NopCloser(bytes.NewReader(data[start : end+1])),
   231  							}
   232  							return result, nil, response, nil
   233  						},
   234  					),
   235  				),
   236  				retrieval.WithDelegationCache(newTestDelegationCache(t)),
   237  			)
   238  			require.NoError(t, err)
   239  
   240  			httpServer := newRetrievalHTTPServer(t, server)
   241  			defer httpServer.Close()
   242  
   243  			// make a UCAN authorized retrieval request for some bytes from the data
   244  
   245  			// identify the data
   246  			digest, err := multihash.Sum(data, multihash.SHA2_256, -1)
   247  			require.NoError(t, err)
   248  
   249  			// specify the byte range we want to receive (inclusive)
   250  			contentRange := []int{100, 200}
   251  
   252  			url, err := url.Parse(httpServer.URL)
   253  			require.NoError(t, err)
   254  
   255  			headers := http.Header{}
   256  			headers.Set("Range", fmt.Sprintf("bytes=%d-%d", contentRange[0], contentRange[1]))
   257  
   258  			// the URL doesn't really have a consequence on this test, but it can be
   259  			// used to idenitfy the data if not done so in the invocation caveats
   260  			conn, err := NewConnection(
   261  				fixtures.Service,
   262  				url.JoinPath("blob", "z"+digest.B58String()),
   263  				WithHeaders(headers),
   264  			)
   265  			require.NoError(t, err)
   266  
   267  			inv, err := serve.Invoke(
   268  				fixtures.Alice,
   269  				fixtures.Service,
   270  				fixtures.Service.DID().String(),
   271  				serveCaveats{Digest: digest, Range: contentRange},
   272  				delegation.WithProof(delegation.FromDelegation(dlg)),
   273  			)
   274  			require.NoError(t, err)
   275  
   276  			// send the invocation, and receive the execution response _as well as_ the
   277  			// HTTP response!
   278  			xRes, hRes, err := Execute(t.Context(), inv, conn)
   279  			require.NoError(t, err)
   280  			require.NotNil(t, xRes)
   281  			require.NotNil(t, hRes)
   282  
   283  			rcptLink, ok := xRes.Get(inv.Link())
   284  			require.True(t, ok)
   285  
   286  			bs, err := blockstore.NewBlockReader(blockstore.WithBlocksIterator(xRes.Blocks()))
   287  			require.NoError(t, err)
   288  
   289  			rcpt, err := receipt.NewAnyReceipt(rcptLink, bs)
   290  			require.NoError(t, err)
   291  
   292  			// verify the receipt is not an error, and that the info matches the
   293  			// invocation caveats
   294  			o, x := result.Unwrap(rcpt.Out())
   295  			require.Nil(t, x)
   296  
   297  			sok, err := ipld.Rebind[serveOk](o, serveTS.TypeByName("ServeOk"))
   298  			require.NoError(t, err)
   299  			require.Equal(t, digest, multihash.Multihash(sok.Digest))
   300  			require.Equal(t, []int{100, 200}, sok.Range)
   301  
   302  			// verify the data in the HTTP body is what we asked for
   303  			body, err := io.ReadAll(hRes.Body())
   304  			require.NoError(t, err)
   305  			require.Equal(t, data[100:200+1], body)
   306  		})
   307  	}
   308  }