go-micro.dev/v5@v5.12.0/transport/nats/nats.go (about) 1 // Package nats provides a NATS transport 2 package nats 3 4 import ( 5 "context" 6 "errors" 7 "io" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/nats-io/nats.go" 13 "go-micro.dev/v5/codec/json" 14 "go-micro.dev/v5/server" 15 "go-micro.dev/v5/transport" 16 ) 17 18 type ntport struct { 19 addrs []string 20 opts transport.Options 21 nopts nats.Options 22 } 23 24 type ntportClient struct { 25 conn *nats.Conn 26 addr string 27 id string 28 local string 29 remote string 30 sub *nats.Subscription 31 opts transport.Options 32 } 33 34 type ntportSocket struct { 35 conn *nats.Conn 36 m *nats.Msg 37 r chan *nats.Msg 38 39 close chan bool 40 41 sync.Mutex 42 bl []*nats.Msg 43 44 opts transport.Options 45 local string 46 remote string 47 } 48 49 type ntportListener struct { 50 conn *nats.Conn 51 addr string 52 exit chan bool 53 54 sync.RWMutex 55 so map[string]*ntportSocket 56 57 opts transport.Options 58 } 59 60 var ( 61 DefaultTimeout = time.Minute 62 ) 63 64 func configure(n *ntport, opts ...transport.Option) { 65 for _, o := range opts { 66 o(&n.opts) 67 } 68 69 natsOptions := nats.GetDefaultOptions() 70 if n, ok := n.opts.Context.Value(optionsKey{}).(nats.Options); ok { 71 natsOptions = n 72 } 73 74 // transport.Options have higher priority than nats.Options 75 // only if Addrs, Secure or TLSConfig were not set through a transport.Option 76 // we read them from nats.Option 77 if len(n.opts.Addrs) == 0 { 78 n.opts.Addrs = natsOptions.Servers 79 } 80 81 if !n.opts.Secure { 82 n.opts.Secure = natsOptions.Secure 83 } 84 85 if n.opts.TLSConfig == nil { 86 n.opts.TLSConfig = natsOptions.TLSConfig 87 } 88 89 // check & add nats:// prefix (this makes also sure that the addresses 90 // stored in natsRegistry.addrs and options.Addrs are identical) 91 n.opts.Addrs = setAddrs(n.opts.Addrs) 92 n.nopts = natsOptions 93 n.addrs = n.opts.Addrs 94 } 95 96 func setAddrs(addrs []string) []string { 97 cAddrs := make([]string, 0, len(addrs)) 98 for _, addr := range addrs { 99 if len(addr) == 0 { 100 continue 101 } 102 if !strings.HasPrefix(addr, "nats://") { 103 addr = "nats://" + addr 104 } 105 cAddrs = append(cAddrs, addr) 106 } 107 if len(cAddrs) == 0 { 108 cAddrs = []string{nats.DefaultURL} 109 } 110 return cAddrs 111 } 112 113 func (n *ntportClient) Local() string { 114 return n.local 115 } 116 117 func (n *ntportClient) Remote() string { 118 return n.remote 119 } 120 121 func (n *ntportClient) Send(m *transport.Message) error { 122 b, err := n.opts.Codec.Marshal(m) 123 if err != nil { 124 return err 125 } 126 127 // no deadline 128 if n.opts.Timeout == time.Duration(0) { 129 return n.conn.PublishRequest(n.addr, n.id, b) 130 } 131 132 // use the deadline 133 ch := make(chan error, 1) 134 135 go func() { 136 ch <- n.conn.PublishRequest(n.addr, n.id, b) 137 }() 138 139 select { 140 case err := <-ch: 141 return err 142 case <-time.After(n.opts.Timeout): 143 return errors.New("deadline exceeded") 144 } 145 } 146 147 func (n *ntportClient) Recv(m *transport.Message) error { 148 timeout := time.Second * 10 149 if n.opts.Timeout > time.Duration(0) { 150 timeout = n.opts.Timeout 151 } 152 153 rsp, err := n.sub.NextMsg(timeout) 154 if err != nil { 155 return err 156 } 157 158 var mr transport.Message 159 if err := n.opts.Codec.Unmarshal(rsp.Data, &mr); err != nil { 160 return err 161 } 162 163 *m = mr 164 return nil 165 } 166 167 func (n *ntportClient) Close() error { 168 n.sub.Unsubscribe() 169 n.conn.Close() 170 return nil 171 } 172 173 func (n *ntportSocket) Local() string { 174 return n.local 175 } 176 177 func (n *ntportSocket) Remote() string { 178 return n.remote 179 } 180 181 func (n *ntportSocket) Recv(m *transport.Message) error { 182 if m == nil { 183 return errors.New("message passed in is nil") 184 } 185 186 var r *nats.Msg 187 var ok bool 188 189 // if there's a deadline we use it 190 if n.opts.Timeout > time.Duration(0) { 191 select { 192 case r, ok = <-n.r: 193 case <-time.After(n.opts.Timeout): 194 return errors.New("deadline exceeded") 195 } 196 } else { 197 r, ok = <-n.r 198 } 199 200 if !ok { 201 return io.EOF 202 } 203 204 n.Lock() 205 if len(n.bl) > 0 { 206 select { 207 case n.r <- n.bl[0]: 208 n.bl = n.bl[1:] 209 default: 210 } 211 } 212 n.Unlock() 213 214 if err := n.opts.Codec.Unmarshal(r.Data, m); err != nil { 215 return err 216 } 217 return nil 218 } 219 220 func (n *ntportSocket) Send(m *transport.Message) error { 221 b, err := n.opts.Codec.Marshal(m) 222 if err != nil { 223 return err 224 } 225 226 // no deadline 227 if n.opts.Timeout == time.Duration(0) { 228 return n.conn.Publish(n.m.Reply, b) 229 } 230 231 // use the deadline 232 ch := make(chan error, 1) 233 234 go func() { 235 ch <- n.conn.Publish(n.m.Reply, b) 236 }() 237 238 select { 239 case err := <-ch: 240 return err 241 case <-time.After(n.opts.Timeout): 242 return errors.New("deadline exceeded") 243 } 244 } 245 246 func (n *ntportSocket) Close() error { 247 select { 248 case <-n.close: 249 return nil 250 default: 251 close(n.close) 252 } 253 return nil 254 } 255 256 func (n *ntportListener) Addr() string { 257 return n.addr 258 } 259 260 func (n *ntportListener) Close() error { 261 n.exit <- true 262 n.conn.Close() 263 return nil 264 } 265 266 func (n *ntportListener) Accept(fn func(transport.Socket)) error { 267 s, err := n.conn.SubscribeSync(n.addr) 268 if err != nil { 269 return err 270 } 271 272 go func() { 273 <-n.exit 274 s.Unsubscribe() 275 }() 276 277 for { 278 m, err := s.NextMsg(time.Minute) 279 if err != nil && err == nats.ErrTimeout { 280 continue 281 } else if err != nil { 282 return err 283 } 284 285 n.RLock() 286 sock, ok := n.so[m.Reply] 287 n.RUnlock() 288 289 if !ok { 290 sock = &ntportSocket{ 291 conn: n.conn, 292 m: m, 293 r: make(chan *nats.Msg, 1), 294 close: make(chan bool), 295 opts: n.opts, 296 local: n.Addr(), 297 remote: m.Reply, 298 } 299 n.Lock() 300 n.so[m.Reply] = sock 301 n.Unlock() 302 303 go func() { 304 // TODO: think of a better error response strategy 305 defer func() { 306 if r := recover(); r != nil { 307 sock.Close() 308 } 309 }() 310 fn(sock) 311 }() 312 313 go func() { 314 <-sock.close 315 n.Lock() 316 delete(n.so, sock.m.Reply) 317 n.Unlock() 318 }() 319 } 320 321 select { 322 case <-sock.close: 323 continue 324 default: 325 } 326 327 sock.Lock() 328 sock.bl = append(sock.bl, m) 329 select { 330 case sock.r <- sock.bl[0]: 331 sock.bl = sock.bl[1:] 332 default: 333 } 334 sock.Unlock() 335 } 336 } 337 338 func (n *ntport) Dial(addr string, dialOpts ...transport.DialOption) (transport.Client, error) { 339 dopts := transport.DialOptions{ 340 Timeout: transport.DefaultDialTimeout, 341 } 342 343 for _, o := range dialOpts { 344 o(&dopts) 345 } 346 347 opts := n.nopts 348 opts.Servers = n.addrs 349 opts.Secure = n.opts.Secure 350 opts.TLSConfig = n.opts.TLSConfig 351 opts.Timeout = dopts.Timeout 352 353 // secure might not be set 354 if n.opts.TLSConfig != nil { 355 opts.Secure = true 356 } 357 358 c, err := opts.Connect() 359 if err != nil { 360 return nil, err 361 } 362 363 id := nats.NewInbox() 364 sub, err := c.SubscribeSync(id) 365 if err != nil { 366 return nil, err 367 } 368 369 return &ntportClient{ 370 conn: c, 371 addr: addr, 372 id: id, 373 sub: sub, 374 opts: n.opts, 375 local: id, 376 remote: addr, 377 }, nil 378 } 379 380 func (n *ntport) Listen(addr string, listenOpts ...transport.ListenOption) (transport.Listener, error) { 381 opts := n.nopts 382 opts.Servers = n.addrs 383 opts.Secure = n.opts.Secure 384 opts.TLSConfig = n.opts.TLSConfig 385 386 // secure might not be set 387 if n.opts.TLSConfig != nil { 388 opts.Secure = true 389 } 390 391 c, err := opts.Connect() 392 if err != nil { 393 return nil, err 394 } 395 396 // in case address has not been specifically set, create a new nats.Inbox() 397 if addr == server.DefaultAddress { 398 addr = nats.NewInbox() 399 } 400 401 // make sure addr subject is not empty 402 if len(addr) == 0 { 403 return nil, errors.New("addr (nats subject) must not be empty") 404 } 405 406 // since NATS implements a text based protocol, no space characters are 407 // admitted in the addr (subject name) 408 if strings.Contains(addr, " ") { 409 return nil, errors.New("addr (nats subject) must not contain space characters") 410 } 411 412 return &ntportListener{ 413 addr: addr, 414 conn: c, 415 exit: make(chan bool, 1), 416 so: make(map[string]*ntportSocket), 417 opts: n.opts, 418 }, nil 419 } 420 421 func (n *ntport) Init(opts ...transport.Option) error { 422 configure(n, opts...) 423 return nil 424 } 425 426 func (n *ntport) Options() transport.Options { 427 return n.opts 428 } 429 430 func (n *ntport) String() string { 431 return "nats" 432 } 433 434 func NewTransport(opts ...transport.Option) transport.Transport { 435 options := transport.Options{ 436 // Default codec 437 Codec: json.Marshaler{}, 438 Timeout: DefaultTimeout, 439 Context: context.Background(), 440 } 441 442 nt := &ntport{ 443 opts: options, 444 } 445 configure(nt, opts...) 446 return nt 447 }