github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/session/context.go (about)

     1  package session
     2  
     3  import (
     4  	"context"
     5  	_ "unsafe"
     6  
     7  	"github.com/xtls/xray-core/common/net"
     8  	"github.com/xtls/xray-core/features/routing"
     9  )
    10  
    11  //go:linkname IndependentCancelCtx context.newCancelCtx
    12  func IndependentCancelCtx(parent context.Context) context.Context
    13  
    14  type sessionKey int
    15  
    16  const (
    17  	idSessionKey sessionKey = iota
    18  	inboundSessionKey
    19  	outboundSessionKey
    20  	contentSessionKey
    21  	muxPreferedSessionKey
    22  	sockoptSessionKey
    23  	trackedConnectionErrorKey
    24  	dispatcherKey
    25  	timeoutOnlyKey
    26  	allowedNetworkKey
    27  	handlerSessionKey
    28  )
    29  
    30  // ContextWithID returns a new context with the given ID.
    31  func ContextWithID(ctx context.Context, id ID) context.Context {
    32  	return context.WithValue(ctx, idSessionKey, id)
    33  }
    34  
    35  // IDFromContext returns ID in this context, or 0 if not contained.
    36  func IDFromContext(ctx context.Context) ID {
    37  	if id, ok := ctx.Value(idSessionKey).(ID); ok {
    38  		return id
    39  	}
    40  	return 0
    41  }
    42  
    43  func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context {
    44  	return context.WithValue(ctx, inboundSessionKey, inbound)
    45  }
    46  
    47  func InboundFromContext(ctx context.Context) *Inbound {
    48  	if inbound, ok := ctx.Value(inboundSessionKey).(*Inbound); ok {
    49  		return inbound
    50  	}
    51  	return nil
    52  }
    53  
    54  func ContextWithOutbounds(ctx context.Context, outbounds []*Outbound) context.Context {
    55  	return context.WithValue(ctx, outboundSessionKey, outbounds)
    56  }
    57  
    58  func OutboundsFromContext(ctx context.Context) []*Outbound {
    59  	if outbounds, ok := ctx.Value(outboundSessionKey).([]*Outbound); ok {
    60  		return outbounds
    61  	}
    62  	return nil
    63  }
    64  
    65  func ContextWithContent(ctx context.Context, content *Content) context.Context {
    66  	return context.WithValue(ctx, contentSessionKey, content)
    67  }
    68  
    69  func ContentFromContext(ctx context.Context) *Content {
    70  	if content, ok := ctx.Value(contentSessionKey).(*Content); ok {
    71  		return content
    72  	}
    73  	return nil
    74  }
    75  
    76  // ContextWithMuxPrefered returns a new context with the given bool
    77  func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
    78  	return context.WithValue(ctx, muxPreferedSessionKey, forced)
    79  }
    80  
    81  // MuxPreferedFromContext returns value in this context, or false if not contained.
    82  func MuxPreferedFromContext(ctx context.Context) bool {
    83  	if val, ok := ctx.Value(muxPreferedSessionKey).(bool); ok {
    84  		return val
    85  	}
    86  	return false
    87  }
    88  
    89  // ContextWithSockopt returns a new context with Socket configs included
    90  func ContextWithSockopt(ctx context.Context, s *Sockopt) context.Context {
    91  	return context.WithValue(ctx, sockoptSessionKey, s)
    92  }
    93  
    94  // SockoptFromContext returns Socket configs in this context, or nil if not contained.
    95  func SockoptFromContext(ctx context.Context) *Sockopt {
    96  	if sockopt, ok := ctx.Value(sockoptSessionKey).(*Sockopt); ok {
    97  		return sockopt
    98  	}
    99  	return nil
   100  }
   101  
   102  func GetForcedOutboundTagFromContext(ctx context.Context) string {
   103  	if ContentFromContext(ctx) == nil {
   104  		return ""
   105  	}
   106  	return ContentFromContext(ctx).Attribute("forcedOutboundTag")
   107  }
   108  
   109  func SetForcedOutboundTagToContext(ctx context.Context, tag string) context.Context {
   110  	if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
   111  		ctx = ContextWithContent(ctx, &Content{})
   112  	}
   113  	ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag)
   114  	return ctx
   115  }
   116  
   117  type TrackedRequestErrorFeedback interface {
   118  	SubmitError(err error)
   119  }
   120  
   121  func SubmitOutboundErrorToOriginator(ctx context.Context, err error) {
   122  	if errorTracker := ctx.Value(trackedConnectionErrorKey); errorTracker != nil {
   123  		errorTracker := errorTracker.(TrackedRequestErrorFeedback)
   124  		errorTracker.SubmitError(err)
   125  	}
   126  }
   127  
   128  func TrackedConnectionError(ctx context.Context, tracker TrackedRequestErrorFeedback) context.Context {
   129  	return context.WithValue(ctx, trackedConnectionErrorKey, tracker)
   130  }
   131  
   132  func ContextWithDispatcher(ctx context.Context, dispatcher routing.Dispatcher) context.Context {
   133  	return context.WithValue(ctx, dispatcherKey, dispatcher)
   134  }
   135  
   136  func DispatcherFromContext(ctx context.Context) routing.Dispatcher {
   137  	if dispatcher, ok := ctx.Value(dispatcherKey).(routing.Dispatcher); ok {
   138  		return dispatcher
   139  	}
   140  	return nil
   141  }
   142  
   143  func ContextWithTimeoutOnly(ctx context.Context, only bool) context.Context {
   144  	return context.WithValue(ctx, timeoutOnlyKey, only)
   145  }
   146  
   147  func TimeoutOnlyFromContext(ctx context.Context) bool {
   148  	if val, ok := ctx.Value(timeoutOnlyKey).(bool); ok {
   149  		return val
   150  	}
   151  	return false
   152  }
   153  
   154  func ContextWithAllowedNetwork(ctx context.Context, network net.Network) context.Context {
   155  	return context.WithValue(ctx, allowedNetworkKey, network)
   156  }
   157  
   158  func AllowedNetworkFromContext(ctx context.Context) net.Network {
   159  	if val, ok := ctx.Value(allowedNetworkKey).(net.Network); ok {
   160  		return val
   161  	}
   162  	return net.Network_Unknown
   163  }