github.com/hupe1980/go-huggingface@v0.0.15/huggingface_test.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  // Mock HTTP Client for testing purposes
    15  type mockHTTPClient struct {
    16  	Response []byte
    17  	Err      error
    18  }
    19  
    20  func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
    21  	if c.Err != nil {
    22  		return nil, c.Err
    23  	}
    24  
    25  	return &http.Response{
    26  		StatusCode: http.StatusOK,
    27  		Body:       io.NopCloser(bytes.NewBuffer(c.Response)),
    28  	}, nil
    29  }
    30  
    31  func TestSummarization(t *testing.T) {
    32  	client := NewInferenceClient("your-token")
    33  	mockResponse := []byte(`[{"summary_text": "This is a summary"}]`)
    34  
    35  	t.Run("Successful Request", func(t *testing.T) {
    36  		// Mock HTTP Client with successful response
    37  		mockHTTP := &mockHTTPClient{Response: mockResponse}
    38  		client.httpClient = mockHTTP
    39  
    40  		req := &SummarizationRequest{
    41  			Inputs: []string{"This is a test input"},
    42  			Model:  "t5-base",
    43  		}
    44  
    45  		response, err := client.Summarization(context.Background(), req)
    46  		assert.NoError(t, err)
    47  		assert.NotNil(t, response)
    48  		assert.Equal(t, "This is a summary", response[0].SummaryText)
    49  	})
    50  
    51  	t.Run("Empty Inputs", func(t *testing.T) {
    52  		req := &SummarizationRequest{
    53  			Inputs: nil, // Empty inputs
    54  			Model:  "t5-base",
    55  		}
    56  
    57  		response, err := client.Summarization(context.Background(), req)
    58  		assert.Error(t, err)
    59  		assert.Nil(t, response)
    60  		assert.Equal(t, "inputs are required", err.Error())
    61  	})
    62  
    63  	t.Run("HTTP Request Error", func(t *testing.T) {
    64  		// Mock HTTP Client with error response
    65  		mockHTTP := &mockHTTPClient{Err: errors.New("request error")}
    66  		client.httpClient = mockHTTP
    67  
    68  		req := &SummarizationRequest{
    69  			Inputs: []string{"This is a test input"},
    70  			Model:  "t5-base",
    71  		}
    72  
    73  		response, err := client.Summarization(context.Background(), req)
    74  		assert.Error(t, err)
    75  		assert.Nil(t, response)
    76  		assert.Equal(t, "request error", err.Error())
    77  	})
    78  }
    79  
    80  func TestQuestionAnswering(t *testing.T) {
    81  	client := NewInferenceClient("your-token")
    82  
    83  	t.Run("Missing question input", func(t *testing.T) {
    84  		req := &QuestionAnsweringRequest{
    85  			Model: "distilbert-base-uncased-distilled-squad",
    86  			Inputs: QuestionAnsweringInputs{
    87  				Context: "Paris is the capital of France.",
    88  			},
    89  		}
    90  		_, err := client.QuestionAnswering(context.Background(), req)
    91  		assert.EqualError(t, err, "question is required")
    92  	})
    93  
    94  	t.Run("Missing context input", func(t *testing.T) {
    95  		req := &QuestionAnsweringRequest{
    96  			Model: "distilbert-base-uncased-distilled-squad",
    97  			Inputs: QuestionAnsweringInputs{
    98  				Question: "What is the capital of France?",
    99  			},
   100  		}
   101  		_, err := client.QuestionAnswering(context.Background(), req)
   102  		assert.EqualError(t, err, "context is required")
   103  	})
   104  }