github.com/annwntech/go-micro/v2@v2.9.5/tunnel/session.go (about) 1 package tunnel 2 3 import ( 4 "crypto/cipher" 5 "encoding/base32" 6 "io" 7 "sync" 8 "time" 9 10 "github.com/annwntech/go-micro/v2/logger" 11 "github.com/annwntech/go-micro/v2/transport" 12 ) 13 14 // session is our pseudo session for transport.Socket 15 type session struct { 16 // the tunnel id 17 tunnel string 18 // the channel name 19 channel string 20 // the session id based on Micro.Tunnel-Session 21 session string 22 // token is the session token 23 token string 24 // closed 25 closed chan bool 26 // remote addr 27 remote string 28 // local addr 29 local string 30 // send chan 31 send chan *message 32 // recv chan 33 recv chan *message 34 // if the discovery worked 35 discovered bool 36 // if the session was accepted 37 accepted bool 38 // outbound marks the session as outbound dialled connection 39 outbound bool 40 // lookback marks the session as a loopback on the inbound 41 loopback bool 42 // mode of the connection 43 mode Mode 44 // the dial timeout 45 dialTimeout time.Duration 46 // the read timeout 47 readTimeout time.Duration 48 // the link on which this message was received 49 link string 50 // the error response 51 errChan chan error 52 // key for session encryption 53 key []byte 54 // cipher for session 55 gcm cipher.AEAD 56 sync.RWMutex 57 } 58 59 // message is sent over the send channel 60 type message struct { 61 // type of message 62 typ string 63 // tunnel id 64 tunnel string 65 // channel name 66 channel string 67 // the session id 68 session string 69 // outbound marks the message as outbound 70 outbound bool 71 // loopback marks the message intended for loopback 72 loopback bool 73 // mode of the connection 74 mode Mode 75 // the link to send the message on 76 link string 77 // transport data 78 data *transport.Message 79 // the error channel 80 errChan chan error 81 } 82 83 func (s *session) Remote() string { 84 return s.remote 85 } 86 87 func (s *session) Local() string { 88 return s.local 89 } 90 91 func (s *session) Link() string { 92 return s.link 93 } 94 95 func (s *session) Id() string { 96 return s.session 97 } 98 99 func (s *session) Channel() string { 100 return s.channel 101 } 102 103 // newMessage creates a new message based on the session 104 func (s *session) newMessage(typ string) *message { 105 return &message{ 106 typ: typ, 107 tunnel: s.tunnel, 108 channel: s.channel, 109 session: s.session, 110 outbound: s.outbound, 111 loopback: s.loopback, 112 mode: s.mode, 113 link: s.link, 114 errChan: s.errChan, 115 } 116 } 117 118 func (s *session) sendMsg(msg *message) error { 119 select { 120 case <-s.closed: 121 return io.EOF 122 case s.send <- msg: 123 return nil 124 } 125 } 126 127 func (s *session) wait(msg *message) error { 128 // wait for an error response 129 select { 130 case err := <-msg.errChan: 131 if err != nil { 132 return err 133 } 134 case <-s.closed: 135 return io.EOF 136 } 137 138 return nil 139 } 140 141 // waitFor waits for the message type required until the timeout specified 142 func (s *session) waitFor(msgType string, timeout time.Duration) (*message, error) { 143 now := time.Now() 144 145 after := func(timeout time.Duration) <-chan time.Time { 146 if timeout < time.Duration(0) { 147 return nil 148 } 149 150 // get the delta 151 d := time.Since(now) 152 153 // dial timeout minus time since 154 wait := timeout - d 155 156 if wait < time.Duration(0) { 157 wait = time.Duration(0) 158 } 159 160 return time.After(wait) 161 } 162 163 // wait for the message type 164 for { 165 select { 166 case msg := <-s.recv: 167 // there may be no message type 168 if len(msgType) == 0 { 169 return msg, nil 170 } 171 172 // ignore what we don't want 173 if msg.typ != msgType { 174 if logger.V(logger.DebugLevel, log) { 175 log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) 176 } 177 continue 178 } 179 180 // got the message 181 return msg, nil 182 case <-after(timeout): 183 return nil, ErrReadTimeout 184 case <-s.closed: 185 // check pending message queue 186 select { 187 case msg := <-s.recv: 188 // there may be no message type 189 if len(msgType) == 0 { 190 return msg, nil 191 } 192 193 // ignore what we don't want 194 if msg.typ != msgType { 195 if logger.V(logger.DebugLevel, log) { 196 log.Debugf("Tunnel received non %s message in waiting for %s", msg.typ, msgType) 197 } 198 continue 199 } 200 201 // got the message 202 return msg, nil 203 default: 204 // non blocking 205 } 206 return nil, io.EOF 207 } 208 } 209 } 210 211 // Discover attempts to discover the link for a specific channel. 212 // This is only used by the tunnel.Dial when first connecting. 213 func (s *session) Discover() error { 214 // create a new discovery message for this channel 215 msg := s.newMessage("discover") 216 // broadcast the message to all links 217 msg.mode = Broadcast 218 // its an outbound connection since we're dialling 219 msg.outbound = true 220 // don't set the link since we don't know where it is 221 msg.link = "" 222 223 // if multicast then set that as session 224 if s.mode == Multicast { 225 msg.session = "multicast" 226 } 227 228 // send discover message 229 if err := s.sendMsg(msg); err != nil { 230 return err 231 } 232 233 // set time now 234 now := time.Now() 235 236 // after strips down the dial timeout 237 after := func() time.Duration { 238 d := time.Since(now) 239 // dial timeout minus time since 240 wait := s.dialTimeout - d 241 // make sure its always > 0 242 if wait < time.Duration(0) { 243 return time.Duration(0) 244 } 245 return wait 246 } 247 248 // the discover message is sent out, now 249 // wait to hear back about the sent message 250 select { 251 case <-time.After(after()): 252 return ErrDialTimeout 253 case err := <-s.errChan: 254 if err != nil { 255 return err 256 } 257 } 258 259 // bail early if its not unicast 260 // we don't need to wait for the announce 261 if s.mode != Unicast { 262 s.discovered = true 263 s.accepted = true 264 return nil 265 } 266 267 // wait for announce 268 _, err := s.waitFor("announce", after()) 269 if err != nil { 270 return err 271 } 272 273 // set discovered 274 s.discovered = true 275 276 return nil 277 } 278 279 // Open will fire the open message for the session. This is called by the dialler. 280 // This is to indicate that we want to create a new session. 281 func (s *session) Open() error { 282 // create a new message 283 msg := s.newMessage("open") 284 285 // send open message 286 if err := s.sendMsg(msg); err != nil { 287 return err 288 } 289 290 // wait for an error response for send 291 if err := s.wait(msg); err != nil { 292 return err 293 } 294 295 // now wait for the accept message to be returned 296 msg, err := s.waitFor("accept", s.dialTimeout) 297 if err != nil { 298 return err 299 } 300 301 // set to accepted 302 s.accepted = true 303 // set link 304 s.link = msg.link 305 306 return nil 307 } 308 309 // Accept sends the accept response to an open message from a dialled connection 310 func (s *session) Accept() error { 311 msg := s.newMessage("accept") 312 313 // send the accept message 314 if err := s.sendMsg(msg); err != nil { 315 return err 316 } 317 318 // wait for send response 319 return s.wait(msg) 320 } 321 322 // Announce sends an announcement to notify that this session exists. 323 // This is primarily used by the listener. 324 func (s *session) Announce() error { 325 msg := s.newMessage("announce") 326 // we don't need an error back 327 msg.errChan = nil 328 // announce to all 329 msg.mode = Broadcast 330 // we don't need the link 331 msg.link = "" 332 333 // send announce message 334 return s.sendMsg(msg) 335 } 336 337 // Send is used to send a message 338 func (s *session) Send(m *transport.Message) error { 339 var err error 340 341 s.RLock() 342 gcm := s.gcm 343 s.RUnlock() 344 345 if gcm == nil { 346 gcm, err = newCipher(s.key) 347 if err != nil { 348 return err 349 } 350 s.Lock() 351 s.gcm = gcm 352 s.Unlock() 353 } 354 // encrypt the transport message payload 355 body, err := Encrypt(gcm, m.Body) 356 if err != nil { 357 log.Debugf("failed to encrypt message body: %v", err) 358 return err 359 } 360 361 // make copy, without rehash and realloc 362 data := &transport.Message{ 363 Header: make(map[string]string, len(m.Header)), 364 Body: body, 365 } 366 367 // encrypt all the headers 368 for k, v := range m.Header { 369 // encrypt the transport message payload 370 val, err := Encrypt(s.gcm, []byte(v)) 371 if err != nil { 372 log.Debugf("failed to encrypt message header %s: %v", k, err) 373 return err 374 } 375 // add the encrypted header value 376 data.Header[k] = base32.StdEncoding.EncodeToString(val) 377 } 378 379 // create a new message 380 msg := s.newMessage("session") 381 // set the data 382 msg.data = data 383 384 // if multicast don't set the link 385 if s.mode != Unicast { 386 msg.link = "" 387 } 388 389 if logger.V(logger.TraceLevel, log) { 390 log.Tracef("Appending to send backlog: %v", msg) 391 } 392 // send the actual message 393 if err := s.sendMsg(msg); err != nil { 394 return err 395 } 396 397 // wait for an error response 398 return s.wait(msg) 399 } 400 401 // Recv is used to receive a message 402 func (s *session) Recv(m *transport.Message) error { 403 var msg *message 404 405 msg, err := s.waitFor("", s.readTimeout) 406 if err != nil { 407 return err 408 } 409 410 // check the error if one exists 411 select { 412 case err := <-msg.errChan: 413 return err 414 default: 415 } 416 417 if logger.V(logger.TraceLevel, log) { 418 log.Tracef("Received from recv backlog: %v", msg) 419 } 420 421 gcm, err := newCipher([]byte(s.token + s.channel + msg.session)) 422 if err != nil { 423 if logger.V(logger.ErrorLevel, log) { 424 log.Errorf("unable to create cipher: %v", err) 425 } 426 return err 427 } 428 429 // decrypt the received payload using the token 430 // we have to used msg.session because multicast has a shared 431 // session id of "multicast" in this session struct on 432 // the listener side 433 msg.data.Body, err = Decrypt(gcm, msg.data.Body) 434 if err != nil { 435 if logger.V(logger.DebugLevel, log) { 436 log.Debugf("failed to decrypt message body: %v", err) 437 } 438 return err 439 } 440 441 // dencrypt all the headers 442 for k, v := range msg.data.Header { 443 // decode the header values 444 h, err := base32.StdEncoding.DecodeString(v) 445 if err != nil { 446 if logger.V(logger.DebugLevel, log) { 447 log.Debugf("failed to decode message header %s: %v", k, err) 448 } 449 return err 450 } 451 452 // dencrypt the transport message payload 453 val, err := Decrypt(gcm, h) 454 if err != nil { 455 if logger.V(logger.DebugLevel, log) { 456 log.Debugf("failed to decrypt message header %s: %v", k, err) 457 } 458 return err 459 } 460 // add decrypted header value 461 msg.data.Header[k] = string(val) 462 } 463 464 // set the link 465 // TODO: decruft, this is only for multicast 466 // since the session is now a single session 467 // likely provide as part of message.Link() 468 msg.data.Header["Micro-Link"] = msg.link 469 470 // set message 471 *m = *msg.data 472 // return nil 473 return nil 474 } 475 476 // Close closes the session by sending a close message 477 func (s *session) Close() error { 478 select { 479 case <-s.closed: 480 // no op 481 default: 482 close(s.closed) 483 484 // don't send close on multicast or broadcast 485 if s.mode != Unicast { 486 return nil 487 } 488 489 // append to backlog 490 msg := s.newMessage("close") 491 // no error response on close 492 msg.errChan = nil 493 494 // send the close message 495 select { 496 case s.send <- msg: 497 case <-time.After(time.Millisecond * 10): 498 } 499 } 500 501 return nil 502 }