github.com/annwntech/go-micro/v2@v2.9.5/transport/memory/memory.go (about)

     1  // Package memory is an in-memory transport
     2  package memory
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"math/rand"
     9  	"net"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/annwntech/go-micro/v2/transport"
    14  	maddr "github.com/annwntech/go-micro/v2/util/addr"
    15  	mnet "github.com/annwntech/go-micro/v2/util/net"
    16  )
    17  
    18  type memorySocket struct {
    19  	recv chan *transport.Message
    20  	send chan *transport.Message
    21  	// sock exit
    22  	exit chan bool
    23  	// listener exit
    24  	lexit chan bool
    25  
    26  	local  string
    27  	remote string
    28  
    29  	// for send/recv transport.Timeout
    30  	timeout time.Duration
    31  	ctx     context.Context
    32  	sync.RWMutex
    33  }
    34  
    35  type memoryClient struct {
    36  	*memorySocket
    37  	opts transport.DialOptions
    38  }
    39  
    40  type memoryListener struct {
    41  	addr  string
    42  	exit  chan bool
    43  	conn  chan *memorySocket
    44  	lopts transport.ListenOptions
    45  	topts transport.Options
    46  	sync.RWMutex
    47  	ctx context.Context
    48  }
    49  
    50  type memoryTransport struct {
    51  	opts transport.Options
    52  	sync.RWMutex
    53  	listeners map[string]*memoryListener
    54  }
    55  
    56  func (ms *memorySocket) Recv(m *transport.Message) error {
    57  	ms.RLock()
    58  	defer ms.RUnlock()
    59  
    60  	ctx := ms.ctx
    61  	if ms.timeout > 0 {
    62  		var cancel context.CancelFunc
    63  		ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
    64  		defer cancel()
    65  	}
    66  
    67  	select {
    68  	case <-ctx.Done():
    69  		return ctx.Err()
    70  	case <-ms.exit:
    71  		return errors.New("connection closed")
    72  	case <-ms.lexit:
    73  		return errors.New("server connection closed")
    74  	case cm := <-ms.recv:
    75  		*m = *cm
    76  	}
    77  	return nil
    78  }
    79  
    80  func (ms *memorySocket) Local() string {
    81  	return ms.local
    82  }
    83  
    84  func (ms *memorySocket) Remote() string {
    85  	return ms.remote
    86  }
    87  
    88  func (ms *memorySocket) Send(m *transport.Message) error {
    89  	ms.RLock()
    90  	defer ms.RUnlock()
    91  
    92  	ctx := ms.ctx
    93  	if ms.timeout > 0 {
    94  		var cancel context.CancelFunc
    95  		ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
    96  		defer cancel()
    97  	}
    98  
    99  	select {
   100  	case <-ctx.Done():
   101  		return ctx.Err()
   102  	case <-ms.exit:
   103  		return errors.New("connection closed")
   104  	case <-ms.lexit:
   105  		return errors.New("server connection closed")
   106  	case ms.send <- m:
   107  	}
   108  	return nil
   109  }
   110  
   111  func (ms *memorySocket) Close() error {
   112  	ms.Lock()
   113  	defer ms.Unlock()
   114  	select {
   115  	case <-ms.exit:
   116  		return nil
   117  	default:
   118  		close(ms.exit)
   119  	}
   120  	return nil
   121  }
   122  
   123  func (m *memoryListener) Addr() string {
   124  	return m.addr
   125  }
   126  
   127  func (m *memoryListener) Close() error {
   128  	m.Lock()
   129  	defer m.Unlock()
   130  	select {
   131  	case <-m.exit:
   132  		return nil
   133  	default:
   134  		close(m.exit)
   135  	}
   136  	return nil
   137  }
   138  
   139  func (m *memoryListener) Accept(fn func(transport.Socket)) error {
   140  	for {
   141  		select {
   142  		case <-m.exit:
   143  			return nil
   144  		case c := <-m.conn:
   145  			go fn(&memorySocket{
   146  				lexit:   c.lexit,
   147  				exit:    c.exit,
   148  				send:    c.recv,
   149  				recv:    c.send,
   150  				local:   c.Remote(),
   151  				remote:  c.Local(),
   152  				timeout: m.topts.Timeout,
   153  				ctx:     m.topts.Context,
   154  			})
   155  		}
   156  	}
   157  }
   158  
   159  func (m *memoryTransport) Dial(addr string, opts ...transport.DialOption) (transport.Client, error) {
   160  	m.RLock()
   161  	defer m.RUnlock()
   162  
   163  	listener, ok := m.listeners[addr]
   164  	if !ok {
   165  		return nil, errors.New("could not dial " + addr)
   166  	}
   167  
   168  	var options transport.DialOptions
   169  	for _, o := range opts {
   170  		o(&options)
   171  	}
   172  
   173  	client := &memoryClient{
   174  		&memorySocket{
   175  			send:    make(chan *transport.Message),
   176  			recv:    make(chan *transport.Message),
   177  			exit:    make(chan bool),
   178  			lexit:   listener.exit,
   179  			local:   addr,
   180  			remote:  addr,
   181  			timeout: m.opts.Timeout,
   182  			ctx:     m.opts.Context,
   183  		},
   184  		options,
   185  	}
   186  
   187  	// pseudo connect
   188  	select {
   189  	case <-listener.exit:
   190  		return nil, errors.New("connection error")
   191  	case listener.conn <- client.memorySocket:
   192  	}
   193  
   194  	return client, nil
   195  }
   196  
   197  func (m *memoryTransport) Listen(addr string, opts ...transport.ListenOption) (transport.Listener, error) {
   198  	m.Lock()
   199  	defer m.Unlock()
   200  
   201  	var options transport.ListenOptions
   202  	for _, o := range opts {
   203  		o(&options)
   204  	}
   205  
   206  	host, port, err := net.SplitHostPort(addr)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	addr, err = maddr.Extract(host)
   212  	if err != nil {
   213  		return nil, err
   214  	}
   215  
   216  	// if zero port then randomly assign one
   217  	if len(port) > 0 && port == "0" {
   218  		i := rand.Intn(20000)
   219  		port = fmt.Sprintf("%d", 10000+i)
   220  	}
   221  
   222  	// set addr with port
   223  	addr = mnet.HostPort(addr, port)
   224  
   225  	if _, ok := m.listeners[addr]; ok {
   226  		return nil, errors.New("already listening on " + addr)
   227  	}
   228  
   229  	listener := &memoryListener{
   230  		lopts: options,
   231  		topts: m.opts,
   232  		addr:  addr,
   233  		conn:  make(chan *memorySocket),
   234  		exit:  make(chan bool),
   235  		ctx:   m.opts.Context,
   236  	}
   237  
   238  	m.listeners[addr] = listener
   239  
   240  	return listener, nil
   241  }
   242  
   243  func (m *memoryTransport) Init(opts ...transport.Option) error {
   244  	for _, o := range opts {
   245  		o(&m.opts)
   246  	}
   247  	return nil
   248  }
   249  
   250  func (m *memoryTransport) Options() transport.Options {
   251  	return m.opts
   252  }
   253  
   254  func (m *memoryTransport) String() string {
   255  	return "memory"
   256  }
   257  
   258  func NewTransport(opts ...transport.Option) transport.Transport {
   259  	var options transport.Options
   260  
   261  	rand.Seed(time.Now().UnixNano())
   262  
   263  	for _, o := range opts {
   264  		o(&options)
   265  	}
   266  
   267  	if options.Context == nil {
   268  		options.Context = context.Background()
   269  	}
   270  
   271  	return &memoryTransport{
   272  		opts:      options,
   273  		listeners: make(map[string]*memoryListener),
   274  	}
   275  }