github.com/cellofellow/gopkg@v0.0.0-20140722061823-eec0544a62ad/web/scgi.go (about)

     1  // Copyright 2013 <chaishushan{AT}gmail.com>. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package web
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/http/cgi"
    16  	"strconv"
    17  	"strings"
    18  )
    19  
    20  type scgiBody struct {
    21  	reader io.Reader
    22  	conn   io.ReadWriteCloser
    23  	closed bool
    24  }
    25  
    26  func (b *scgiBody) Read(p []byte) (n int, err error) {
    27  	if b.closed {
    28  		return 0, errors.New("SCGI read after close")
    29  	}
    30  	return b.reader.Read(p)
    31  }
    32  
    33  func (b *scgiBody) Close() error {
    34  	b.closed = true
    35  	return b.conn.Close()
    36  }
    37  
    38  type scgiConn struct {
    39  	fd           io.ReadWriteCloser
    40  	req          *http.Request
    41  	headers      http.Header
    42  	wroteHeaders bool
    43  }
    44  
    45  func (conn *scgiConn) WriteHeader(status int) {
    46  	if !conn.wroteHeaders {
    47  		conn.wroteHeaders = true
    48  
    49  		var buf bytes.Buffer
    50  		text := statusText[status]
    51  
    52  		fmt.Fprintf(&buf, "HTTP/1.1 %d %s\r\n", status, text)
    53  
    54  		for k, v := range conn.headers {
    55  			for _, i := range v {
    56  				buf.WriteString(k + ": " + i + "\r\n")
    57  			}
    58  		}
    59  
    60  		buf.WriteString("\r\n")
    61  		conn.fd.Write(buf.Bytes())
    62  	}
    63  }
    64  
    65  func (conn *scgiConn) Header() http.Header {
    66  	return conn.headers
    67  }
    68  
    69  func (conn *scgiConn) Write(data []byte) (n int, err error) {
    70  	if !conn.wroteHeaders {
    71  		conn.WriteHeader(200)
    72  	}
    73  
    74  	if conn.req.Method == "HEAD" {
    75  		return 0, errors.New("Body Not Allowed")
    76  	}
    77  
    78  	return conn.fd.Write(data)
    79  }
    80  
    81  func (conn *scgiConn) Close() { conn.fd.Close() }
    82  
    83  func (conn *scgiConn) finishRequest() error {
    84  	var buf bytes.Buffer
    85  	if !conn.wroteHeaders {
    86  		conn.wroteHeaders = true
    87  
    88  		for k, v := range conn.headers {
    89  			for _, i := range v {
    90  				buf.WriteString(k + ": " + i + "\r\n")
    91  			}
    92  		}
    93  
    94  		buf.WriteString("\r\n")
    95  		conn.fd.Write(buf.Bytes())
    96  	}
    97  	return nil
    98  }
    99  
   100  func (s *Server) readScgiRequest(fd io.ReadWriteCloser) (*http.Request, error) {
   101  	reader := bufio.NewReader(fd)
   102  	line, err := reader.ReadString(':')
   103  	if err != nil {
   104  		s.Logger.Println("Error during SCGI read: ", err.Error())
   105  	}
   106  	length, _ := strconv.Atoi(line[0 : len(line)-1])
   107  	if length > 16384 {
   108  		s.Logger.Println("Error: max header size is 16k")
   109  	}
   110  	headerData := make([]byte, length)
   111  	_, err = reader.Read(headerData)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	b, err := reader.ReadByte()
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	// discard the trailing comma
   121  	if b != ',' {
   122  		return nil, errors.New("SCGI protocol error: missing comma")
   123  	}
   124  	headerList := bytes.Split(headerData, []byte{0})
   125  	headers := map[string]string{}
   126  	for i := 0; i < len(headerList)-1; i += 2 {
   127  		headers[string(headerList[i])] = string(headerList[i+1])
   128  	}
   129  	httpReq, err := cgi.RequestFromMap(headers)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	if httpReq.ContentLength > 0 {
   134  		httpReq.Body = &scgiBody{
   135  			reader: io.LimitReader(reader, httpReq.ContentLength),
   136  			conn:   fd,
   137  		}
   138  	} else {
   139  		httpReq.Body = &scgiBody{reader: reader, conn: fd}
   140  	}
   141  	return httpReq, nil
   142  }
   143  
   144  func (s *Server) handleScgiRequest(fd io.ReadWriteCloser) {
   145  	req, err := s.readScgiRequest(fd)
   146  	if err != nil {
   147  		s.Logger.Println("SCGI error:", err.Error())
   148  	}
   149  	sc := scgiConn{fd, req, make(map[string][]string), false}
   150  	s.routeHandler(req, &sc)
   151  	sc.finishRequest()
   152  	fd.Close()
   153  }
   154  
   155  func (s *Server) listenAndServeScgi(addr string) error {
   156  
   157  	var l net.Listener
   158  	var err error
   159  
   160  	//if the path begins with a "/", assume it's a unix address
   161  	if strings.HasPrefix(addr, "/") {
   162  		l, err = net.Listen("unix", addr)
   163  	} else {
   164  		l, err = net.Listen("tcp", addr)
   165  	}
   166  
   167  	//save the listener so it can be closed
   168  	s.l = l
   169  
   170  	if err != nil {
   171  		s.Logger.Println("SCGI listen error", err.Error())
   172  		return err
   173  	}
   174  
   175  	for {
   176  		fd, err := l.Accept()
   177  		if err != nil {
   178  			s.Logger.Println("SCGI accept error", err.Error())
   179  			return err
   180  		}
   181  		go s.handleScgiRequest(fd)
   182  	}
   183  }