github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/handler.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rafttransport 5 6 import ( 7 "fmt" 8 "net" 9 "net/http" 10 ) 11 12 // Handler is an http.Handler suitable for serving an endpoint that 13 // upgrades to raft transport connections. 14 type Handler struct { 15 connections chan<- net.Conn 16 abort <-chan struct{} 17 } 18 19 // NewHandler returns a new Handler that sends connections to the 20 // given connections channel, and stops accepting connections after 21 // the abort channel is closed. 22 func NewHandler( 23 connections chan<- net.Conn, 24 abort <-chan struct{}, 25 ) *Handler { 26 return &Handler{ 27 connections: connections, 28 abort: abort, 29 } 30 } 31 32 // ServeHTTP is part of the http.Handler interface. 33 // 34 // ServeHTTP checks for "raft" upgrade requests, and hijacks 35 // those connections for use as a raw connection for raft 36 // communications. 37 // 38 // Based on code from https://github.com/CanonicalLtd/raft-http. 39 func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 40 // Fail immediately if we've been closed. 41 select { 42 case <-h.abort: 43 http.Error(w, "raft transport closed", http.StatusForbidden) 44 return 45 default: 46 } 47 48 if r.Header.Get("Upgrade") != "raft" { 49 http.Error(w, "missing or invalid upgrade header", http.StatusBadRequest) 50 return 51 } 52 53 hijacker, ok := w.(http.Hijacker) 54 if !ok { 55 http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) 56 return 57 } 58 59 conn, _, err := hijacker.Hijack() 60 if err != nil { 61 message := fmt.Sprintf("failed to hijack connection: %s", err) 62 http.Error(w, message, http.StatusInternalServerError) 63 return 64 } 65 66 // Write the status line and upgrade header by hand since w.WriteHeader() 67 // would fail after Hijack() 68 data := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: raft\r\n\r\n") 69 if n, err := conn.Write(data); err != nil || n != len(data) { 70 conn.Close() 71 return 72 } 73 74 select { 75 case h.connections <- conn: 76 case <-r.Context().Done(): 77 conn.Close() 78 } 79 }