github.com/imannamdari/v2ray-core/v5@v5.0.5/app/dispatcher/default.go (about) 1 package dispatcher 2 3 //go:generate go run github.com/imannamdari/v2ray-core/v5/common/errors/errorgen 4 5 import ( 6 "context" 7 "strings" 8 "sync" 9 "time" 10 11 core "github.com/imannamdari/v2ray-core/v5" 12 "github.com/imannamdari/v2ray-core/v5/common" 13 "github.com/imannamdari/v2ray-core/v5/common/buf" 14 "github.com/imannamdari/v2ray-core/v5/common/log" 15 "github.com/imannamdari/v2ray-core/v5/common/net" 16 "github.com/imannamdari/v2ray-core/v5/common/protocol" 17 "github.com/imannamdari/v2ray-core/v5/common/session" 18 "github.com/imannamdari/v2ray-core/v5/common/strmatcher" 19 "github.com/imannamdari/v2ray-core/v5/features/outbound" 20 "github.com/imannamdari/v2ray-core/v5/features/policy" 21 "github.com/imannamdari/v2ray-core/v5/features/routing" 22 routing_session "github.com/imannamdari/v2ray-core/v5/features/routing/session" 23 "github.com/imannamdari/v2ray-core/v5/features/stats" 24 "github.com/imannamdari/v2ray-core/v5/transport" 25 "github.com/imannamdari/v2ray-core/v5/transport/pipe" 26 ) 27 28 var errSniffingTimeout = newError("timeout on sniffing") 29 30 type cachedReader struct { 31 sync.Mutex 32 reader *pipe.Reader 33 cache buf.MultiBuffer 34 } 35 36 func (r *cachedReader) Cache(b *buf.Buffer) { 37 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 38 r.Lock() 39 if !mb.IsEmpty() { 40 r.cache, _ = buf.MergeMulti(r.cache, mb) 41 } 42 b.Clear() 43 rawBytes := b.Extend(buf.Size) 44 n := r.cache.Copy(rawBytes) 45 b.Resize(0, int32(n)) 46 r.Unlock() 47 } 48 49 func (r *cachedReader) readInternal() buf.MultiBuffer { 50 r.Lock() 51 defer r.Unlock() 52 53 if r.cache != nil && !r.cache.IsEmpty() { 54 mb := r.cache 55 r.cache = nil 56 return mb 57 } 58 59 return nil 60 } 61 62 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 63 mb := r.readInternal() 64 if mb != nil { 65 return mb, nil 66 } 67 68 return r.reader.ReadMultiBuffer() 69 } 70 71 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 72 mb := r.readInternal() 73 if mb != nil { 74 return mb, nil 75 } 76 77 return r.reader.ReadMultiBufferTimeout(timeout) 78 } 79 80 func (r *cachedReader) Interrupt() { 81 r.Lock() 82 if r.cache != nil { 83 r.cache = buf.ReleaseMulti(r.cache) 84 } 85 r.Unlock() 86 r.reader.Interrupt() 87 } 88 89 // DefaultDispatcher is a default implementation of Dispatcher. 90 type DefaultDispatcher struct { 91 ohm outbound.Manager 92 router routing.Router 93 policy policy.Manager 94 stats stats.Manager 95 } 96 97 func init() { 98 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 99 d := new(DefaultDispatcher) 100 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 101 return d.Init(config.(*Config), om, router, pm, sm) 102 }); err != nil { 103 return nil, err 104 } 105 return d, nil 106 })) 107 } 108 109 // Init initializes DefaultDispatcher. 110 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 111 d.ohm = om 112 d.router = router 113 d.policy = pm 114 d.stats = sm 115 return nil 116 } 117 118 // Type implements common.HasType. 119 func (*DefaultDispatcher) Type() interface{} { 120 return routing.DispatcherType() 121 } 122 123 // Start implements common.Runnable. 124 func (*DefaultDispatcher) Start() error { 125 return nil 126 } 127 128 // Close implements common.Closable. 129 func (*DefaultDispatcher) Close() error { return nil } 130 131 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 132 opt := pipe.OptionsFromContext(ctx) 133 uplinkReader, uplinkWriter := pipe.New(opt...) 134 downlinkReader, downlinkWriter := pipe.New(opt...) 135 136 inboundLink := &transport.Link{ 137 Reader: downlinkReader, 138 Writer: uplinkWriter, 139 } 140 141 outboundLink := &transport.Link{ 142 Reader: uplinkReader, 143 Writer: downlinkWriter, 144 } 145 146 sessionInbound := session.InboundFromContext(ctx) 147 var user *protocol.MemoryUser 148 if sessionInbound != nil { 149 user = sessionInbound.User 150 } 151 152 if user != nil && len(user.Email) > 0 { 153 p := d.policy.ForLevel(user.Level) 154 if p.Stats.UserUplink { 155 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 156 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 157 inboundLink.Writer = &SizeStatWriter{ 158 Counter: c, 159 Writer: inboundLink.Writer, 160 } 161 } 162 } 163 if p.Stats.UserDownlink { 164 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 165 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 166 outboundLink.Writer = &SizeStatWriter{ 167 Counter: c, 168 Writer: outboundLink.Writer, 169 } 170 } 171 } 172 } 173 174 return inboundLink, outboundLink 175 } 176 177 func shouldOverride(result SniffResult, domainOverride []string) bool { 178 if result.Domain() == "" { 179 return false 180 } 181 protocolString := result.Protocol() 182 if resComp, ok := result.(SnifferResultComposite); ok { 183 protocolString = resComp.ProtocolForDomainResult() 184 } 185 for _, p := range domainOverride { 186 if strings.HasPrefix(protocolString, p) || strings.HasSuffix(protocolString, p) { 187 return true 188 } 189 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 190 if resultSubset.IsProtoSubsetOf(p) { 191 return true 192 } 193 } 194 } 195 return false 196 } 197 198 // Dispatch implements routing.Dispatcher. 199 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 200 if !destination.IsValid() { 201 panic("Dispatcher: Invalid destination.") 202 } 203 ob := &session.Outbound{ 204 Target: destination, 205 } 206 ctx = session.ContextWithOutbound(ctx, ob) 207 208 inbound, outbound := d.getLink(ctx) 209 content := session.ContentFromContext(ctx) 210 if content == nil { 211 content = new(session.Content) 212 ctx = session.ContextWithContent(ctx, content) 213 } 214 sniffingRequest := content.SniffingRequest 215 if !sniffingRequest.Enabled { 216 go d.routedDispatch(ctx, outbound, destination) 217 } else { 218 go func() { 219 cReader := &cachedReader{ 220 reader: outbound.Reader.(*pipe.Reader), 221 } 222 outbound.Reader = cReader 223 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) 224 if err == nil { 225 content.Protocol = result.Protocol() 226 } 227 if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { 228 if domain, err := strmatcher.ToDomain(result.Domain()); err == nil { 229 newError("sniffed domain: ", domain, " for ", destination).WriteToLog(session.ExportIDToError(ctx)) 230 destination.Address = net.ParseAddress(domain) 231 ob.Target = destination 232 } 233 } 234 d.routedDispatch(ctx, outbound, destination) 235 }() 236 } 237 238 return inbound, nil 239 } 240 241 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) { 242 payload := buf.New() 243 defer payload.Release() 244 245 sniffer := NewSniffer(ctx) 246 247 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 248 249 if metadataOnly { 250 return metaresult, metadataErr 251 } 252 253 contentResult, contentErr := func() (SniffResult, error) { 254 totalAttempt := 0 255 for { 256 select { 257 case <-ctx.Done(): 258 return nil, ctx.Err() 259 default: 260 totalAttempt++ 261 if totalAttempt > 2 { 262 return nil, errSniffingTimeout 263 } 264 265 cReader.Cache(payload) 266 if !payload.IsEmpty() { 267 result, err := sniffer.Sniff(ctx, payload.Bytes(), network) 268 if err != common.ErrNoClue { 269 return result, err 270 } 271 } 272 if payload.IsFull() { 273 return nil, errUnknownContent 274 } 275 } 276 } 277 }() 278 if contentErr != nil && metadataErr == nil { 279 return metaresult, nil 280 } 281 if contentErr == nil && metadataErr == nil { 282 return CompositeResult(metaresult, contentResult), nil 283 } 284 return contentResult, contentErr 285 } 286 287 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 288 var handler outbound.Handler 289 290 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 291 ctx = session.SetForcedOutboundTagToContext(ctx, "") 292 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 293 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 294 handler = h 295 } else { 296 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 297 common.Close(link.Writer) 298 common.Interrupt(link.Reader) 299 return 300 } 301 } else if d.router != nil { 302 if route, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil { 303 tag := route.GetOutboundTag() 304 if h := d.ohm.GetHandler(tag); h != nil { 305 newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 306 handler = h 307 } else { 308 newError("non existing tag: ", tag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 309 } 310 } else { 311 newError("default route for ", destination).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 312 } 313 } 314 315 if handler == nil { 316 handler = d.ohm.GetDefaultHandler() 317 } 318 319 if handler == nil { 320 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 321 common.Close(link.Writer) 322 common.Interrupt(link.Reader) 323 return 324 } 325 326 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 327 if tag := handler.Tag(); tag != "" { 328 accessMessage.Detour = tag 329 if d.policy.ForSystem().OverrideAccessLogDest { 330 accessMessage.To = destination 331 } 332 } 333 log.Record(accessMessage) 334 } 335 336 handler.Dispatch(ctx, link) 337 }