github.com/iDigitalFlame/xmt@v0.5.4/com/wc2/listener.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  // Package wc2 contains a HTTP/Web based communication channel, which follows
    18  // the Golang 'net.Conn' interface and is very configurable.
    19  package wc2
    20  
    21  import (
    22  	"context"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"sync/atomic"
    27  	"time"
    28  
    29  	"github.com/iDigitalFlame/xmt/util/bugtrack"
    30  )
    31  
    32  var done = new(complete)
    33  
    34  type addr string
    35  type conn struct {
    36  	_ [0]func()
    37  	net.Conn
    38  	ch   chan complete
    39  	done uint32
    40  }
    41  type complete struct{}
    42  type listener struct {
    43  	err error
    44  	ctx context.Context
    45  
    46  	p       *Server
    47  	ch, pch chan complete
    48  	new     chan *conn
    49  	*http.Server
    50  
    51  	rules []Rule
    52  	read  time.Duration
    53  }
    54  
    55  func (c *conn) Close() error {
    56  	if atomic.LoadUint32(&c.done) == 1 {
    57  		return nil
    58  	}
    59  	atomic.StoreUint32(&c.done, 1)
    60  	err := c.Conn.Close()
    61  	close(c.ch)
    62  	return err
    63  }
    64  func (a addr) String() string {
    65  	return string(a)
    66  }
    67  func (complete) Timeout() bool {
    68  	return true
    69  }
    70  func (complete) Error() string {
    71  	return context.DeadlineExceeded.Error()
    72  }
    73  func (complete) Temporary() bool {
    74  	return true
    75  }
    76  func (l *listener) Close() error {
    77  	if l.p == nil {
    78  		return nil
    79  	}
    80  	err := l.Server.Close()
    81  	close(l.new)
    82  	close(l.ch)
    83  	if l.rules, l.p = nil, nil; err != nil {
    84  		return err
    85  	}
    86  	if l.err == http.ErrServerClosed {
    87  		return nil
    88  	}
    89  	return l.err
    90  }
    91  func (l *listener) Addr() net.Addr {
    92  	return addr(l.Server.Addr)
    93  }
    94  func (l *listener) listen(x net.Listener) {
    95  	l.err = l.Serve(x)
    96  	l.Close()
    97  }
    98  func (l *listener) Accept() (net.Conn, error) {
    99  	if l.err != nil {
   100  		return nil, l.err
   101  	}
   102  	if l.read > 0 {
   103  		var (
   104  			t   = time.NewTimer(l.read)
   105  			n   *conn
   106  			err error
   107  		)
   108  		select {
   109  		case <-t.C:
   110  			err = done
   111  		case <-l.ch:
   112  			err = io.ErrClosedPipe
   113  		case <-l.pch:
   114  			err = io.ErrClosedPipe
   115  		case n = <-l.new:
   116  		case <-l.ctx.Done():
   117  			err = io.ErrClosedPipe
   118  		}
   119  		t.Stop()
   120  		return n, err
   121  	}
   122  	select {
   123  	case <-l.ch:
   124  	case <-l.pch:
   125  	case n := <-l.new:
   126  		return n, nil
   127  	case <-l.ctx.Done():
   128  	}
   129  	return nil, io.ErrClosedPipe
   130  }
   131  func (l *listener) context(_ net.Listener) context.Context {
   132  	return l.ctx
   133  }
   134  func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   135  	if !matchAll(r, l.rules) {
   136  		if bugtrack.Enabled {
   137  			bugtrack.Track("wc2.(*listener).ServeHTTP(): Connection from %s passed to parent as it does not match rules.", r.RemoteAddr)
   138  		}
   139  		l.p.handler.ServeHTTP(w, r)
   140  		r.Body.Close()
   141  		return
   142  	}
   143  	h, ok := w.(http.Hijacker)
   144  	if !ok {
   145  		if bugtrack.Enabled {
   146  			bugtrack.Track("wc2.(*listener).ServeHTTP(): Connection from %s cannot be hijacked, ignoring!", r.RemoteAddr)
   147  		}
   148  		w.WriteHeader(http.StatusNotAcceptable)
   149  		return
   150  	}
   151  	modHeaders(w.Header())
   152  	w.WriteHeader(http.StatusSwitchingProtocols)
   153  	c, _, err := h.Hijack()
   154  	if err != nil {
   155  		if bugtrack.Enabled {
   156  			bugtrack.Track("wc2.(*listener).ServeHTTP(): Connection from %s cannot be hijacked err=%s!", r.RemoteAddr, err.Error())
   157  		}
   158  		return
   159  	}
   160  	if l.p == nil {
   161  		c.Close()
   162  		return
   163  	}
   164  	if bugtrack.Enabled {
   165  		bugtrack.Track("wc2.(*listener).ServeHTTP(): Adding tracked connection from %s", r.RemoteAddr)
   166  	}
   167  	v := &conn{ch: make(chan complete, 1), Conn: c}
   168  	l.new <- v
   169  	select {
   170  	case <-v.ch:
   171  	case <-l.ch:
   172  	case <-l.ctx.Done():
   173  	}
   174  	v.Close()
   175  }