github.com/storacha/go-ucanto@v0.7.2/server/server.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/storacha/go-ucanto/core/dag/blockstore"
    14  	"github.com/storacha/go-ucanto/core/delegation"
    15  	"github.com/storacha/go-ucanto/core/invocation"
    16  	"github.com/storacha/go-ucanto/core/ipld"
    17  	"github.com/storacha/go-ucanto/core/message"
    18  	"github.com/storacha/go-ucanto/core/receipt"
    19  	"github.com/storacha/go-ucanto/core/receipt/ran"
    20  	"github.com/storacha/go-ucanto/core/result"
    21  	"github.com/storacha/go-ucanto/core/result/failure"
    22  	"github.com/storacha/go-ucanto/did"
    23  	"github.com/storacha/go-ucanto/principal"
    24  	"github.com/storacha/go-ucanto/principal/ed25519/verifier"
    25  	"github.com/storacha/go-ucanto/server/transaction"
    26  	"github.com/storacha/go-ucanto/transport"
    27  	"github.com/storacha/go-ucanto/transport/car"
    28  	thttp "github.com/storacha/go-ucanto/transport/http"
    29  	"github.com/storacha/go-ucanto/ucan"
    30  	"github.com/storacha/go-ucanto/validator"
    31  )
    32  
    33  // InvocationContext is the context provided to service methods.
    34  type InvocationContext interface {
    35  	validator.RevocationChecker[any]
    36  	validator.CanIssuer[any]
    37  	validator.ProofResolver
    38  	validator.PrincipalParser
    39  	validator.PrincipalResolver
    40  	validator.TimeBoundsValidator
    41  	validator.AuthorityProver
    42  	// ID is the DID of the service the invocation was sent to.
    43  	ID() principal.Signer
    44  
    45  	// AlternativeAudiences are other audiences the service will accept for invocations.
    46  	AlternativeAudiences() []ucan.Principal
    47  }
    48  
    49  // ServiceMethod is an invocation handler.
    50  type ServiceMethod[O ipld.Builder, X failure.IPLDBuilderFailure] func(
    51  	context.Context,
    52  	invocation.Invocation,
    53  	InvocationContext,
    54  ) (transaction.Transaction[O, X], error)
    55  
    56  // Service is a mapping of service names to handlers, used to define a
    57  // service implementation.
    58  type Service = map[ucan.Ability]ServiceMethod[ipld.Builder, failure.IPLDBuilderFailure]
    59  
    60  type ServiceInvocation = invocation.IssuedInvocation
    61  
    62  type Server[S any] interface {
    63  	// ID is the DID which will be used to verify that received invocation
    64  	// audience matches it.
    65  	ID() principal.Signer
    66  	Codec() transport.InboundCodec
    67  	Context() InvocationContext
    68  	// Service is the actual service providing capability handlers.
    69  	Service() S
    70  	Catch(err HandlerExecutionError[any])
    71  	LogReceipt(ctx context.Context, rcpt receipt.AnyReceipt, inv invocation.Invocation) error
    72  }
    73  
    74  // Server is a materialized service that is configured to use a specific
    75  // transport channel. It has a invocation context which contains the DID of the
    76  // service itself, among other things.
    77  type ServerView[S any] interface {
    78  	Server[S]
    79  	transport.Channel
    80  	// Run executes a single invocation and returns a receipt.
    81  	Run(ctx context.Context, invocation ServiceInvocation) (receipt.AnyReceipt, error)
    82  }
    83  
    84  // ErrorHandlerFunc allows non-result errors generated during handler execution
    85  // to be logged.
    86  type ErrorHandlerFunc func(err HandlerExecutionError[any])
    87  
    88  // ReceiptLoggerFunc allows receipts generated during handler execution to be logged.
    89  // The original invocation is also provided for reference.
    90  // Returning an error from this function will cause the server to fail the request and send an error response
    91  // back to the client, use judiciously.
    92  type ReceiptLoggerFunc func(ctx context.Context, rcpt receipt.AnyReceipt, inv invocation.Invocation) error
    93  
    94  func NewServer(id principal.Signer, options ...Option) (ServerView[Service], error) {
    95  	cfg := srvConfig{service: Service{}}
    96  	for _, opt := range options {
    97  		if err := opt(&cfg); err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  
   102  	codec := cfg.codec
   103  	if codec == nil {
   104  		codec = car.NewInboundCodec()
   105  	}
   106  
   107  	canIssue := cfg.canIssue
   108  	if canIssue == nil {
   109  		canIssue = validator.IsSelfIssued
   110  	}
   111  
   112  	catch := cfg.catch
   113  	if catch == nil {
   114  		catch = func(err HandlerExecutionError[any]) {
   115  			fmt.Fprintf(os.Stderr, "error: %s\n", err.Error())
   116  		}
   117  	}
   118  
   119  	validateAuthorization := cfg.validateAuthorization
   120  	if validateAuthorization == nil {
   121  		validateAuthorization = func(context.Context, validator.Authorization[any]) validator.Revoked {
   122  			return nil
   123  		}
   124  	}
   125  
   126  	resolveProof := cfg.resolveProof
   127  	if resolveProof == nil {
   128  		resolveProof = validator.ProofUnavailable
   129  	}
   130  
   131  	parsePrincipal := cfg.parsePrincipal
   132  	if parsePrincipal == nil {
   133  		parsePrincipal = ParsePrincipal
   134  	}
   135  
   136  	resolveDIDKey := cfg.resolveDIDKey
   137  	if resolveDIDKey == nil {
   138  		resolveDIDKey = validator.FailDIDKeyResolution
   139  	}
   140  
   141  	validateTimeBounds := cfg.validateTimeBounds
   142  	if validateTimeBounds == nil {
   143  		validateTimeBounds = validator.NotExpiredNotTooEarly
   144  	}
   145  
   146  	ctx := serverContext{id, canIssue, validateAuthorization, resolveProof, parsePrincipal, resolveDIDKey, validateTimeBounds, cfg.authorityProofs, cfg.altAudiences}
   147  	svr := &server{id, cfg.service, ctx, codec, catch, cfg.logReceipt}
   148  	return svr, nil
   149  }
   150  
   151  func ParsePrincipal(str string) (principal.Verifier, error) {
   152  	// TODO: Ed or RSA
   153  	return verifier.Parse(str)
   154  }
   155  
   156  type serverContext struct {
   157  	id                    principal.Signer
   158  	canIssue              validator.CanIssueFunc[any]
   159  	validateAuthorization validator.RevocationCheckerFunc[any]
   160  	resolveProof          validator.ProofResolverFunc
   161  	parsePrincipal        validator.PrincipalParserFunc
   162  	resolveDIDKey         validator.PrincipalResolverFunc
   163  	validateTimeBounds    validator.TimeBoundsValidatorFunc
   164  	authorityProofs       []delegation.Delegation
   165  	altAudiences          []ucan.Principal
   166  }
   167  
   168  func (ctx serverContext) ID() principal.Signer {
   169  	return ctx.id
   170  }
   171  
   172  func (sctx serverContext) CanIssue(capability ucan.Capability[any], issuer did.DID) bool {
   173  	return sctx.canIssue(capability, issuer)
   174  }
   175  
   176  func (sctx serverContext) ValidateAuthorization(ctx context.Context, auth validator.Authorization[any]) validator.Revoked {
   177  	return sctx.validateAuthorization(ctx, auth)
   178  }
   179  
   180  func (sctx serverContext) ResolveProof(ctx context.Context, proof ucan.Link) (delegation.Delegation, validator.UnavailableProof) {
   181  	return sctx.resolveProof(ctx, proof)
   182  }
   183  
   184  func (sctx serverContext) ParsePrincipal(str string) (principal.Verifier, error) {
   185  	return sctx.parsePrincipal(str)
   186  }
   187  
   188  func (sctx serverContext) ResolveDIDKey(ctx context.Context, did did.DID) (did.DID, validator.UnresolvedDID) {
   189  	return sctx.resolveDIDKey(ctx, did)
   190  }
   191  
   192  func (sctx serverContext) ValidateTimeBounds(dlg delegation.Delegation) validator.InvalidProof {
   193  	return sctx.validateTimeBounds(dlg)
   194  }
   195  
   196  func (sctx serverContext) AuthorityProofs() []delegation.Delegation {
   197  	return sctx.authorityProofs
   198  }
   199  
   200  func (sctx serverContext) AlternativeAudiences() []ucan.Principal {
   201  	return sctx.altAudiences
   202  }
   203  
   204  type server struct {
   205  	id         principal.Signer
   206  	service    Service
   207  	context    InvocationContext
   208  	codec      transport.InboundCodec
   209  	catch      ErrorHandlerFunc
   210  	logReceipt ReceiptLoggerFunc
   211  }
   212  
   213  func (srv *server) ID() principal.Signer {
   214  	return srv.id
   215  }
   216  
   217  func (srv *server) Service() Service {
   218  	return srv.service
   219  }
   220  
   221  func (srv *server) Context() InvocationContext {
   222  	return srv.context
   223  }
   224  
   225  func (srv *server) Codec() transport.InboundCodec {
   226  	return srv.codec
   227  }
   228  
   229  func (srv *server) Request(ctx context.Context, request transport.HTTPRequest) (transport.HTTPResponse, error) {
   230  	return Handle(ctx, srv, request)
   231  }
   232  
   233  func (srv *server) Run(ctx context.Context, invocation ServiceInvocation) (receipt.AnyReceipt, error) {
   234  	return Run(ctx, srv, invocation)
   235  }
   236  
   237  func (srv *server) Catch(err HandlerExecutionError[any]) {
   238  	srv.catch(err)
   239  }
   240  
   241  func (srv *server) LogReceipt(ctx context.Context, rcpt receipt.AnyReceipt, inv invocation.Invocation) error {
   242  	if srv.logReceipt == nil {
   243  		return nil
   244  	}
   245  
   246  	return srv.logReceipt(ctx, rcpt, inv)
   247  }
   248  
   249  var _ transport.Channel = (*server)(nil)
   250  var _ ServerView[Service] = (*server)(nil)
   251  
   252  func Handle(ctx context.Context, server Server[Service], request transport.HTTPRequest) (transport.HTTPResponse, error) {
   253  	selection, aerr := server.Codec().Accept(request)
   254  	if aerr != nil {
   255  		return thttp.NewResponse(aerr.Status(), io.NopCloser(strings.NewReader(aerr.Error())), aerr.Headers()), nil
   256  	}
   257  
   258  	msg, err := selection.Decoder().Decode(request)
   259  	if err != nil {
   260  		return thttp.NewResponse(http.StatusBadRequest, io.NopCloser(strings.NewReader("The server failed to decode the request payload. Please format the payload according to the specified media type.")), nil), nil
   261  	}
   262  
   263  	result, err := Execute(ctx, server, msg)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  
   268  	return selection.Encoder().Encode(result)
   269  }
   270  
   271  func Execute(ctx context.Context, server Server[Service], msg message.AgentMessage) (message.AgentMessage, error) {
   272  	br, err := blockstore.NewBlockReader(blockstore.WithBlocksIterator(msg.Blocks()))
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  
   277  	var invs []invocation.Invocation
   278  	for _, invlnk := range msg.Invocations() {
   279  		inv, err := invocation.NewInvocationView(invlnk, br)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		invs = append(invs, inv)
   284  	}
   285  
   286  	var rcpts []receipt.AnyReceipt
   287  	var rerr error
   288  	var wg sync.WaitGroup
   289  	var lock sync.RWMutex
   290  	for _, inv := range invs {
   291  		wg.Add(1)
   292  		go func(inv invocation.Invocation) {
   293  			defer wg.Done()
   294  			rcpt, err := Run(ctx, server, inv)
   295  			if err != nil {
   296  				rerr = err
   297  				return
   298  			}
   299  
   300  			lock.Lock()
   301  			rcpts = append(rcpts, rcpt)
   302  			lock.Unlock()
   303  		}(inv)
   304  	}
   305  	wg.Wait()
   306  
   307  	if rerr != nil {
   308  		return nil, rerr
   309  	}
   310  
   311  	return message.Build(nil, rcpts)
   312  }
   313  
   314  func Run(ctx context.Context, server Server[Service], invocation ServiceInvocation) (receipt.AnyReceipt, error) {
   315  	caps := invocation.Capabilities()
   316  	// Invocation needs to have one single capability
   317  	if len(caps) != 1 {
   318  		err := NewInvocationCapabilityError(invocation.Capabilities())
   319  		return receipt.Issue(server.ID(), result.NewFailure(err), ran.FromInvocation(invocation))
   320  	}
   321  
   322  	cap := caps[0]
   323  	handle, ok := server.Service()[cap.Can()]
   324  	if !ok {
   325  		err := NewHandlerNotFoundError(cap)
   326  		return receipt.Issue(server.ID(), result.NewFailure(err), ran.FromInvocation(invocation))
   327  	}
   328  
   329  	tx, err := handle(ctx, invocation, server.Context())
   330  	if err != nil {
   331  		if errors.Is(err, context.Canceled) {
   332  			return nil, err
   333  		}
   334  		herr := NewHandlerExecutionError(err, cap)
   335  		server.Catch(herr)
   336  		return receipt.Issue(server.ID(), result.NewFailure(herr), ran.FromInvocation(invocation))
   337  	}
   338  
   339  	fx := tx.Fx()
   340  	var opts []receipt.Option
   341  	if fx != nil {
   342  		opts = append(opts, receipt.WithJoin(fx.Join()), receipt.WithFork(fx.Fork()...))
   343  	}
   344  
   345  	rcpt, err := receipt.Issue(server.ID(), tx.Out(), ran.FromInvocation(invocation), opts...)
   346  	if err != nil {
   347  		herr := NewHandlerExecutionError(err, cap)
   348  		server.Catch(herr)
   349  		return receipt.Issue(server.ID(), result.NewFailure(herr), ran.FromInvocation(invocation))
   350  	}
   351  
   352  	if err := server.LogReceipt(ctx, rcpt, invocation); err != nil {
   353  		herr := NewHandlerExecutionError(err, cap)
   354  		server.Catch(herr)
   355  		return receipt.Issue(server.ID(), result.NewFailure(herr), ran.FromInvocation(invocation))
   356  	}
   357  
   358  	return rcpt, nil
   359  }