github.com/hupe1980/go-huggingface@v0.0.15/huggingface_test.go (about) 1 package huggingface 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "io" 8 "net/http" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 ) 13 14 // Mock HTTP Client for testing purposes 15 type mockHTTPClient struct { 16 Response []byte 17 Err error 18 } 19 20 func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { 21 if c.Err != nil { 22 return nil, c.Err 23 } 24 25 return &http.Response{ 26 StatusCode: http.StatusOK, 27 Body: io.NopCloser(bytes.NewBuffer(c.Response)), 28 }, nil 29 } 30 31 func TestSummarization(t *testing.T) { 32 client := NewInferenceClient("your-token") 33 mockResponse := []byte(`[{"summary_text": "This is a summary"}]`) 34 35 t.Run("Successful Request", func(t *testing.T) { 36 // Mock HTTP Client with successful response 37 mockHTTP := &mockHTTPClient{Response: mockResponse} 38 client.httpClient = mockHTTP 39 40 req := &SummarizationRequest{ 41 Inputs: []string{"This is a test input"}, 42 Model: "t5-base", 43 } 44 45 response, err := client.Summarization(context.Background(), req) 46 assert.NoError(t, err) 47 assert.NotNil(t, response) 48 assert.Equal(t, "This is a summary", response[0].SummaryText) 49 }) 50 51 t.Run("Empty Inputs", func(t *testing.T) { 52 req := &SummarizationRequest{ 53 Inputs: nil, // Empty inputs 54 Model: "t5-base", 55 } 56 57 response, err := client.Summarization(context.Background(), req) 58 assert.Error(t, err) 59 assert.Nil(t, response) 60 assert.Equal(t, "inputs are required", err.Error()) 61 }) 62 63 t.Run("HTTP Request Error", func(t *testing.T) { 64 // Mock HTTP Client with error response 65 mockHTTP := &mockHTTPClient{Err: errors.New("request error")} 66 client.httpClient = mockHTTP 67 68 req := &SummarizationRequest{ 69 Inputs: []string{"This is a test input"}, 70 Model: "t5-base", 71 } 72 73 response, err := client.Summarization(context.Background(), req) 74 assert.Error(t, err) 75 assert.Nil(t, response) 76 assert.Equal(t, "request error", err.Error()) 77 }) 78 } 79 80 func TestQuestionAnswering(t *testing.T) { 81 client := NewInferenceClient("your-token") 82 83 t.Run("Missing question input", func(t *testing.T) { 84 req := &QuestionAnsweringRequest{ 85 Model: "distilbert-base-uncased-distilled-squad", 86 Inputs: QuestionAnsweringInputs{ 87 Context: "Paris is the capital of France.", 88 }, 89 } 90 _, err := client.QuestionAnswering(context.Background(), req) 91 assert.EqualError(t, err, "question is required") 92 }) 93 94 t.Run("Missing context input", func(t *testing.T) { 95 req := &QuestionAnsweringRequest{ 96 Model: "distilbert-base-uncased-distilled-squad", 97 Inputs: QuestionAnsweringInputs{ 98 Question: "What is the capital of France?", 99 }, 100 } 101 _, err := client.QuestionAnswering(context.Background(), req) 102 assert.EqualError(t, err, "context is required") 103 }) 104 }