github.com/oinume/lekcije@v0.0.0-20231017100347-5b4c5eb6ab24/backend/internal/mock/mock_transport.go (about)

     1  package mock
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"sync"
    10  )
    11  
    12  type HTMLTransport struct {
    13  	sync.Mutex
    14  	NumCalled int
    15  	content   string
    16  }
    17  
    18  func NewHTMLTransport(path string) (*HTMLTransport, error) {
    19  	file, err := os.Open(path)
    20  	if err != nil {
    21  		return nil, fmt.Errorf("os.Open failed: path=%v, err=%v", path, err)
    22  	}
    23  	b, err := io.ReadAll(file)
    24  	if err != nil {
    25  		return nil, fmt.Errorf("read file failed: err=%v", err)
    26  	}
    27  	return &HTMLTransport{
    28  		content: string(b),
    29  	}, nil
    30  }
    31  
    32  func NewHTMLTransportFromString(content string) *HTMLTransport {
    33  	return &HTMLTransport{
    34  		content: content,
    35  	}
    36  }
    37  
    38  func (t *HTMLTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    39  	t.Lock()
    40  	t.NumCalled++
    41  	t.Unlock()
    42  	resp := &http.Response{
    43  		Header:     make(http.Header),
    44  		Request:    req,
    45  		StatusCode: http.StatusOK,
    46  		Status:     "200 OK",
    47  	}
    48  	resp.Header.Set("Content-Type", "text/html; charset=UTF-8")
    49  	resp.Body = io.NopCloser(strings.NewReader(t.content))
    50  	return resp, nil
    51  }
    52  
    53  type ResponseTransport struct {
    54  	sync.Mutex
    55  	NumCalled    int
    56  	responseFunc func(*ResponseTransport, *http.Request) *http.Response
    57  }
    58  
    59  func NewResponseTransport(f func(*ResponseTransport, *http.Request) *http.Response) *ResponseTransport {
    60  	return &ResponseTransport{responseFunc: f}
    61  }
    62  
    63  func (t *ResponseTransport) RoundTrip(req *http.Request) (*http.Response, error) {
    64  	t.Lock()
    65  	t.NumCalled++
    66  	t.Unlock()
    67  	resp := t.responseFunc(t, req)
    68  	if resp.StatusCode == 0 {
    69  		resp.StatusCode = http.StatusOK
    70  		resp.Status = http.StatusText(http.StatusOK)
    71  	}
    72  	if ct := resp.Header.Get("Content-Type"); ct == "" {
    73  		resp.Header.Set("Content-Type", "text/plain; charset=UTF-8")
    74  	}
    75  	if resp.Body == nil {
    76  		resp.Body = io.NopCloser(strings.NewReader(""))
    77  	}
    78  	return resp, nil
    79  }