github.com/safing/portbase@v0.19.5/api/client/websocket.go (about)

     1  package client
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/gorilla/websocket"
     8  	"github.com/tevino/abool"
     9  
    10  	"github.com/safing/portbase/log"
    11  )
    12  
    13  type wsState struct {
    14  	wsConn     *websocket.Conn
    15  	wg         sync.WaitGroup
    16  	failing    *abool.AtomicBool
    17  	failSignal chan struct{}
    18  }
    19  
    20  func (c *Client) wsConnect() error {
    21  	state := &wsState{
    22  		failing:    abool.NewBool(false),
    23  		failSignal: make(chan struct{}),
    24  	}
    25  
    26  	var err error
    27  	state.wsConn, _, err = websocket.DefaultDialer.Dial(fmt.Sprintf("ws://%s/api/database/v1", c.server), nil)
    28  	if err != nil {
    29  		return err
    30  	}
    31  
    32  	c.signalOnline()
    33  
    34  	state.wg.Add(2)
    35  	go c.wsReader(state)
    36  	go c.wsWriter(state)
    37  
    38  	// wait for end of connection
    39  	select {
    40  	case <-state.failSignal:
    41  	case <-c.shutdownSignal:
    42  		state.Error("")
    43  	}
    44  	_ = state.wsConn.Close()
    45  	state.wg.Wait()
    46  
    47  	return nil
    48  }
    49  
    50  func (c *Client) wsReader(state *wsState) {
    51  	defer state.wg.Done()
    52  	for {
    53  		_, data, err := state.wsConn.ReadMessage()
    54  		log.Tracef("client: read message")
    55  		if err != nil {
    56  			if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
    57  				state.Error(fmt.Sprintf("client: read error: %s", err))
    58  			} else {
    59  				state.Error("client: connection closed by server")
    60  			}
    61  			return
    62  		}
    63  		log.Tracef("client: received message: %s", string(data))
    64  		m, err := ParseMessage(data)
    65  		if err != nil {
    66  			log.Warningf("client: failed to parse message: %s", err)
    67  		} else {
    68  			select {
    69  			case c.recv <- m:
    70  			case <-state.failSignal:
    71  				return
    72  			}
    73  		}
    74  	}
    75  }
    76  
    77  func (c *Client) wsWriter(state *wsState) {
    78  	defer state.wg.Done()
    79  	for {
    80  		select {
    81  		case <-state.failSignal:
    82  			return
    83  		case m := <-c.resend:
    84  			data, err := m.Pack()
    85  			if err == nil {
    86  				err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
    87  			}
    88  			if err != nil {
    89  				state.Error(fmt.Sprintf("client: write error: %s", err))
    90  				return
    91  			}
    92  			log.Tracef("client: sent message: %s", string(data))
    93  			if m.sent != nil {
    94  				m.sent.Set()
    95  			}
    96  		case m := <-c.send:
    97  			data, err := m.Pack()
    98  			if err == nil {
    99  				err = state.wsConn.WriteMessage(websocket.BinaryMessage, data)
   100  			}
   101  			if err != nil {
   102  				c.resend <- m
   103  				state.Error(fmt.Sprintf("client: write error: %s", err))
   104  				return
   105  			}
   106  			log.Tracef("client: sent message: %s", string(data))
   107  			if m.sent != nil {
   108  				m.sent.Set()
   109  			}
   110  		}
   111  	}
   112  }
   113  
   114  func (state *wsState) Error(message string) {
   115  	if state.failing.SetToIf(false, true) {
   116  		close(state.failSignal)
   117  		if message != "" {
   118  			log.Warning(message)
   119  		}
   120  	}
   121  }