github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/features/routing/session/context.go (about)

     1  package session
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/xtls/xray-core/common/net"
     7  	"github.com/xtls/xray-core/common/session"
     8  	"github.com/xtls/xray-core/features/routing"
     9  )
    10  
    11  // Context is an implementation of routing.Context, which is a wrapper of context.context with session info.
    12  type Context struct {
    13  	Inbound  *session.Inbound
    14  	Outbound *session.Outbound
    15  	Content  *session.Content
    16  }
    17  
    18  // GetInboundTag implements routing.Context.
    19  func (ctx *Context) GetInboundTag() string {
    20  	if ctx.Inbound == nil {
    21  		return ""
    22  	}
    23  	return ctx.Inbound.Tag
    24  }
    25  
    26  // GetSourceIPs implements routing.Context.
    27  func (ctx *Context) GetSourceIPs() []net.IP {
    28  	if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
    29  		return nil
    30  	}
    31  	dest := ctx.Inbound.Source
    32  	if dest.Address.Family().IsDomain() {
    33  		return nil
    34  	}
    35  
    36  	return []net.IP{dest.Address.IP()}
    37  }
    38  
    39  // GetSourcePort implements routing.Context.
    40  func (ctx *Context) GetSourcePort() net.Port {
    41  	if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
    42  		return 0
    43  	}
    44  	return ctx.Inbound.Source.Port
    45  }
    46  
    47  // GetTargetIPs implements routing.Context.
    48  func (ctx *Context) GetTargetIPs() []net.IP {
    49  	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
    50  		return nil
    51  	}
    52  
    53  	if ctx.Outbound.Target.Address.Family().IsIP() {
    54  		return []net.IP{ctx.Outbound.Target.Address.IP()}
    55  	}
    56  
    57  	return nil
    58  }
    59  
    60  // GetTargetPort implements routing.Context.
    61  func (ctx *Context) GetTargetPort() net.Port {
    62  	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
    63  		return 0
    64  	}
    65  	return ctx.Outbound.Target.Port
    66  }
    67  
    68  // GetTargetDomain implements routing.Context.
    69  func (ctx *Context) GetTargetDomain() string {
    70  	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
    71  		return ""
    72  	}
    73  	dest := ctx.Outbound.RouteTarget
    74  	if dest.IsValid() && dest.Address.Family().IsDomain() {
    75  		return dest.Address.Domain()
    76  	}
    77  
    78  	dest = ctx.Outbound.Target
    79  	if !dest.Address.Family().IsDomain() {
    80  		return ""
    81  	}
    82  	return dest.Address.Domain()
    83  }
    84  
    85  // GetNetwork implements routing.Context.
    86  func (ctx *Context) GetNetwork() net.Network {
    87  	if ctx.Outbound == nil {
    88  		return net.Network_Unknown
    89  	}
    90  	return ctx.Outbound.Target.Network
    91  }
    92  
    93  // GetProtocol implements routing.Context.
    94  func (ctx *Context) GetProtocol() string {
    95  	if ctx.Content == nil {
    96  		return ""
    97  	}
    98  	return ctx.Content.Protocol
    99  }
   100  
   101  // GetUser implements routing.Context.
   102  func (ctx *Context) GetUser() string {
   103  	if ctx.Inbound == nil || ctx.Inbound.User == nil {
   104  		return ""
   105  	}
   106  	return ctx.Inbound.User.Email
   107  }
   108  
   109  // GetAttributes implements routing.Context.
   110  func (ctx *Context) GetAttributes() map[string]string {
   111  	if ctx.Content == nil {
   112  		return nil
   113  	}
   114  	return ctx.Content.Attributes
   115  }
   116  
   117  // GetSkipDNSResolve implements routing.Context.
   118  func (ctx *Context) GetSkipDNSResolve() bool {
   119  	if ctx.Content == nil {
   120  		return false
   121  	}
   122  	return ctx.Content.SkipDNSResolve
   123  }
   124  
   125  // AsRoutingContext creates a context from context.context with session info.
   126  func AsRoutingContext(ctx context.Context) routing.Context {
   127  	outbounds := session.OutboundsFromContext(ctx)
   128  	ob := outbounds[len(outbounds) - 1]
   129  	return &Context{
   130  		Inbound:  session.InboundFromContext(ctx),
   131  		Outbound: ob,
   132  		Content:  session.ContentFromContext(ctx),
   133  	}
   134  }