github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/request/assembler/simple/server.go (about)

     1  package simple
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"sync"
     7  
     8  	"github.com/v2fly/v2ray-core/v5/common"
     9  
    10  	"github.com/v2fly/v2ray-core/v5/transport/internet/request"
    11  )
    12  
    13  func newServer(config *ServerConfig) request.SessionAssemblerServer {
    14  	return &simpleAssemblerServer{}
    15  }
    16  
    17  type simpleAssemblerServer struct {
    18  	sessions sync.Map
    19  	assembly request.TransportServerAssembly
    20  }
    21  
    22  func (s *simpleAssemblerServer) OnTransportServerAssemblyReady(assembly request.TransportServerAssembly) {
    23  	s.assembly = assembly
    24  }
    25  
    26  func (s *simpleAssemblerServer) OnRoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption,
    27  ) (resp request.Response, err error) {
    28  	connectionID := req.ConnectionTag
    29  	session := newSimpleAssemblerServerSession(ctx)
    30  	loadedSession, loaded := s.sessions.LoadOrStore(string(connectionID), session)
    31  	if loaded {
    32  		session = loadedSession.(*simpleAssemblerServerSession)
    33  	} else {
    34  		if err := s.assembly.SessionReceiver().OnNewSession(ctx, session); err != nil {
    35  			return request.Response{}, newError("failed to create new session").Base(err)
    36  		}
    37  	}
    38  	return session.OnRoundTrip(ctx, req, opts...)
    39  }
    40  
    41  func newSimpleAssemblerServerSession(ctx context.Context) *simpleAssemblerServerSession {
    42  	sessionCtx, finish := context.WithCancel(ctx)
    43  	return &simpleAssemblerServerSession{
    44  		readBuffer:       bytes.NewBuffer(nil),
    45  		readChan:         make(chan []byte, 16),
    46  		requestProcessed: make(chan struct{}),
    47  		writeLock:        new(sync.Mutex),
    48  		writeBuffer:      bytes.NewBuffer(nil),
    49  		maxWriteSize:     4096,
    50  		ctx:              sessionCtx,
    51  		finish:           finish,
    52  	}
    53  }
    54  
    55  type simpleAssemblerServerSession struct {
    56  	maxWriteSize int
    57  
    58  	readBuffer       *bytes.Buffer
    59  	readChan         chan []byte
    60  	requestProcessed chan struct{}
    61  
    62  	writeLock   *sync.Mutex
    63  	writeBuffer *bytes.Buffer
    64  
    65  	ctx    context.Context
    66  	finish func()
    67  }
    68  
    69  func (s *simpleAssemblerServerSession) Read(p []byte) (n int, err error) {
    70  	if s.readBuffer.Len() == 0 {
    71  		select {
    72  		case <-s.ctx.Done():
    73  			return 0, s.ctx.Err()
    74  		case data := <-s.readChan:
    75  			s.readBuffer.Write(data)
    76  		}
    77  	}
    78  	return s.readBuffer.Read(p)
    79  }
    80  
    81  func (s *simpleAssemblerServerSession) Write(p []byte) (n int, err error) {
    82  	s.writeLock.Lock()
    83  
    84  	n, err = s.writeBuffer.Write(p)
    85  	length := s.writeBuffer.Len()
    86  	s.writeLock.Unlock()
    87  	if err != nil {
    88  		return 0, err
    89  	}
    90  	if length > s.maxWriteSize {
    91  		select {
    92  		case <-s.requestProcessed:
    93  		case <-s.ctx.Done():
    94  			return 0, s.ctx.Err()
    95  		}
    96  	}
    97  	return
    98  }
    99  
   100  func (s *simpleAssemblerServerSession) Close() error {
   101  	s.finish()
   102  	return nil
   103  }
   104  
   105  func (s *simpleAssemblerServerSession) OnRoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption,
   106  ) (resp request.Response, err error) {
   107  	if req.Data != nil && len(req.Data) > 0 {
   108  		select {
   109  		case <-s.ctx.Done():
   110  			return request.Response{}, s.ctx.Err()
   111  		case s.readChan <- req.Data:
   112  		}
   113  	}
   114  
   115  	s.writeLock.Lock()
   116  	nextWrite := s.writeBuffer.Next(s.maxWriteSize)
   117  	data := make([]byte, len(nextWrite))
   118  	copy(data, nextWrite)
   119  	s.writeLock.Unlock()
   120  	select {
   121  	case s.requestProcessed <- struct{}{}:
   122  	case <-s.ctx.Done():
   123  		return request.Response{}, s.ctx.Err()
   124  	default:
   125  	}
   126  	return request.Response{Data: data}, nil
   127  }
   128  
   129  func init() {
   130  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   131  		serverConfig, ok := config.(*ServerConfig)
   132  		if !ok {
   133  			return nil, newError("not a SimpleServerConfig")
   134  		}
   135  		return newServer(serverConfig), nil
   136  	}))
   137  }