github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/warpfront/Server.go (about)

     1  package warpfront
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  )
    12  
    13  // Server wraps around the packet server
    14  type Server struct {
    15  	sessions map[string]*session
    16  	seshch   chan *session
    17  	dedch    chan bool
    18  
    19  	once sync.Once
    20  	sync.Mutex
    21  }
    22  
    23  // NewServer creates a http.Handler for warpfront.
    24  func NewServer() *Server {
    25  	return &Server{
    26  		sessions: make(map[string]*session),
    27  		seshch:   make(chan *session),
    28  	}
    29  }
    30  
    31  // Close destroys the warpfront context.
    32  func (srv *Server) Close() error {
    33  	srv.once.Do(func() {
    34  		close(srv.dedch)
    35  	})
    36  	return nil
    37  }
    38  
    39  // Accept accepts a warpfront session.
    40  func (srv *Server) Accept() (net.Conn, error) {
    41  	select {
    42  	case sesh := <-srv.seshch:
    43  		return sesh, nil
    44  		// TODO somehow do a big error thing?
    45  	case <-srv.dedch:
    46  		return nil, io.ErrClosedPipe
    47  	}
    48  }
    49  
    50  func (srv *Server) destroySession(key string) {
    51  	srv.Lock()
    52  	defer srv.Unlock()
    53  	chs, ok := srv.sessions[key]
    54  	if ok {
    55  		chs.Close()
    56  		delete(srv.sessions, key)
    57  	}
    58  }
    59  
    60  func (srv *Server) handleDelete(wr http.ResponseWriter, rq *http.Request) {
    61  	wr.Header().Set("cache-control", "no-cache")
    62  	sesh := rq.URL.Query().Get("id")
    63  	if sesh == "" {
    64  		wr.WriteHeader(http.StatusBadRequest)
    65  		return
    66  	}
    67  	srv.destroySession(sesh)
    68  }
    69  
    70  func (srv *Server) handleRegister(wr http.ResponseWriter, rq *http.Request) {
    71  	sesh := rq.URL.Query().Get("id")
    72  	if sesh == "" {
    73  		wr.WriteHeader(http.StatusBadRequest)
    74  		return
    75  	}
    76  	wr.Header().Set("cache-control", "no-cache")
    77  	srv.Lock()
    78  	_, ok := srv.sessions[sesh]
    79  	// reject if already exists
    80  	if ok {
    81  		wr.WriteHeader(http.StatusForbidden)
    82  		srv.Unlock()
    83  		return
    84  	}
    85  	// otherwise, we initialize
    86  	chs := newSession()
    87  	srv.sessions[sesh] = chs
    88  	srv.Unlock()
    89  	// now we feed into the big chan
    90  	select {
    91  	case srv.seshch <- chs:
    92  		wr.WriteHeader(http.StatusOK)
    93  		go func() {
    94  			<-chs.ded
    95  			srv.destroySession(sesh)
    96  		}()
    97  	case <-time.After(time.Second * 1):
    98  		srv.destroySession(sesh)
    99  		wr.WriteHeader(http.StatusInternalServerError)
   100  	}
   101  }
   102  
   103  // ServeHTTP implements the basic stuff for ppServ
   104  func (srv *Server) ServeHTTP(wr http.ResponseWriter, rq *http.Request) {
   105  	key := rq.URL.Path[1:]
   106  
   107  	if key == "register" {
   108  		srv.handleRegister(wr, rq)
   109  		return
   110  	}
   111  
   112  	if key == "delete" {
   113  		srv.handleDelete(wr, rq)
   114  		return
   115  	}
   116  
   117  	// query for the session
   118  	srv.Lock()
   119  	chs, ok := srv.sessions[key]
   120  	srv.Unlock()
   121  
   122  	if !ok {
   123  		wr.WriteHeader(http.StatusBadRequest)
   124  		return
   125  	}
   126  
   127  	up, dn, ded := chs.rx, chs.tx, chs.ded
   128  
   129  	wr.Header().Set("Content-Encoding", "application/octet-stream")
   130  	wr.Header().Set("Cache-Control", "no-cache, no-store")
   131  
   132  	// signal for continuing
   133  	contbuf := make([]byte, 4)
   134  	binary.BigEndian.PutUint32(contbuf, 0)
   135  
   136  	switch rq.Method {
   137  	case "GET":
   138  		ctr := 0
   139  		start := time.Now()
   140  		for ctr < 10*1024*1024 && time.Now().Sub(start) < time.Second*40 {
   141  			delay := time.Millisecond * 5000
   142  			select {
   143  			case bts := <-dn:
   144  				// write length, then bytes
   145  				buf := make([]byte, 4)
   146  				binary.BigEndian.PutUint32(buf, uint32(len(bts)))
   147  				_, err := wr.Write(append(buf, bts...))
   148  				if err != nil {
   149  					srv.destroySession(key)
   150  					return
   151  				}
   152  				ctr += len(bts)
   153  				wr.(http.Flusher).Flush()
   154  				delay = time.Millisecond * 5000
   155  			case <-time.After(delay):
   156  				wr.Write(contbuf)
   157  				wr.(http.Flusher).Flush()
   158  				delay = delay + time.Millisecond*50
   159  				if delay > time.Second*5 {
   160  					delay = time.Second * 5
   161  				}
   162  				return
   163  			case <-ded:
   164  				srv.destroySession(key)
   165  				return
   166  			}
   167  		}
   168  		wr.Write(contbuf)
   169  		wr.(http.Flusher).Flush()
   170  		time.Sleep(time.Second)
   171  	case "POST":
   172  		pkrd := new(bytes.Buffer)
   173  		_, err := io.Copy(pkrd, rq.Body)
   174  		if err != nil {
   175  			srv.destroySession(key)
   176  			return
   177  		}
   178  		select {
   179  		case up <- pkrd.Bytes(): // TODO potential deadlock; currently mitigated by a buffer
   180  			return
   181  		case <-time.After(time.Minute):
   182  			srv.destroySession(key)
   183  			return
   184  		case <-ded:
   185  			srv.destroySession(key)
   186  			return
   187  		}
   188  	}
   189  }