github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/wsclient/wstestserver.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package wsclient
    18  
    19  import (
    20  	"fmt"
    21  	"net/http"
    22  	"net/http/httptest"
    23  
    24  	"github.com/gorilla/websocket"
    25  )
    26  
    27  // NewTestWSServer creates a little test server for packages (including wsclient itself) to use in unit tests
    28  func NewTestWSServer(testReq func(req *http.Request)) (toServer, fromServer chan string, url string, done func()) {
    29  	upgrader := &websocket.Upgrader{WriteBufferSize: 1024, ReadBufferSize: 1024}
    30  	toServer = make(chan string, 1)
    31  	fromServer = make(chan string, 1)
    32  	sendDone := make(chan struct{})
    33  	receiveDone := make(chan struct{})
    34  	connected := false
    35  	svr := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
    36  		if testReq != nil {
    37  			testReq(req)
    38  		}
    39  		ws, _ := upgrader.Upgrade(res, req, http.Header{})
    40  		go func() {
    41  			defer close(receiveDone)
    42  			for {
    43  				_, data, err := ws.ReadMessage()
    44  				if err != nil {
    45  					return
    46  				}
    47  				toServer <- string(data)
    48  			}
    49  		}()
    50  		go func() {
    51  			defer close(sendDone)
    52  			defer ws.Close()
    53  			for data := range fromServer {
    54  				_ = ws.WriteMessage(websocket.TextMessage, []byte(data))
    55  			}
    56  		}()
    57  		connected = true
    58  	}))
    59  	return toServer, fromServer, fmt.Sprintf("ws://%s", svr.Listener.Addr()), func() {
    60  		close(fromServer)
    61  		svr.Close()
    62  		if connected {
    63  			<-sendDone
    64  			<-receiveDone
    65  		}
    66  	}
    67  }