github.com/annchain/OG@v0.0.9/wserver/handler.go (about)

     1  // Copyright © 2019 Annchain Authors <EMAIL ADDRESS>
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package wserver
    15  
    16  import (
    17  	"encoding/json"
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"strings"
    23  
    24  	"github.com/gin-gonic/gin"
    25  	"github.com/gorilla/websocket"
    26  	"github.com/sirupsen/logrus"
    27  )
    28  
    29  // websocketHandler defines to handle websocket upgrade request.
    30  type websocketHandler struct {
    31  	// upgrader is used to upgrade request.
    32  	upgrader *websocket.Upgrader
    33  
    34  	event2Cons *event2Cons
    35  }
    36  
    37  // RegisterMessage defines message struct client send after connect
    38  // to the server.
    39  type RegisterMessage struct {
    40  	//Token string
    41  	Event string `json:"event"`
    42  }
    43  
    44  func (wh *websocketHandler) Handle(ctx *gin.Context) {
    45  	wh.ServeHTTP(ctx.Writer, ctx.Request)
    46  }
    47  
    48  // First try to upgrade connection to websocket. If success, connection will
    49  // be kept until client send close message or server drop them.
    50  func (wh *websocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    51  	wsConn, err := wh.upgrader.Upgrade(w, r, nil)
    52  	if err != nil {
    53  		return
    54  	}
    55  	defer wsConn.Close()
    56  
    57  	// handle Websocket request
    58  	conn := NewConn(wsConn)
    59  	var eventType string
    60  	conn.AfterReadFunc = func(messageType int, r io.Reader) {
    61  		var rm RegisterMessage
    62  		decoder := json.NewDecoder(r)
    63  		if err := decoder.Decode(&rm); err != nil {
    64  			logrus.WithError(err).Debug("Failed to serve request")
    65  			return
    66  		}
    67  
    68  		wh.event2Cons.Add(rm.Event, conn)
    69  		eventType = rm.Event
    70  	}
    71  	conn.BeforeCloseFunc = func() {
    72  		wh.event2Cons.Remove(eventType, conn)
    73  	}
    74  
    75  	conn.Listen()
    76  }
    77  
    78  // ErrRequestIllegal describes error when data of the request is unaccepted.
    79  var ErrRequestIllegal = errors.New("request data illegal")
    80  
    81  // pushHandler defines to handle push message request.
    82  type pushHandler struct {
    83  	// authFunc defines to authorize request. The request will proceed only
    84  	// when it returns true.
    85  	authFunc func(r *http.Request) bool
    86  
    87  	event2Cons *event2Cons
    88  }
    89  
    90  func (s *pushHandler) Handle(ctx *gin.Context) {
    91  	s.ServeHTTP(ctx.Writer, ctx.Request)
    92  }
    93  
    94  // Authorize if needed. Then decode the request and push message to each
    95  // realted websocket connection.
    96  func (s *pushHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    97  	if r.Method != http.MethodPost {
    98  		w.WriteHeader(http.StatusMethodNotAllowed)
    99  		return
   100  	}
   101  
   102  	// authorize
   103  	if s.authFunc != nil {
   104  		if ok := s.authFunc(r); !ok {
   105  			w.WriteHeader(http.StatusUnauthorized)
   106  			return
   107  		}
   108  	}
   109  
   110  	// read request
   111  	var pm PushMessage
   112  	decoder := json.NewDecoder(r.Body)
   113  	if err := decoder.Decode(&pm); err != nil {
   114  		w.WriteHeader(http.StatusBadRequest)
   115  		w.Write([]byte(ErrRequestIllegal.Error()))
   116  		return
   117  	}
   118  
   119  	// validate the data
   120  	if pm.Event == "" || pm.Message == "" {
   121  		w.WriteHeader(http.StatusBadRequest)
   122  		w.Write([]byte(ErrRequestIllegal.Error()))
   123  		return
   124  	}
   125  
   126  	cnt, err := s.push(pm.Event, pm.Message)
   127  	if err != nil {
   128  		w.WriteHeader(http.StatusInternalServerError)
   129  		w.Write([]byte(err.Error()))
   130  		return
   131  	}
   132  
   133  	result := strings.NewReader(fmt.Sprintf("message sent to %d clients", cnt))
   134  	io.Copy(w, result)
   135  }
   136  
   137  func (s *pushHandler) push(event, message string) (int, error) {
   138  	if event == "" || message == "" {
   139  		return 0, errors.New("parameters(userId, event, message) can't be empty")
   140  	}
   141  
   142  	conns, err := s.event2Cons.Get(event)
   143  	if err != nil {
   144  		return 0, fmt.Errorf("Get conns with eventType: %s failed!\n", event)
   145  	}
   146  	cnt := 0
   147  	for i := range conns {
   148  		_, err := conns[i].Write([]byte(message))
   149  		if err != nil {
   150  			s.event2Cons.Remove(event, conns[i])
   151  			continue
   152  		}
   153  		cnt++
   154  	}
   155  
   156  	return cnt, nil
   157  }
   158  
   159  // PushMessage defines message struct send by client to push to each connected
   160  // websocket client.
   161  type PushMessage struct {
   162  	//UserID  string `json:"userId"`
   163  	Event   string `json:"event"`
   164  	Message string `json:"message"`
   165  }