go-micro.dev/v5@v5.12.0/transport/nats/nats.go (about)

     1  // Package nats provides a NATS transport
     2  package nats
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/nats-io/nats.go"
    13  	"go-micro.dev/v5/codec/json"
    14  	"go-micro.dev/v5/server"
    15  	"go-micro.dev/v5/transport"
    16  )
    17  
    18  type ntport struct {
    19  	addrs []string
    20  	opts  transport.Options
    21  	nopts nats.Options
    22  }
    23  
    24  type ntportClient struct {
    25  	conn   *nats.Conn
    26  	addr   string
    27  	id     string
    28  	local  string
    29  	remote string
    30  	sub    *nats.Subscription
    31  	opts   transport.Options
    32  }
    33  
    34  type ntportSocket struct {
    35  	conn *nats.Conn
    36  	m    *nats.Msg
    37  	r    chan *nats.Msg
    38  
    39  	close chan bool
    40  
    41  	sync.Mutex
    42  	bl []*nats.Msg
    43  
    44  	opts   transport.Options
    45  	local  string
    46  	remote string
    47  }
    48  
    49  type ntportListener struct {
    50  	conn *nats.Conn
    51  	addr string
    52  	exit chan bool
    53  
    54  	sync.RWMutex
    55  	so map[string]*ntportSocket
    56  
    57  	opts transport.Options
    58  }
    59  
    60  var (
    61  	DefaultTimeout = time.Minute
    62  )
    63  
    64  func configure(n *ntport, opts ...transport.Option) {
    65  	for _, o := range opts {
    66  		o(&n.opts)
    67  	}
    68  
    69  	natsOptions := nats.GetDefaultOptions()
    70  	if n, ok := n.opts.Context.Value(optionsKey{}).(nats.Options); ok {
    71  		natsOptions = n
    72  	}
    73  
    74  	// transport.Options have higher priority than nats.Options
    75  	// only if Addrs, Secure or TLSConfig were not set through a transport.Option
    76  	// we read them from nats.Option
    77  	if len(n.opts.Addrs) == 0 {
    78  		n.opts.Addrs = natsOptions.Servers
    79  	}
    80  
    81  	if !n.opts.Secure {
    82  		n.opts.Secure = natsOptions.Secure
    83  	}
    84  
    85  	if n.opts.TLSConfig == nil {
    86  		n.opts.TLSConfig = natsOptions.TLSConfig
    87  	}
    88  
    89  	// check & add nats:// prefix (this makes also sure that the addresses
    90  	// stored in natsRegistry.addrs and options.Addrs are identical)
    91  	n.opts.Addrs = setAddrs(n.opts.Addrs)
    92  	n.nopts = natsOptions
    93  	n.addrs = n.opts.Addrs
    94  }
    95  
    96  func setAddrs(addrs []string) []string {
    97  	cAddrs := make([]string, 0, len(addrs))
    98  	for _, addr := range addrs {
    99  		if len(addr) == 0 {
   100  			continue
   101  		}
   102  		if !strings.HasPrefix(addr, "nats://") {
   103  			addr = "nats://" + addr
   104  		}
   105  		cAddrs = append(cAddrs, addr)
   106  	}
   107  	if len(cAddrs) == 0 {
   108  		cAddrs = []string{nats.DefaultURL}
   109  	}
   110  	return cAddrs
   111  }
   112  
   113  func (n *ntportClient) Local() string {
   114  	return n.local
   115  }
   116  
   117  func (n *ntportClient) Remote() string {
   118  	return n.remote
   119  }
   120  
   121  func (n *ntportClient) Send(m *transport.Message) error {
   122  	b, err := n.opts.Codec.Marshal(m)
   123  	if err != nil {
   124  		return err
   125  	}
   126  
   127  	// no deadline
   128  	if n.opts.Timeout == time.Duration(0) {
   129  		return n.conn.PublishRequest(n.addr, n.id, b)
   130  	}
   131  
   132  	// use the deadline
   133  	ch := make(chan error, 1)
   134  
   135  	go func() {
   136  		ch <- n.conn.PublishRequest(n.addr, n.id, b)
   137  	}()
   138  
   139  	select {
   140  	case err := <-ch:
   141  		return err
   142  	case <-time.After(n.opts.Timeout):
   143  		return errors.New("deadline exceeded")
   144  	}
   145  }
   146  
   147  func (n *ntportClient) Recv(m *transport.Message) error {
   148  	timeout := time.Second * 10
   149  	if n.opts.Timeout > time.Duration(0) {
   150  		timeout = n.opts.Timeout
   151  	}
   152  
   153  	rsp, err := n.sub.NextMsg(timeout)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	var mr transport.Message
   159  	if err := n.opts.Codec.Unmarshal(rsp.Data, &mr); err != nil {
   160  		return err
   161  	}
   162  
   163  	*m = mr
   164  	return nil
   165  }
   166  
   167  func (n *ntportClient) Close() error {
   168  	n.sub.Unsubscribe()
   169  	n.conn.Close()
   170  	return nil
   171  }
   172  
   173  func (n *ntportSocket) Local() string {
   174  	return n.local
   175  }
   176  
   177  func (n *ntportSocket) Remote() string {
   178  	return n.remote
   179  }
   180  
   181  func (n *ntportSocket) Recv(m *transport.Message) error {
   182  	if m == nil {
   183  		return errors.New("message passed in is nil")
   184  	}
   185  
   186  	var r *nats.Msg
   187  	var ok bool
   188  
   189  	// if there's a deadline we use it
   190  	if n.opts.Timeout > time.Duration(0) {
   191  		select {
   192  		case r, ok = <-n.r:
   193  		case <-time.After(n.opts.Timeout):
   194  			return errors.New("deadline exceeded")
   195  		}
   196  	} else {
   197  		r, ok = <-n.r
   198  	}
   199  
   200  	if !ok {
   201  		return io.EOF
   202  	}
   203  
   204  	n.Lock()
   205  	if len(n.bl) > 0 {
   206  		select {
   207  		case n.r <- n.bl[0]:
   208  			n.bl = n.bl[1:]
   209  		default:
   210  		}
   211  	}
   212  	n.Unlock()
   213  
   214  	if err := n.opts.Codec.Unmarshal(r.Data, m); err != nil {
   215  		return err
   216  	}
   217  	return nil
   218  }
   219  
   220  func (n *ntportSocket) Send(m *transport.Message) error {
   221  	b, err := n.opts.Codec.Marshal(m)
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	// no deadline
   227  	if n.opts.Timeout == time.Duration(0) {
   228  		return n.conn.Publish(n.m.Reply, b)
   229  	}
   230  
   231  	// use the deadline
   232  	ch := make(chan error, 1)
   233  
   234  	go func() {
   235  		ch <- n.conn.Publish(n.m.Reply, b)
   236  	}()
   237  
   238  	select {
   239  	case err := <-ch:
   240  		return err
   241  	case <-time.After(n.opts.Timeout):
   242  		return errors.New("deadline exceeded")
   243  	}
   244  }
   245  
   246  func (n *ntportSocket) Close() error {
   247  	select {
   248  	case <-n.close:
   249  		return nil
   250  	default:
   251  		close(n.close)
   252  	}
   253  	return nil
   254  }
   255  
   256  func (n *ntportListener) Addr() string {
   257  	return n.addr
   258  }
   259  
   260  func (n *ntportListener) Close() error {
   261  	n.exit <- true
   262  	n.conn.Close()
   263  	return nil
   264  }
   265  
   266  func (n *ntportListener) Accept(fn func(transport.Socket)) error {
   267  	s, err := n.conn.SubscribeSync(n.addr)
   268  	if err != nil {
   269  		return err
   270  	}
   271  
   272  	go func() {
   273  		<-n.exit
   274  		s.Unsubscribe()
   275  	}()
   276  
   277  	for {
   278  		m, err := s.NextMsg(time.Minute)
   279  		if err != nil && err == nats.ErrTimeout {
   280  			continue
   281  		} else if err != nil {
   282  			return err
   283  		}
   284  
   285  		n.RLock()
   286  		sock, ok := n.so[m.Reply]
   287  		n.RUnlock()
   288  
   289  		if !ok {
   290  			sock = &ntportSocket{
   291  				conn:   n.conn,
   292  				m:      m,
   293  				r:      make(chan *nats.Msg, 1),
   294  				close:  make(chan bool),
   295  				opts:   n.opts,
   296  				local:  n.Addr(),
   297  				remote: m.Reply,
   298  			}
   299  			n.Lock()
   300  			n.so[m.Reply] = sock
   301  			n.Unlock()
   302  
   303  			go func() {
   304  				// TODO: think of a better error response strategy
   305  				defer func() {
   306  					if r := recover(); r != nil {
   307  						sock.Close()
   308  					}
   309  				}()
   310  				fn(sock)
   311  			}()
   312  
   313  			go func() {
   314  				<-sock.close
   315  				n.Lock()
   316  				delete(n.so, sock.m.Reply)
   317  				n.Unlock()
   318  			}()
   319  		}
   320  
   321  		select {
   322  		case <-sock.close:
   323  			continue
   324  		default:
   325  		}
   326  
   327  		sock.Lock()
   328  		sock.bl = append(sock.bl, m)
   329  		select {
   330  		case sock.r <- sock.bl[0]:
   331  			sock.bl = sock.bl[1:]
   332  		default:
   333  		}
   334  		sock.Unlock()
   335  	}
   336  }
   337  
   338  func (n *ntport) Dial(addr string, dialOpts ...transport.DialOption) (transport.Client, error) {
   339  	dopts := transport.DialOptions{
   340  		Timeout: transport.DefaultDialTimeout,
   341  	}
   342  
   343  	for _, o := range dialOpts {
   344  		o(&dopts)
   345  	}
   346  
   347  	opts := n.nopts
   348  	opts.Servers = n.addrs
   349  	opts.Secure = n.opts.Secure
   350  	opts.TLSConfig = n.opts.TLSConfig
   351  	opts.Timeout = dopts.Timeout
   352  
   353  	// secure might not be set
   354  	if n.opts.TLSConfig != nil {
   355  		opts.Secure = true
   356  	}
   357  
   358  	c, err := opts.Connect()
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  
   363  	id := nats.NewInbox()
   364  	sub, err := c.SubscribeSync(id)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	return &ntportClient{
   370  		conn:   c,
   371  		addr:   addr,
   372  		id:     id,
   373  		sub:    sub,
   374  		opts:   n.opts,
   375  		local:  id,
   376  		remote: addr,
   377  	}, nil
   378  }
   379  
   380  func (n *ntport) Listen(addr string, listenOpts ...transport.ListenOption) (transport.Listener, error) {
   381  	opts := n.nopts
   382  	opts.Servers = n.addrs
   383  	opts.Secure = n.opts.Secure
   384  	opts.TLSConfig = n.opts.TLSConfig
   385  
   386  	// secure might not be set
   387  	if n.opts.TLSConfig != nil {
   388  		opts.Secure = true
   389  	}
   390  
   391  	c, err := opts.Connect()
   392  	if err != nil {
   393  		return nil, err
   394  	}
   395  
   396  	// in case address has not been specifically set, create a new nats.Inbox()
   397  	if addr == server.DefaultAddress {
   398  		addr = nats.NewInbox()
   399  	}
   400  
   401  	// make sure addr subject is not empty
   402  	if len(addr) == 0 {
   403  		return nil, errors.New("addr (nats subject) must not be empty")
   404  	}
   405  
   406  	// since NATS implements a text based protocol, no space characters are
   407  	// admitted in the addr (subject name)
   408  	if strings.Contains(addr, " ") {
   409  		return nil, errors.New("addr (nats subject) must not contain space characters")
   410  	}
   411  
   412  	return &ntportListener{
   413  		addr: addr,
   414  		conn: c,
   415  		exit: make(chan bool, 1),
   416  		so:   make(map[string]*ntportSocket),
   417  		opts: n.opts,
   418  	}, nil
   419  }
   420  
   421  func (n *ntport) Init(opts ...transport.Option) error {
   422  	configure(n, opts...)
   423  	return nil
   424  }
   425  
   426  func (n *ntport) Options() transport.Options {
   427  	return n.opts
   428  }
   429  
   430  func (n *ntport) String() string {
   431  	return "nats"
   432  }
   433  
   434  func NewTransport(opts ...transport.Option) transport.Transport {
   435  	options := transport.Options{
   436  		// Default codec
   437  		Codec:   json.Marshaler{},
   438  		Timeout: DefaultTimeout,
   439  		Context: context.Background(),
   440  	}
   441  
   442  	nt := &ntport{
   443  		opts: options,
   444  	}
   445  	configure(nt, opts...)
   446  	return nt
   447  }