go-micro.dev/v5@v5.12.0/transport/memory.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"encoding/gob"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	maddr "go-micro.dev/v5/util/addr"
    15  	mnet "go-micro.dev/v5/util/net"
    16  )
    17  
    18  type memorySocket struct {
    19  	ctx context.Context
    20  	// Client receiver of io.Pipe with gob
    21  	crecv *gob.Decoder
    22  	// Client sender of the io.Pipe with gob
    23  	csend *gob.Encoder
    24  	// Server receiver of the io.Pip with gob
    25  	srecv *gob.Decoder
    26  	// Server sender of the io.Pip with gob
    27  	ssend *gob.Encoder
    28  	// sock exit
    29  	exit chan bool
    30  	// listener exit
    31  	lexit chan bool
    32  
    33  	local  string
    34  	remote string
    35  
    36  	// for send/recv Timeout
    37  	timeout time.Duration
    38  	// True server mode, False client mode
    39  	server bool
    40  }
    41  
    42  type memoryClient struct {
    43  	*memorySocket
    44  	opts DialOptions
    45  }
    46  
    47  type memoryListener struct {
    48  	lopts ListenOptions
    49  	ctx   context.Context
    50  	exit  chan bool
    51  	conn  chan *memorySocket
    52  	topts Options
    53  	addr  string
    54  	sync.RWMutex
    55  }
    56  
    57  type memoryTransport struct {
    58  	listeners map[string]*memoryListener
    59  	opts      Options
    60  	sync.RWMutex
    61  }
    62  
    63  func (ms *memorySocket) Recv(m *Message) error {
    64  	ctx := ms.ctx
    65  	if ms.timeout > 0 {
    66  		var cancel context.CancelFunc
    67  		ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
    68  		defer cancel()
    69  	}
    70  
    71  	select {
    72  	case <-ctx.Done():
    73  		return ctx.Err()
    74  	case <-ms.exit:
    75  		// connection closed
    76  		return io.EOF
    77  	case <-ms.lexit:
    78  		// Server connection closed
    79  		return io.EOF
    80  	default:
    81  		if ms.server {
    82  			if err := ms.srecv.Decode(m); err != nil {
    83  				return err
    84  			}
    85  		} else {
    86  			if err := ms.crecv.Decode(m); err != nil {
    87  				return err
    88  			}
    89  		}
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  func (ms *memorySocket) Local() string {
    96  	return ms.local
    97  }
    98  
    99  func (ms *memorySocket) Remote() string {
   100  	return ms.remote
   101  }
   102  
   103  func (ms *memorySocket) Send(m *Message) error {
   104  	ctx := ms.ctx
   105  	if ms.timeout > 0 {
   106  		var cancel context.CancelFunc
   107  		ctx, cancel = context.WithTimeout(ms.ctx, ms.timeout)
   108  		defer cancel()
   109  	}
   110  
   111  	select {
   112  	case <-ctx.Done():
   113  		return ctx.Err()
   114  	case <-ms.exit:
   115  		// connection closed
   116  		return io.EOF
   117  	case <-ms.lexit:
   118  		// Server connection closed
   119  		return io.EOF
   120  	default:
   121  		if ms.server {
   122  			if err := ms.ssend.Encode(m); err != nil {
   123  				return err
   124  			}
   125  		} else {
   126  			if err := ms.csend.Encode(m); err != nil {
   127  				return err
   128  			}
   129  		}
   130  	}
   131  
   132  	return nil
   133  }
   134  
   135  func (ms *memorySocket) Close() error {
   136  	select {
   137  	case <-ms.exit:
   138  		return nil
   139  	default:
   140  		close(ms.exit)
   141  	}
   142  	return nil
   143  }
   144  
   145  func (m *memoryListener) Addr() string {
   146  	return m.addr
   147  }
   148  
   149  func (m *memoryListener) Close() error {
   150  	m.Lock()
   151  	defer m.Unlock()
   152  	select {
   153  	case <-m.exit:
   154  		return nil
   155  	default:
   156  		close(m.exit)
   157  	}
   158  	return nil
   159  }
   160  
   161  func (m *memoryListener) Accept(fn func(Socket)) error {
   162  	for {
   163  		select {
   164  		case <-m.exit:
   165  			return nil
   166  		case c := <-m.conn:
   167  			go fn(&memorySocket{
   168  				server:  true,
   169  				lexit:   c.lexit,
   170  				exit:    c.exit,
   171  				ssend:   c.ssend,
   172  				srecv:   c.srecv,
   173  				local:   c.Remote(),
   174  				remote:  c.Local(),
   175  				timeout: m.topts.Timeout,
   176  				ctx:     m.topts.Context,
   177  			})
   178  		}
   179  	}
   180  }
   181  
   182  func (m *memoryTransport) Dial(addr string, opts ...DialOption) (Client, error) {
   183  	m.RLock()
   184  	defer m.RUnlock()
   185  
   186  	listener, ok := m.listeners[addr]
   187  	if !ok {
   188  		return nil, errors.New("could not dial " + addr)
   189  	}
   190  
   191  	var options DialOptions
   192  	for _, o := range opts {
   193  		o(&options)
   194  	}
   195  
   196  	creader, swriter := io.Pipe()
   197  	sreader, cwriter := io.Pipe()
   198  
   199  	client := &memoryClient{
   200  		&memorySocket{
   201  			server: false,
   202  			csend:  gob.NewEncoder(cwriter),
   203  			crecv:  gob.NewDecoder(creader),
   204  			ssend:  gob.NewEncoder(swriter),
   205  			srecv:  gob.NewDecoder(sreader), exit: make(chan bool),
   206  			lexit:   listener.exit,
   207  			local:   addr,
   208  			remote:  addr,
   209  			timeout: m.opts.Timeout,
   210  			ctx:     m.opts.Context,
   211  		},
   212  		options,
   213  	}
   214  
   215  	// pseudo connect
   216  	select {
   217  	case <-listener.exit:
   218  		return nil, errors.New("connection error")
   219  	case listener.conn <- client.memorySocket:
   220  	}
   221  
   222  	return client, nil
   223  }
   224  
   225  func (m *memoryTransport) Listen(addr string, opts ...ListenOption) (Listener, error) {
   226  	m.Lock()
   227  	defer m.Unlock()
   228  
   229  	var options ListenOptions
   230  	for _, o := range opts {
   231  		o(&options)
   232  	}
   233  
   234  	host, port, err := net.SplitHostPort(addr)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  
   239  	addr, err = maddr.Extract(host)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  
   244  	// if zero port then randomly assign one
   245  	if len(port) > 0 && port == "0" {
   246  		i := rand.Intn(20000)
   247  		port = fmt.Sprintf("%d", 10000+i)
   248  	}
   249  
   250  	// set addr with port
   251  	addr = mnet.HostPort(addr, port)
   252  
   253  	if _, ok := m.listeners[addr]; ok {
   254  		return nil, errors.New("already listening on " + addr)
   255  	}
   256  
   257  	listener := &memoryListener{
   258  		lopts: options,
   259  		topts: m.opts,
   260  		addr:  addr,
   261  		conn:  make(chan *memorySocket),
   262  		exit:  make(chan bool),
   263  		ctx:   m.opts.Context,
   264  	}
   265  
   266  	m.listeners[addr] = listener
   267  
   268  	return listener, nil
   269  }
   270  
   271  func (m *memoryTransport) Init(opts ...Option) error {
   272  	for _, o := range opts {
   273  		o(&m.opts)
   274  	}
   275  	return nil
   276  }
   277  
   278  func (m *memoryTransport) Options() Options {
   279  	return m.opts
   280  }
   281  
   282  func (m *memoryTransport) String() string {
   283  	return "memory"
   284  }
   285  
   286  func NewMemoryTransport(opts ...Option) Transport {
   287  	var options Options
   288  
   289  	rand.Seed(time.Now().UnixNano())
   290  
   291  	for _, o := range opts {
   292  		o(&options)
   293  	}
   294  
   295  	if options.Context == nil {
   296  		options.Context = context.Background()
   297  	}
   298  
   299  	return &memoryTransport{
   300  		opts:      options,
   301  		listeners: make(map[string]*memoryListener),
   302  	}
   303  }