github.com/yaling888/clash@v1.53.0/constant/mitm.go (about)

     1  package constant
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"time"
    13  
    14  	"golang.org/x/text/encoding/charmap"
    15  	"golang.org/x/text/transform"
    16  
    17  	"github.com/yaling888/clash/common/cert"
    18  )
    19  
    20  var (
    21  	ErrInvalidResponse = errors.New("invalid response")
    22  	ErrInvalidURL      = errors.New("invalid URL")
    23  )
    24  
    25  type RewriteHandler interface {
    26  	HandleRequest(*MitmSession) (*http.Request, *http.Response) // session.Response maybe nil
    27  	HandleResponse(*MitmSession) *http.Response
    28  	HandleApiRequest(*MitmSession) bool
    29  	HandleError(*MitmSession, error) // session maybe nil
    30  }
    31  
    32  type MitmOption struct {
    33  	ApiHost string
    34  
    35  	TLSConfig  *tls.Config
    36  	CertConfig *cert.Config
    37  
    38  	Handler RewriteHandler
    39  }
    40  
    41  type MitmSession struct {
    42  	Conn     net.Conn
    43  	Request  *http.Request
    44  	Response *http.Response
    45  
    46  	props map[string]any
    47  }
    48  
    49  func (s *MitmSession) GetProperties(key string) (any, bool) {
    50  	v, ok := s.props[key]
    51  	return v, ok
    52  }
    53  
    54  func (s *MitmSession) SetProperties(key string, val any) {
    55  	s.props[key] = val
    56  }
    57  
    58  func (s *MitmSession) NewResponse(code int, body io.Reader) *http.Response {
    59  	return NewResponse(code, body, s.Request)
    60  }
    61  
    62  func (s *MitmSession) NewErrorResponse(err error) *http.Response {
    63  	return NewErrorResponse(s.Request, err)
    64  }
    65  
    66  func (s *MitmSession) WriteResponse() (err error) {
    67  	if s.Response == nil {
    68  		return ErrInvalidResponse
    69  	}
    70  	err = s.Response.Write(s.Conn)
    71  	if s.Response.Body != nil {
    72  		_ = s.Response.Body.Close()
    73  	}
    74  	return
    75  }
    76  
    77  func NewMitmSession(conn net.Conn, request *http.Request, response *http.Response) *MitmSession {
    78  	return &MitmSession{
    79  		Conn:     conn,
    80  		Request:  request,
    81  		Response: response,
    82  		props:    map[string]any{},
    83  	}
    84  }
    85  
    86  func NewResponse(code int, body io.Reader, req *http.Request) *http.Response {
    87  	if body == nil {
    88  		body = &bytes.Buffer{}
    89  	}
    90  
    91  	rc, ok := body.(io.ReadCloser)
    92  	if !ok {
    93  		rc = io.NopCloser(body)
    94  	}
    95  
    96  	res := &http.Response{
    97  		StatusCode: code,
    98  		Status:     fmt.Sprintf("%d %s", code, http.StatusText(code)),
    99  		Proto:      "HTTP/1.1",
   100  		ProtoMajor: 1,
   101  		ProtoMinor: 1,
   102  		Header:     http.Header{},
   103  		Body:       rc,
   104  		Request:    req,
   105  	}
   106  
   107  	if req != nil {
   108  		res.Close = req.Close
   109  		res.Proto = req.Proto
   110  		res.ProtoMajor = req.ProtoMajor
   111  		res.ProtoMinor = req.ProtoMinor
   112  	}
   113  
   114  	return res
   115  }
   116  
   117  func NewErrorResponse(req *http.Request, err error) *http.Response {
   118  	res := NewResponse(http.StatusBadGateway, nil, req)
   119  	res.Close = true
   120  
   121  	date := res.Header.Get("Date")
   122  	if date == "" {
   123  		date = time.Now().Format(http.TimeFormat)
   124  	}
   125  
   126  	w := fmt.Sprintf(`199 "clash" %s %s`, err.Error(), date)
   127  	res.Header.Add("Warning", w)
   128  	return res
   129  }
   130  
   131  func ReadDecompressedBody(res *http.Response) ([]byte, error) {
   132  	rBody := res.Body
   133  	if res.Header.Get("Content-Encoding") == "gzip" {
   134  		gzReader, err := gzip.NewReader(rBody)
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  		rBody = gzReader
   139  
   140  		defer func(gzReader *gzip.Reader) {
   141  			_ = gzReader.Close()
   142  		}(gzReader)
   143  	}
   144  	return io.ReadAll(rBody)
   145  }
   146  
   147  func DecodeLatin1(reader io.Reader) (string, error) {
   148  	r := transform.NewReader(reader, charmap.ISO8859_1.NewDecoder())
   149  	b, err := io.ReadAll(r)
   150  	if err != nil {
   151  		return "", err
   152  	}
   153  
   154  	return string(b), nil
   155  }
   156  
   157  func EncodeLatin1(str string) ([]byte, error) {
   158  	return charmap.ISO8859_1.NewEncoder().Bytes([]byte(str))
   159  }