github.com/xraypb/xray-core@v1.6.6/common/session/context.go (about)

     1  package session
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/xraypb/xray-core/features/routing"
     7  )
     8  
     9  type sessionKey int
    10  
    11  const (
    12  	idSessionKey sessionKey = iota
    13  	inboundSessionKey
    14  	outboundSessionKey
    15  	contentSessionKey
    16  	muxPreferedSessionKey
    17  	sockoptSessionKey
    18  	trackedConnectionErrorKey
    19  	dispatcherKey
    20  )
    21  
    22  // ContextWithID returns a new context with the given ID.
    23  func ContextWithID(ctx context.Context, id ID) context.Context {
    24  	return context.WithValue(ctx, idSessionKey, id)
    25  }
    26  
    27  // IDFromContext returns ID in this context, or 0 if not contained.
    28  func IDFromContext(ctx context.Context) ID {
    29  	if id, ok := ctx.Value(idSessionKey).(ID); ok {
    30  		return id
    31  	}
    32  	return 0
    33  }
    34  
    35  func ContextWithInbound(ctx context.Context, inbound *Inbound) context.Context {
    36  	return context.WithValue(ctx, inboundSessionKey, inbound)
    37  }
    38  
    39  func InboundFromContext(ctx context.Context) *Inbound {
    40  	if inbound, ok := ctx.Value(inboundSessionKey).(*Inbound); ok {
    41  		return inbound
    42  	}
    43  	return nil
    44  }
    45  
    46  func ContextWithOutbound(ctx context.Context, outbound *Outbound) context.Context {
    47  	return context.WithValue(ctx, outboundSessionKey, outbound)
    48  }
    49  
    50  func OutboundFromContext(ctx context.Context) *Outbound {
    51  	if outbound, ok := ctx.Value(outboundSessionKey).(*Outbound); ok {
    52  		return outbound
    53  	}
    54  	return nil
    55  }
    56  
    57  func ContextWithContent(ctx context.Context, content *Content) context.Context {
    58  	return context.WithValue(ctx, contentSessionKey, content)
    59  }
    60  
    61  func ContentFromContext(ctx context.Context) *Content {
    62  	if content, ok := ctx.Value(contentSessionKey).(*Content); ok {
    63  		return content
    64  	}
    65  	return nil
    66  }
    67  
    68  // ContextWithMuxPrefered returns a new context with the given bool
    69  func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
    70  	return context.WithValue(ctx, muxPreferedSessionKey, forced)
    71  }
    72  
    73  // MuxPreferedFromContext returns value in this context, or false if not contained.
    74  func MuxPreferedFromContext(ctx context.Context) bool {
    75  	if val, ok := ctx.Value(muxPreferedSessionKey).(bool); ok {
    76  		return val
    77  	}
    78  	return false
    79  }
    80  
    81  // ContextWithSockopt returns a new context with Socket configs included
    82  func ContextWithSockopt(ctx context.Context, s *Sockopt) context.Context {
    83  	return context.WithValue(ctx, sockoptSessionKey, s)
    84  }
    85  
    86  // SockoptFromContext returns Socket configs in this context, or nil if not contained.
    87  func SockoptFromContext(ctx context.Context) *Sockopt {
    88  	if sockopt, ok := ctx.Value(sockoptSessionKey).(*Sockopt); ok {
    89  		return sockopt
    90  	}
    91  	return nil
    92  }
    93  
    94  func GetForcedOutboundTagFromContext(ctx context.Context) string {
    95  	if ContentFromContext(ctx) == nil {
    96  		return ""
    97  	}
    98  	return ContentFromContext(ctx).Attribute("forcedOutboundTag")
    99  }
   100  
   101  func SetForcedOutboundTagToContext(ctx context.Context, tag string) context.Context {
   102  	if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
   103  		ctx = ContextWithContent(ctx, &Content{})
   104  	}
   105  	ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag)
   106  	return ctx
   107  }
   108  
   109  type TrackedRequestErrorFeedback interface {
   110  	SubmitError(err error)
   111  }
   112  
   113  func SubmitOutboundErrorToOriginator(ctx context.Context, err error) {
   114  	if errorTracker := ctx.Value(trackedConnectionErrorKey); errorTracker != nil {
   115  		errorTracker := errorTracker.(TrackedRequestErrorFeedback)
   116  		errorTracker.SubmitError(err)
   117  	}
   118  }
   119  
   120  func TrackedConnectionError(ctx context.Context, tracker TrackedRequestErrorFeedback) context.Context {
   121  	return context.WithValue(ctx, trackedConnectionErrorKey, tracker)
   122  }
   123  
   124  func ContextWithDispatcher(ctx context.Context, dispatcher routing.Dispatcher) context.Context {
   125  	return context.WithValue(ctx, dispatcherKey, dispatcher)
   126  }
   127  
   128  func DispatcherFromContext(ctx context.Context) routing.Dispatcher {
   129  	if dispatcher, ok := ctx.Value(dispatcherKey).(routing.Dispatcher); ok {
   130  		return dispatcher
   131  	}
   132  	return nil
   133  }