github.com/v2fly/v2ray-core/v4@v4.45.2/app/dispatcher/default.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dispatcher 5 6 //go:generate go run github.com/v2fly/v2ray-core/v4/common/errors/errorgen 7 8 import ( 9 "context" 10 "strings" 11 "sync" 12 "time" 13 14 core "github.com/v2fly/v2ray-core/v4" 15 "github.com/v2fly/v2ray-core/v4/common" 16 "github.com/v2fly/v2ray-core/v4/common/buf" 17 "github.com/v2fly/v2ray-core/v4/common/log" 18 "github.com/v2fly/v2ray-core/v4/common/net" 19 "github.com/v2fly/v2ray-core/v4/common/protocol" 20 "github.com/v2fly/v2ray-core/v4/common/session" 21 "github.com/v2fly/v2ray-core/v4/features/outbound" 22 "github.com/v2fly/v2ray-core/v4/features/policy" 23 "github.com/v2fly/v2ray-core/v4/features/routing" 24 routing_session "github.com/v2fly/v2ray-core/v4/features/routing/session" 25 "github.com/v2fly/v2ray-core/v4/features/stats" 26 "github.com/v2fly/v2ray-core/v4/transport" 27 "github.com/v2fly/v2ray-core/v4/transport/pipe" 28 ) 29 30 var errSniffingTimeout = newError("timeout on sniffing") 31 32 type cachedReader struct { 33 sync.Mutex 34 reader *pipe.Reader 35 cache buf.MultiBuffer 36 } 37 38 func (r *cachedReader) Cache(b *buf.Buffer) { 39 mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100) 40 r.Lock() 41 if !mb.IsEmpty() { 42 r.cache, _ = buf.MergeMulti(r.cache, mb) 43 } 44 b.Clear() 45 rawBytes := b.Extend(buf.Size) 46 n := r.cache.Copy(rawBytes) 47 b.Resize(0, int32(n)) 48 r.Unlock() 49 } 50 51 func (r *cachedReader) readInternal() buf.MultiBuffer { 52 r.Lock() 53 defer r.Unlock() 54 55 if r.cache != nil && !r.cache.IsEmpty() { 56 mb := r.cache 57 r.cache = nil 58 return mb 59 } 60 61 return nil 62 } 63 64 func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 65 mb := r.readInternal() 66 if mb != nil { 67 return mb, nil 68 } 69 70 return r.reader.ReadMultiBuffer() 71 } 72 73 func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) { 74 mb := r.readInternal() 75 if mb != nil { 76 return mb, nil 77 } 78 79 return r.reader.ReadMultiBufferTimeout(timeout) 80 } 81 82 func (r *cachedReader) Interrupt() { 83 r.Lock() 84 if r.cache != nil { 85 r.cache = buf.ReleaseMulti(r.cache) 86 } 87 r.Unlock() 88 r.reader.Interrupt() 89 } 90 91 // DefaultDispatcher is a default implementation of Dispatcher. 92 type DefaultDispatcher struct { 93 ohm outbound.Manager 94 router routing.Router 95 policy policy.Manager 96 stats stats.Manager 97 } 98 99 func init() { 100 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 101 d := new(DefaultDispatcher) 102 if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 103 return d.Init(config.(*Config), om, router, pm, sm) 104 }); err != nil { 105 return nil, err 106 } 107 return d, nil 108 })) 109 } 110 111 // Init initializes DefaultDispatcher. 112 func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error { 113 d.ohm = om 114 d.router = router 115 d.policy = pm 116 d.stats = sm 117 return nil 118 } 119 120 // Type implements common.HasType. 121 func (*DefaultDispatcher) Type() interface{} { 122 return routing.DispatcherType() 123 } 124 125 // Start implements common.Runnable. 126 func (*DefaultDispatcher) Start() error { 127 return nil 128 } 129 130 // Close implements common.Closable. 131 func (*DefaultDispatcher) Close() error { return nil } 132 133 func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) { 134 opt := pipe.OptionsFromContext(ctx) 135 uplinkReader, uplinkWriter := pipe.New(opt...) 136 downlinkReader, downlinkWriter := pipe.New(opt...) 137 138 inboundLink := &transport.Link{ 139 Reader: downlinkReader, 140 Writer: uplinkWriter, 141 } 142 143 outboundLink := &transport.Link{ 144 Reader: uplinkReader, 145 Writer: downlinkWriter, 146 } 147 148 sessionInbound := session.InboundFromContext(ctx) 149 var user *protocol.MemoryUser 150 if sessionInbound != nil { 151 user = sessionInbound.User 152 } 153 154 if user != nil && len(user.Email) > 0 { 155 p := d.policy.ForLevel(user.Level) 156 if p.Stats.UserUplink { 157 name := "user>>>" + user.Email + ">>>traffic>>>uplink" 158 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 159 inboundLink.Writer = &SizeStatWriter{ 160 Counter: c, 161 Writer: inboundLink.Writer, 162 } 163 } 164 } 165 if p.Stats.UserDownlink { 166 name := "user>>>" + user.Email + ">>>traffic>>>downlink" 167 if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil { 168 outboundLink.Writer = &SizeStatWriter{ 169 Counter: c, 170 Writer: outboundLink.Writer, 171 } 172 } 173 } 174 } 175 176 return inboundLink, outboundLink 177 } 178 179 func shouldOverride(result SniffResult, domainOverride []string) bool { 180 protocolString := result.Protocol() 181 if resComp, ok := result.(SnifferResultComposite); ok { 182 protocolString = resComp.ProtocolForDomainResult() 183 } 184 for _, p := range domainOverride { 185 if strings.HasPrefix(protocolString, p) { 186 return true 187 } 188 if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok { 189 if resultSubset.IsProtoSubsetOf(p) { 190 return true 191 } 192 } 193 } 194 return false 195 } 196 197 // Dispatch implements routing.Dispatcher. 198 func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) { 199 if !destination.IsValid() { 200 panic("Dispatcher: Invalid destination.") 201 } 202 ob := &session.Outbound{ 203 Target: destination, 204 } 205 ctx = session.ContextWithOutbound(ctx, ob) 206 207 inbound, outbound := d.getLink(ctx) 208 content := session.ContentFromContext(ctx) 209 if content == nil { 210 content = new(session.Content) 211 ctx = session.ContextWithContent(ctx, content) 212 } 213 sniffingRequest := content.SniffingRequest 214 switch { 215 case !sniffingRequest.Enabled: 216 go d.routedDispatch(ctx, outbound, destination) 217 case destination.Network != net.Network_TCP: 218 // Only metadata sniff will be used for non tcp connection 219 result, err := sniffer(ctx, nil, true) 220 if err == nil { 221 content.Protocol = result.Protocol() 222 if shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { 223 domain := result.Domain() 224 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 225 destination.Address = net.ParseAddress(domain) 226 ob.Target = destination 227 } 228 } 229 go d.routedDispatch(ctx, outbound, destination) 230 default: 231 go func() { 232 cReader := &cachedReader{ 233 reader: outbound.Reader.(*pipe.Reader), 234 } 235 outbound.Reader = cReader 236 result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly) 237 if err == nil { 238 content.Protocol = result.Protocol() 239 } 240 if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) { 241 domain := result.Domain() 242 newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx)) 243 destination.Address = net.ParseAddress(domain) 244 ob.Target = destination 245 } 246 d.routedDispatch(ctx, outbound, destination) 247 }() 248 } 249 return inbound, nil 250 } 251 252 func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool) (SniffResult, error) { 253 payload := buf.New() 254 defer payload.Release() 255 256 sniffer := NewSniffer(ctx) 257 258 metaresult, metadataErr := sniffer.SniffMetadata(ctx) 259 260 if metadataOnly { 261 return metaresult, metadataErr 262 } 263 264 contentResult, contentErr := func() (SniffResult, error) { 265 totalAttempt := 0 266 for { 267 select { 268 case <-ctx.Done(): 269 return nil, ctx.Err() 270 default: 271 totalAttempt++ 272 if totalAttempt > 2 { 273 return nil, errSniffingTimeout 274 } 275 276 cReader.Cache(payload) 277 if !payload.IsEmpty() { 278 result, err := sniffer.Sniff(ctx, payload.Bytes()) 279 if err != common.ErrNoClue { 280 return result, err 281 } 282 } 283 if payload.IsFull() { 284 return nil, errUnknownContent 285 } 286 } 287 } 288 }() 289 if contentErr != nil && metadataErr == nil { 290 return metaresult, nil 291 } 292 if contentErr == nil && metadataErr == nil { 293 return CompositeResult(metaresult, contentResult), nil 294 } 295 return contentResult, contentErr 296 } 297 298 func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) { 299 var handler outbound.Handler 300 301 if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" { 302 ctx = session.SetForcedOutboundTagToContext(ctx, "") 303 if h := d.ohm.GetHandler(forcedOutboundTag); h != nil { 304 newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 305 handler = h 306 } else { 307 newError("non existing tag for platform initialized detour: ", forcedOutboundTag).AtError().WriteToLog(session.ExportIDToError(ctx)) 308 common.Close(link.Writer) 309 common.Interrupt(link.Reader) 310 return 311 } 312 } else if d.router != nil { 313 if route, err := d.router.PickRoute(routing_session.AsRoutingContext(ctx)); err == nil { 314 tag := route.GetOutboundTag() 315 if h := d.ohm.GetHandler(tag); h != nil { 316 newError("taking detour [", tag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx)) 317 handler = h 318 } else { 319 newError("non existing tag: ", tag).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 320 } 321 } else { 322 newError("default route for ", destination).WriteToLog(session.ExportIDToError(ctx)) 323 } 324 } 325 326 if handler == nil { 327 handler = d.ohm.GetDefaultHandler() 328 } 329 330 if handler == nil { 331 newError("default outbound handler not exist").WriteToLog(session.ExportIDToError(ctx)) 332 common.Close(link.Writer) 333 common.Interrupt(link.Reader) 334 return 335 } 336 337 if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil { 338 if tag := handler.Tag(); tag != "" { 339 accessMessage.Detour = tag 340 } 341 log.Record(accessMessage) 342 } 343 344 handler.Dispatch(ctx, link) 345 }