github.com/sagernet/sing-box@v1.9.0-rc.20/transport/vless/service.go (about)

     1  package vless
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  
     9  	"github.com/sagernet/sing-vmess"
    10  	"github.com/sagernet/sing/common/auth"
    11  	"github.com/sagernet/sing/common/buf"
    12  	"github.com/sagernet/sing/common/bufio"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	"github.com/sagernet/sing/common/logger"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  
    18  	"github.com/gofrs/uuid/v5"
    19  )
    20  
    21  type Service[T comparable] struct {
    22  	userMap  map[[16]byte]T
    23  	userFlow map[T]string
    24  	logger   logger.Logger
    25  	handler  Handler
    26  }
    27  
    28  type Handler interface {
    29  	N.TCPConnectionHandler
    30  	N.UDPConnectionHandler
    31  	E.Handler
    32  }
    33  
    34  func NewService[T comparable](logger logger.Logger, handler Handler) *Service[T] {
    35  	return &Service[T]{
    36  		logger:  logger,
    37  		handler: handler,
    38  	}
    39  }
    40  
    41  func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string, userFlowList []string) {
    42  	userMap := make(map[[16]byte]T)
    43  	userFlowMap := make(map[T]string)
    44  	for i, userName := range userList {
    45  		userID := uuid.FromStringOrNil(userUUIDList[i])
    46  		if userID == uuid.Nil {
    47  			userID = uuid.NewV5(uuid.Nil, userUUIDList[i])
    48  		}
    49  		userMap[userID] = userName
    50  		userFlowMap[userName] = userFlowList[i]
    51  	}
    52  	s.userMap = userMap
    53  	s.userFlow = userFlowMap
    54  }
    55  
    56  var _ N.TCPConnectionHandler = (*Service[int])(nil)
    57  
    58  func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    59  	request, err := ReadRequest(conn)
    60  	if err != nil {
    61  		return err
    62  	}
    63  	user, loaded := s.userMap[request.UUID]
    64  	if !loaded {
    65  		return E.New("unknown UUID: ", uuid.FromBytesOrNil(request.UUID[:]))
    66  	}
    67  	ctx = auth.ContextWithUser(ctx, user)
    68  	metadata.Destination = request.Destination
    69  
    70  	userFlow := s.userFlow[user]
    71  	if request.Flow == FlowVision && request.Command == vmess.NetworkUDP {
    72  		return E.New(FlowVision, " flow does not support UDP")
    73  	} else if request.Flow != userFlow {
    74  		return E.New("flow mismatch: expected ", flowName(userFlow), ", but got ", flowName(request.Flow))
    75  	}
    76  
    77  	if request.Command == vmess.CommandUDP {
    78  		return s.handler.NewPacketConnection(ctx, &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(conn), destination: request.Destination}, metadata)
    79  	}
    80  	responseConn := &serverConn{ExtendedConn: bufio.NewExtendedConn(conn), writer: bufio.NewVectorisedWriter(conn)}
    81  	switch userFlow {
    82  	case FlowVision:
    83  		conn, err = NewVisionConn(responseConn, conn, request.UUID, s.logger)
    84  		if err != nil {
    85  			return E.Cause(err, "initialize vision")
    86  		}
    87  	case "":
    88  		conn = responseConn
    89  	default:
    90  		return E.New("unknown flow: ", userFlow)
    91  	}
    92  	switch request.Command {
    93  	case vmess.CommandTCP:
    94  		return s.handler.NewConnection(ctx, conn, metadata)
    95  	case vmess.CommandMux:
    96  		return vmess.HandleMuxConnection(ctx, conn, s.handler)
    97  	default:
    98  		return E.New("unknown command: ", request.Command)
    99  	}
   100  }
   101  
   102  func flowName(value string) string {
   103  	if value == "" {
   104  		return "none"
   105  	}
   106  	return value
   107  }
   108  
   109  var _ N.VectorisedWriter = (*serverConn)(nil)
   110  
   111  type serverConn struct {
   112  	N.ExtendedConn
   113  	writer          N.VectorisedWriter
   114  	responseWritten bool
   115  }
   116  
   117  func (c *serverConn) Read(b []byte) (n int, err error) {
   118  	return c.ExtendedConn.Read(b)
   119  }
   120  
   121  func (c *serverConn) Write(b []byte) (n int, err error) {
   122  	if !c.responseWritten {
   123  		_, err = bufio.WriteVectorised(c.writer, [][]byte{{Version, 0}, b})
   124  		if err == nil {
   125  			n = len(b)
   126  		}
   127  		c.responseWritten = true
   128  		return
   129  	}
   130  	return c.ExtendedConn.Write(b)
   131  }
   132  
   133  func (c *serverConn) WriteBuffer(buffer *buf.Buffer) error {
   134  	if !c.responseWritten {
   135  		header := buffer.ExtendHeader(2)
   136  		header[0] = Version
   137  		header[1] = 0
   138  		c.responseWritten = true
   139  	}
   140  	return c.ExtendedConn.WriteBuffer(buffer)
   141  }
   142  
   143  func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error {
   144  	if !c.responseWritten {
   145  		err := c.writer.WriteVectorised(append([]*buf.Buffer{buf.As([]byte{Version, 0})}, buffers...))
   146  		c.responseWritten = true
   147  		return err
   148  	}
   149  	return c.writer.WriteVectorised(buffers)
   150  }
   151  
   152  func (c *serverConn) NeedAdditionalReadDeadline() bool {
   153  	return true
   154  }
   155  
   156  func (c *serverConn) FrontHeadroom() int {
   157  	if c.responseWritten {
   158  		return 0
   159  	}
   160  	return 2
   161  }
   162  
   163  func (c *serverConn) ReaderReplaceable() bool {
   164  	return true
   165  }
   166  
   167  func (c *serverConn) WriterReplaceable() bool {
   168  	return c.responseWritten
   169  }
   170  
   171  func (c *serverConn) Upstream() any {
   172  	return c.ExtendedConn
   173  }
   174  
   175  type serverPacketConn struct {
   176  	N.ExtendedConn
   177  	responseWriter  io.Writer
   178  	responseWritten bool
   179  	destination     M.Socksaddr
   180  }
   181  
   182  func (c *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   183  	n, err = c.ExtendedConn.Read(p)
   184  	if err != nil {
   185  		return
   186  	}
   187  	if c.destination.IsFqdn() {
   188  		addr = c.destination
   189  	} else {
   190  		addr = c.destination.UDPAddr()
   191  	}
   192  	return
   193  }
   194  
   195  func (c *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   196  	if !c.responseWritten {
   197  		if c.responseWriter == nil {
   198  			var packetLen [2]byte
   199  			binary.BigEndian.PutUint16(packetLen[:], uint16(len(p)))
   200  			_, err = bufio.WriteVectorised(bufio.NewVectorisedWriter(c.ExtendedConn), [][]byte{{Version, 0}, packetLen[:], p})
   201  			if err == nil {
   202  				n = len(p)
   203  			}
   204  			c.responseWritten = true
   205  			return
   206  		} else {
   207  			_, err = c.responseWriter.Write([]byte{Version, 0})
   208  			if err != nil {
   209  				return
   210  			}
   211  			c.responseWritten = true
   212  		}
   213  	}
   214  	return c.ExtendedConn.Write(p)
   215  }
   216  
   217  func (c *serverPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   218  	var packetLen uint16
   219  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &packetLen)
   220  	if err != nil {
   221  		return
   222  	}
   223  
   224  	_, err = buffer.ReadFullFrom(c.ExtendedConn, int(packetLen))
   225  	if err != nil {
   226  		return
   227  	}
   228  
   229  	destination = c.destination
   230  	return
   231  }
   232  
   233  func (c *serverPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   234  	if !c.responseWritten {
   235  		if c.responseWriter == nil {
   236  			var packetLen [2]byte
   237  			binary.BigEndian.PutUint16(packetLen[:], uint16(buffer.Len()))
   238  			err := bufio.NewVectorisedWriter(c.ExtendedConn).WriteVectorised([]*buf.Buffer{buf.As([]byte{Version, 0}), buf.As(packetLen[:]), buffer})
   239  			c.responseWritten = true
   240  			return err
   241  		} else {
   242  			_, err := c.responseWriter.Write([]byte{Version, 0})
   243  			if err != nil {
   244  				return err
   245  			}
   246  			c.responseWritten = true
   247  		}
   248  	}
   249  	packetLen := buffer.Len()
   250  	binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(packetLen))
   251  	return c.ExtendedConn.WriteBuffer(buffer)
   252  }
   253  
   254  func (c *serverPacketConn) FrontHeadroom() int {
   255  	return 2
   256  }
   257  
   258  func (c *serverPacketConn) Upstream() any {
   259  	return c.ExtendedConn
   260  }