github.com/annchain/OG@v0.0.9/wserver/conn.go (about)

     1  // Copyright © 2019 Annchain Authors <EMAIL ADDRESS>
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package wserver
    15  
    16  import (
    17  	"errors"
    18  	"io"
    19  	"log"
    20  	"sync"
    21  	"time"
    22  
    23  	"fmt"
    24  	"github.com/google/uuid"
    25  	"github.com/gorilla/websocket"
    26  )
    27  
    28  // Conn wraps websocket.Conn with Conn. It defines to listen and read
    29  // data from Conn.
    30  type Conn struct {
    31  	Conn *websocket.Conn
    32  
    33  	AfterReadFunc   func(messageType int, r io.Reader)
    34  	BeforeCloseFunc func()
    35  
    36  	once   sync.Once
    37  	id     string
    38  	stopCh chan struct{}
    39  }
    40  
    41  // Write write p to the websocket connection. The error returned will always
    42  // be nil if success.
    43  func (c *Conn) Write(p []byte) (n int, err error) {
    44  	select {
    45  	case <-c.stopCh:
    46  		return 0, errors.New("Conn is closed, can't be written")
    47  	default:
    48  		err = c.Conn.WriteMessage(websocket.TextMessage, p)
    49  		if err != nil {
    50  			return 0, err
    51  		}
    52  		return len(p), nil
    53  	}
    54  }
    55  
    56  // GetID returns the id generated using UUID algorithm.
    57  func (c *Conn) GetID() string {
    58  	c.once.Do(func() {
    59  		u := uuid.New()
    60  		c.id = u.String()
    61  	})
    62  
    63  	return c.id
    64  }
    65  
    66  // Listen listens for receive data from websocket connection. It blocks
    67  // until websocket connection is closed.
    68  func (c *Conn) Listen() {
    69  	c.Conn.SetCloseHandler(func(code int, text string) error {
    70  		if c.BeforeCloseFunc != nil {
    71  			c.BeforeCloseFunc()
    72  		}
    73  
    74  		if err := c.Close(); err != nil {
    75  			log.Println(err)
    76  		}
    77  
    78  		message := websocket.FormatCloseMessage(code, "")
    79  		c.Conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
    80  		return nil
    81  	})
    82  
    83  	// Keeps reading from Conn util get error.
    84  ReadLoop:
    85  	for {
    86  		select {
    87  		case <-c.stopCh:
    88  			break ReadLoop
    89  		default:
    90  			messageType, r, err := c.Conn.NextReader()
    91  			if err != nil {
    92  				// TODO: handle read error maybe
    93  				break ReadLoop
    94  			}
    95  
    96  			if c.AfterReadFunc != nil {
    97  				c.AfterReadFunc(messageType, r)
    98  			}
    99  		}
   100  	}
   101  }
   102  
   103  // Close close the connection.
   104  func (c *Conn) Close() error {
   105  	select {
   106  	case <-c.stopCh:
   107  		return errors.New("Conn already been closed")
   108  	default:
   109  		c.Conn.Close()
   110  		close(c.stopCh)
   111  		return nil
   112  	}
   113  }
   114  
   115  // NewConn wraps conn.
   116  func NewConn(conn *websocket.Conn) *Conn {
   117  	return &Conn{
   118  		Conn:   conn,
   119  		stopCh: make(chan struct{}),
   120  	}
   121  }
   122  
   123  const (
   124  	EVENT_NEW_UNIT = "new_unit"
   125  )
   126  
   127  // event2Cons contains a map of map
   128  // key: event type
   129  // value: another map whose key: Conn's ID ,value: Conn
   130  type event2Cons struct {
   131  	conns map[string]map[string]*Conn
   132  	mu    sync.RWMutex
   133  }
   134  
   135  func (e *event2Cons) getFromMap(key string) (v map[string]*Conn, ok bool) {
   136  	e.mu.RLock()
   137  	defer e.mu.RUnlock()
   138  	v, ok = e.conns[key]
   139  	return
   140  }
   141  
   142  func (e *event2Cons) setMap(key string, v map[string]*Conn) {
   143  	e.mu.Lock()
   144  	defer e.mu.Unlock()
   145  	e.conns[key] = v
   146  }
   147  
   148  func NewEvent2Cons() *event2Cons {
   149  	return &event2Cons{
   150  		conns: make(map[string]map[string]*Conn),
   151  	}
   152  }
   153  func (e *event2Cons) Add(eventType string, conn *Conn) error {
   154  	conns, ok := e.getFromMap(eventType)
   155  	if !ok {
   156  		conns = make(map[string]*Conn)
   157  		e.setMap(eventType, conns)
   158  	}
   159  	thisID := conn.GetID()
   160  	if _, ok := e.getFromMap(thisID); !ok {
   161  		//not exist,add it
   162  		v, _ := e.getFromMap(eventType)
   163  		v[thisID] = conn
   164  		e.setMap(eventType, v)
   165  	} else {
   166  		return fmt.Errorf("Conn with ID: %s already exist!", thisID)
   167  	}
   168  	return nil
   169  }
   170  
   171  func (e *event2Cons) Remove(eventType string, conn *Conn) error {
   172  	conns, ok := e.getFromMap(eventType)
   173  	if !ok {
   174  		return fmt.Errorf("No Connection with eventType: %s\n", eventType)
   175  	}
   176  	thisID := conn.GetID()
   177  	if _, ok := e.getFromMap(thisID); !ok {
   178  		return fmt.Errorf("No connection with ID: %s\n", thisID)
   179  	} else {
   180  		delete(conns, thisID)
   181  	}
   182  	return nil
   183  }
   184  
   185  func (e *event2Cons) Get(eventType string) ([]*Conn, error) {
   186  	conns, ok := e.getFromMap(eventType)
   187  	if !ok {
   188  		return nil, fmt.Errorf("No Connection with eventType: %s\n", eventType)
   189  	}
   190  	var ret []*Conn
   191  	for _, c := range conns {
   192  		ret = append(ret, c)
   193  	}
   194  	return ret, nil
   195  }
   196  
   197  func (e *event2Cons) GetWithID(eventType string, ID string) (*Conn, error) {
   198  	conns, err := e.Get(eventType)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	for _, c := range conns {
   203  		if c.GetID() == ID {
   204  			return c, nil
   205  		}
   206  	}
   207  	return nil, fmt.Errorf("No Connection with eventType: %s ID: %s\n", eventType, ID)
   208  }