github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/service/context.go (about)

     1  package service
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/sagernet/sing/common"
     7  )
     8  
     9  func ContextWithRegistry(ctx context.Context, registry Registry) context.Context {
    10  	return context.WithValue(ctx, common.DefaultValue[*Registry](), registry)
    11  }
    12  
    13  func ContextWithDefaultRegistry(ctx context.Context) context.Context {
    14  	if RegistryFromContext(ctx) != nil {
    15  		return ctx
    16  	}
    17  	return context.WithValue(ctx, common.DefaultValue[*Registry](), NewRegistry())
    18  }
    19  
    20  func RegistryFromContext(ctx context.Context) Registry {
    21  	registry := ctx.Value(common.DefaultValue[*Registry]())
    22  	if registry == nil {
    23  		return nil
    24  	}
    25  	return registry.(Registry)
    26  }
    27  
    28  func FromContext[T any](ctx context.Context) T {
    29  	registry := RegistryFromContext(ctx)
    30  	if registry == nil {
    31  		return common.DefaultValue[T]()
    32  	}
    33  	service := registry.Get(common.DefaultValue[*T]())
    34  	if service == nil {
    35  		return common.DefaultValue[T]()
    36  	}
    37  	return service.(T)
    38  }
    39  
    40  func PtrFromContext[T any](ctx context.Context) *T {
    41  	registry := RegistryFromContext(ctx)
    42  	if registry == nil {
    43  		return nil
    44  	}
    45  	servicePtr := registry.Get(common.DefaultValue[*T]())
    46  	if servicePtr == nil {
    47  		return nil
    48  	}
    49  	return servicePtr.(*T)
    50  }
    51  
    52  func ContextWith[T any](ctx context.Context, service T) context.Context {
    53  	registry := RegistryFromContext(ctx)
    54  	if registry == nil {
    55  		registry = NewRegistry()
    56  		ctx = ContextWithRegistry(ctx, registry)
    57  	}
    58  	registry.Register(common.DefaultValue[*T](), service)
    59  	return ctx
    60  }
    61  
    62  func ContextWithPtr[T any](ctx context.Context, servicePtr *T) context.Context {
    63  	registry := RegistryFromContext(ctx)
    64  	if registry == nil {
    65  		registry = NewRegistry()
    66  		ctx = ContextWithRegistry(ctx, registry)
    67  	}
    68  	registry.Register(common.DefaultValue[*T](), servicePtr)
    69  	return ctx
    70  }
    71  
    72  func MustRegister[T any](ctx context.Context, service T) {
    73  	registry := RegistryFromContext(ctx)
    74  	if registry == nil {
    75  		panic("missing service registry in context")
    76  	}
    77  	registry.Register(common.DefaultValue[*T](), service)
    78  }
    79  
    80  func MustRegisterPtr[T any](ctx context.Context, servicePtr *T) {
    81  	registry := RegistryFromContext(ctx)
    82  	if registry == nil {
    83  		panic("missing service registry in context")
    84  	}
    85  	registry.Register(common.DefaultValue[*T](), servicePtr)
    86  }