gitee.com/sy_183/go-common@v1.0.5-0.20231205030221-958cfe129b47/ipc/cgo/shm-server/server.go (about)

     1  package shmServer
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"gitee.com/sy_183/go-common/container"
     7  	semPkg "gitee.com/sy_183/go-common/ipc/cgo/sem"
     8  	"gitee.com/sy_183/go-common/ipc/cgo/sem/queue"
     9  	"gitee.com/sy_183/go-common/lifecycle"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  	"unsafe"
    17  )
    18  
    19  type selector struct {
    20  	qcs []struct {
    21  		queue32 *queue.Queue32
    22  		channel *Channel
    23  	}
    24  }
    25  
    26  type Server struct {
    27  	lifecycle.Lifecycle
    28  	runner *lifecycle.DefaultLifecycle
    29  	closed atomic.Bool
    30  
    31  	addr     *net.UDPAddr
    32  	conn     *net.UDPConn
    33  	allocSem semPkg.NamedSem
    34  	useSem   semPkg.NamedSem
    35  
    36  	channels     container.SyncMap[AddrId, *Channel]
    37  	selector     atomic.Pointer[selector]
    38  	selectorLock sync.Mutex
    39  
    40  	errorHandlers [][]ErrorHandler
    41  }
    42  
    43  func NewServer(addr *net.UDPAddr) *Server {
    44  	net.ListenUnix()
    45  }
    46  
    47  func loadErrorHandler(p *ErrorHandler) ErrorHandler {
    48  	up := atomic.LoadUintptr((*uintptr)(unsafe.Pointer(p)))
    49  	return *(*ErrorHandler)(unsafe.Pointer(&up))
    50  }
    51  
    52  func storeErrorHandler(p *ErrorHandler, handler ErrorHandler) {
    53  	up := *(*uintptr)(unsafe.Pointer(&handler))
    54  	atomic.StoreUintptr((*uintptr)(unsafe.Pointer(p)), up)
    55  }
    56  
    57  func (s *Server) locateErrorHandler(typ ErrType) (main, sub *ErrorHandler) {
    58  	mt := int(typ >> 8)
    59  	st := int(typ & 0xff)
    60  	if mt < len(s.errorHandlers) {
    61  		if mhs := s.errorHandlers[mt]; mt == 0 {
    62  			return &mhs[0], nil
    63  		} else if mt < len(mhs) {
    64  			return &mhs[st], &mhs[0]
    65  		}
    66  	}
    67  	return nil, nil
    68  }
    69  
    70  func (s *Server) getErrorHandler(typ ErrType) (main, sub ErrorHandler) {
    71  	mp, sp := s.locateErrorHandler(typ)
    72  	if mp != nil {
    73  		main = loadErrorHandler(mp)
    74  	}
    75  	if sp != nil {
    76  		sub = loadErrorHandler(sp)
    77  	}
    78  	return
    79  }
    80  
    81  func (s *Server) handleError(typ ErrType, err error, args ...any) {
    82  	main, sub := s.getErrorHandler(typ)
    83  	if sub != nil {
    84  		sub(typ, err, args...)
    85  		return
    86  	}
    87  	if main != nil {
    88  		main(typ, err, args...)
    89  	}
    90  	return
    91  }
    92  
    93  func (s *Server) ErrorHandler(typ ErrType) ErrorHandler {
    94  	if mp, sp := s.locateErrorHandler(typ); sp != nil {
    95  		return loadErrorHandler(sp)
    96  	} else if mp != nil {
    97  		return loadErrorHandler(mp)
    98  	}
    99  	return nil
   100  }
   101  
   102  func (s *Server) SetErrorHandler(typ ErrType, handler ErrorHandler) {
   103  	if mp, sp := s.locateErrorHandler(typ); sp != nil {
   104  		storeErrorHandler(sp, handler)
   105  	} else if mp != nil {
   106  		storeErrorHandler(mp, handler)
   107  	}
   108  }
   109  
   110  func (s *Server) updateSelectChannel() {
   111  	qcs := s.selector.Load().qcs[:0:0]
   112  	s.selectorLock.Lock()
   113  	defer s.selectorLock.Unlock()
   114  	s.channels.Range(func(id AddrId, ch *Channel) bool {
   115  		if q := ch.loadQueue(); q != nil {
   116  			qcs = append(qcs, struct {
   117  				queue32 *queue.Queue32
   118  				channel *Channel
   119  			}{queue32: q, channel: ch})
   120  		}
   121  		return true
   122  	})
   123  	s.selector.Store(&selector{qcs: qcs})
   124  }
   125  
   126  func (s *Server) start(lifecycle.Lifecycle) (err error) {
   127  	// 监听UDP
   128  	conn, err := net.ListenUDP("udp", s.addr)
   129  	if err != nil {
   130  		s.handleError(ErrTypeUdpListen, err, s)
   131  		return err
   132  	}
   133  	s.conn = conn
   134  	defer func() {
   135  		if err != nil {
   136  			if err := s.conn.Close(); err != nil {
   137  				s.handleError(ErrTypeUdpClose, err, s)
   138  			}
   139  			s.conn = nil
   140  		}
   141  	}()
   142  
   143  	ads := addrToString(s.addr)
   144  
   145  	// 打开可申请空间的信号量
   146  	allocName := fmt.Sprintf("go-shm-server_%s_alloc", ads)
   147  	allocSem, err := semPkg.Open(allocName, os.O_CREATE|os.O_EXCL, 0644, 0)
   148  	if err != nil {
   149  		s.handleError(ErrTypeSemOpen, err, s, allocName)
   150  		return err
   151  	}
   152  	s.allocSem = allocSem
   153  	defer func() {
   154  		if err != nil {
   155  			if err := s.allocSem.Unlink(); err != nil {
   156  				s.handleError(ErrTypeSemUnlink, err, s, s.allocSem.Name())
   157  			}
   158  			s.allocSem = semPkg.NamedSem{}
   159  		}
   160  	}()
   161  
   162  	// 打开可使用空间的信号量
   163  	useName := fmt.Sprintf("go-shm-server_%s_use", ads)
   164  	useSem, err := semPkg.Open(useName, os.O_CREATE|os.O_EXCL, 0644, 0)
   165  	if err != nil {
   166  		s.handleError(ErrTypeSemOpen, err, s, useName)
   167  		return err
   168  	}
   169  	s.useSem = useSem
   170  	return nil
   171  }
   172  
   173  func (s *Server) read(buf []byte) (data []byte, addr *net.UDPAddr, close bool, err error) {
   174  	if err := s.conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
   175  		s.handleError(ErrTypeUdpSetReadDeadline, err, s)
   176  	}
   177  
   178  	n, addr, err := s.conn.ReadFromUDP(buf)
   179  	if err != nil {
   180  		if netErr, is := err.(net.Error); is && netErr.Timeout() {
   181  			if s.closed.Load() {
   182  				return nil, nil, true, err
   183  			}
   184  			return nil, nil, false, nil
   185  		} else if err == io.EOF {
   186  			return nil, nil, true, nil
   187  		} else if opErr, is := err.(*net.OpError); is && errors.Is(opErr.Err, net.ErrClosed) {
   188  			return nil, nil, true, nil
   189  		}
   190  		s.handleError(ErrTypeUdpRead, err, s)
   191  		return nil, nil, true, err
   192  	}
   193  
   194  	return buf[:n], addr, false, nil
   195  }
   196  
   197  func (s *Server) run(lifecycle.Lifecycle) error {
   198  	go func() {
   199  		if err := s.useSem.Wait(); err != nil {
   200  			s.handleError(ErrTypeSemWait, err, s, s.useSem.Name())
   201  			return
   202  		}
   203  
   204  		sel := s.selector.Load()
   205  		for _, qc := range sel.qcs {
   206  			q := qc.queue32
   207  			view, ok := q.WantTryUse(q.Cap())
   208  			if ok {
   209  
   210  			}
   211  		}
   212  	}()
   213  
   214  	var buf [maxPacketLength]byte
   215  	var basePacket BasePacket
   216  	var addrId AddrId
   217  
   218  	for {
   219  		data, addr, closed, err := s.read(buf[:])
   220  		if closed {
   221  			return err
   222  		}
   223  
   224  		if err, et := basePacket.Unmarshal(data); err != nil {
   225  			s.handleError(et, err, s)
   226  			continue
   227  		}
   228  		addrId.Unmarshal(addr.IP, addr.Port)
   229  		channel, _ := s.channels.Load(addrId)
   230  
   231  		switch basePacket.command {
   232  		case CommandOpenRequest:
   233  			if channel == nil {
   234  				channel = newChannel(addrId)
   235  			}
   236  			s.conn.
   237  				channel.handelOpen(OpenRequest{basePacket})
   238  		case CommandAllocRequest:
   239  			packet := AllocRequest{BasePacket: basePacket}
   240  			if err, et := packet.Unmarshal(data); err != nil {
   241  				s.handleError(et, err, s)
   242  				continue
   243  			}
   244  			channel.handleAlloc(packet)
   245  		case CommandKeepaliveRequest:
   246  			channel.handleKeepalive(KeepaliveRequest{basePacket})
   247  		case CommandCloseRequest:
   248  			channel.handleClose(CloseRequest{basePacket})
   249  		case CommandOpenResponse,
   250  			CommandAllocResponse,
   251  			CommandKeepaliveResponse,
   252  			CommandCloseResponse:
   253  			continue
   254  		}
   255  	}
   256  }
   257  
   258  func (s *Server) close(lifecycle.Lifecycle) error {
   259  
   260  }