github.com/lauslim12/expert-systems@v0.0.0-20221115131159-018513aad29c/pkg/inference/inference_test.go (about)

     1  package inference
     2  
     3  import (
     4  	"reflect"
     5  	"testing"
     6  )
     7  
     8  func TestInfer(t *testing.T) {
     9  	tests := []struct {
    10  		name                    string
    11  		input                   Input
    12  		expectedCertaintyFactor float64
    13  		expectedVerdict         bool
    14  	}{
    15  		{
    16  			name: "test_basic_input",
    17  			input: Input{
    18  				DiseaseID: "D01",
    19  				Locale:    "en",
    20  				Symptoms: []SymptomAndWeight{
    21  					{
    22  						SymptomID: "S1",
    23  						Weight:    0.5,
    24  					},
    25  				},
    26  			},
    27  			expectedCertaintyFactor: 0.2565,
    28  			expectedVerdict:         false,
    29  		},
    30  		{
    31  			name: "test_advanced_input_locale_en",
    32  			input: Input{
    33  				DiseaseID: "D01",
    34  				Locale:    "en",
    35  				Symptoms: []SymptomAndWeight{
    36  					{
    37  						SymptomID: "S1",
    38  						Weight:    0.5,
    39  					},
    40  					{
    41  						SymptomID: "S2",
    42  						Weight:    0.4,
    43  					},
    44  					{
    45  						SymptomID: "S3",
    46  						Weight:    0.2,
    47  					},
    48  					{
    49  						SymptomID: "S4",
    50  						Weight:    0.6,
    51  					},
    52  					{
    53  						SymptomID: "S5",
    54  						Weight:    0.2,
    55  					},
    56  					{
    57  						SymptomID: "S6",
    58  						Weight:    0.4,
    59  					},
    60  					{
    61  						SymptomID: "S7",
    62  						Weight:    0.8,
    63  					},
    64  					{
    65  						SymptomID: "S8",
    66  						Weight:    0.2,
    67  					},
    68  					{
    69  						SymptomID: "S9",
    70  						Weight:    0.2,
    71  					},
    72  					{
    73  						SymptomID: "S10",
    74  						Weight:    0.4,
    75  					},
    76  					{
    77  						SymptomID: "S11",
    78  						Weight:    0.2,
    79  					},
    80  					{
    81  						SymptomID: "S12",
    82  						Weight:    0.2,
    83  					},
    84  					{
    85  						SymptomID: "S13",
    86  						Weight:    0.8,
    87  					},
    88  				},
    89  			},
    90  			expectedCertaintyFactor: 0.9471713614230385,
    91  			expectedVerdict:         true,
    92  		},
    93  		{
    94  			name: "test_advanced_input_locale_en",
    95  			input: Input{
    96  				DiseaseID: "D01",
    97  				Locale:    "en",
    98  				Symptoms: []SymptomAndWeight{
    99  					{
   100  						SymptomID: "S1",
   101  						Weight:    0.5,
   102  					},
   103  					{
   104  						SymptomID: "S2",
   105  						Weight:    0.4,
   106  					},
   107  					{
   108  						SymptomID: "S3",
   109  						Weight:    0.2,
   110  					},
   111  					{
   112  						SymptomID: "S4",
   113  						Weight:    0.6,
   114  					},
   115  					{
   116  						SymptomID: "S5",
   117  						Weight:    0.2,
   118  					},
   119  					{
   120  						SymptomID: "S6",
   121  						Weight:    0.4,
   122  					},
   123  					{
   124  						SymptomID: "S7",
   125  						Weight:    0.8,
   126  					},
   127  					{
   128  						SymptomID: "S8",
   129  						Weight:    0.2,
   130  					},
   131  					{
   132  						SymptomID: "S9",
   133  						Weight:    0.2,
   134  					},
   135  					{
   136  						SymptomID: "S10",
   137  						Weight:    0.4,
   138  					},
   139  					{
   140  						SymptomID: "S11",
   141  						Weight:    0.2,
   142  					},
   143  					{
   144  						SymptomID: "S12",
   145  						Weight:    0.2,
   146  					},
   147  					{
   148  						SymptomID: "S13",
   149  						Weight:    0.8,
   150  					},
   151  				},
   152  			},
   153  			expectedCertaintyFactor: 0.9471713614230385,
   154  			expectedVerdict:         true,
   155  		},
   156  		{
   157  			name: "test_advanced_input_locale_en",
   158  			input: Input{
   159  				DiseaseID: "D01",
   160  				Locale:    "id",
   161  				Symptoms: []SymptomAndWeight{
   162  					{
   163  						SymptomID: "S1",
   164  						Weight:    0.25,
   165  					},
   166  					{
   167  						SymptomID: "S2",
   168  						Weight:    0,
   169  					},
   170  					{
   171  						SymptomID: "S3",
   172  						Weight:    0.25,
   173  					},
   174  					{
   175  						SymptomID: "S4",
   176  						Weight:    0,
   177  					},
   178  					{
   179  						SymptomID: "S5",
   180  						Weight:    0,
   181  					},
   182  					{
   183  						SymptomID: "S6",
   184  						Weight:    0,
   185  					},
   186  					{
   187  						SymptomID: "S7",
   188  						Weight:    0,
   189  					},
   190  					{
   191  						SymptomID: "S8",
   192  						Weight:    0,
   193  					},
   194  					{
   195  						SymptomID: "S9",
   196  						Weight:    0,
   197  					},
   198  					{
   199  						SymptomID: "S10",
   200  						Weight:    0,
   201  					},
   202  					{
   203  						SymptomID: "S11",
   204  						Weight:    0,
   205  					},
   206  					{
   207  						SymptomID: "S12",
   208  						Weight:    0.5,
   209  					},
   210  					{
   211  						SymptomID: "S13",
   212  						Weight:    0.2,
   213  					},
   214  				},
   215  			},
   216  			expectedCertaintyFactor: 0.47902158346120005,
   217  			expectedVerdict:         false,
   218  		},
   219  		{
   220  			name:                    "test_invalid_input",
   221  			input:                   Input{},
   222  			expectedCertaintyFactor: 0.0,
   223  			expectedVerdict:         false,
   224  		},
   225  	}
   226  
   227  	for _, tt := range tests {
   228  		t.Run(tt.name, func(t *testing.T) {
   229  			output := Infer(&tt.input)
   230  
   231  			if tt.expectedCertaintyFactor != output.Probability {
   232  				t.Errorf("Expected and actual certainty factor values are different! Expected: %v. Got: %v", tt.expectedCertaintyFactor, output.Probability)
   233  			}
   234  
   235  			if tt.expectedVerdict != output.Verdict {
   236  				t.Errorf("Expected and actual verdict values are different! Expected: %v. Got: %v", tt.expectedVerdict, output.Verdict)
   237  			}
   238  
   239  		})
   240  	}
   241  }
   242  
   243  func TestNewInput(t *testing.T) {
   244  	tests := []struct {
   245  		name           string
   246  		input          *Input
   247  		expectedOutput *Input
   248  	}{
   249  		{
   250  			name: "test_valid_input",
   251  			input: &Input{
   252  				DiseaseID: "D01",
   253  				Locale:    "id",
   254  				Symptoms: []SymptomAndWeight{
   255  					{
   256  						SymptomID: "S01",
   257  						Weight:    0.25,
   258  					},
   259  				},
   260  			},
   261  			expectedOutput: &Input{
   262  				DiseaseID: "D01",
   263  				Locale:    "id",
   264  				Symptoms: []SymptomAndWeight{
   265  					{
   266  						SymptomID: "S01",
   267  						Weight:    0.25,
   268  					},
   269  				},
   270  			},
   271  		},
   272  		{
   273  			name:  "test_invalid_input",
   274  			input: &Input{},
   275  			expectedOutput: &Input{
   276  				DiseaseID: "D01",
   277  				Locale:    "en",
   278  				Symptoms:  []SymptomAndWeight{},
   279  			},
   280  		},
   281  	}
   282  
   283  	for _, tt := range tests {
   284  		t.Run(tt.name, func(t *testing.T) {
   285  			output := NewInput(tt.input)
   286  
   287  			if !reflect.DeepEqual(&tt.expectedOutput, &output) {
   288  				t.Errorf("Expected and actual structs are not equal! Expected: %v. Got: %v", tt.expectedOutput, output)
   289  			}
   290  		})
   291  	}
   292  }
   293  
   294  func TestGetDiseaseByID(t *testing.T) {
   295  	diseases := getDiseases("en")
   296  
   297  	tests := []struct {
   298  		name           string
   299  		diseaseID      string
   300  		expectedOutput *Disease
   301  	}{
   302  		{
   303  			name:           "test_valid_disease_id",
   304  			diseaseID:      "D01",
   305  			expectedOutput: &diseases[0],
   306  		},
   307  		{
   308  			name:           "test_invalid_disease_id",
   309  			diseaseID:      "404",
   310  			expectedOutput: nil,
   311  		},
   312  	}
   313  
   314  	for _, tt := range tests {
   315  		t.Run(tt.name, func(t *testing.T) {
   316  			output := GetDiseaseByID(tt.diseaseID, diseases)
   317  
   318  			if !reflect.DeepEqual(&tt.expectedOutput, &output) {
   319  				t.Errorf("Expected and actual structs are not equal! Expected: %v. Got: %v", tt.expectedOutput, output)
   320  			}
   321  		})
   322  	}
   323  }
   324  
   325  func TestForwardChaining(t *testing.T) {
   326  	disease := getDiseases("en")[0]
   327  
   328  	tests := []struct {
   329  		name           string
   330  		input          Input
   331  		expectedOutput bool
   332  	}{
   333  		{
   334  			name: "test_forward_chaining_false",
   335  			input: *NewInput(&Input{
   336  				DiseaseID: "D01",
   337  				Locale:    "en",
   338  				Symptoms: []SymptomAndWeight{
   339  					{
   340  						SymptomID: "S1",
   341  						Weight:    0.25,
   342  					},
   343  					{
   344  						SymptomID: "S2",
   345  						Weight:    0,
   346  					},
   347  					{
   348  						SymptomID: "S3",
   349  						Weight:    0.25,
   350  					},
   351  					{
   352  						SymptomID: "S4",
   353  						Weight:    0,
   354  					},
   355  					{
   356  						SymptomID: "S5",
   357  						Weight:    0,
   358  					},
   359  					{
   360  						SymptomID: "S6",
   361  						Weight:    0,
   362  					},
   363  					{
   364  						SymptomID: "S7",
   365  						Weight:    0,
   366  					},
   367  					{
   368  						SymptomID: "S8",
   369  						Weight:    0,
   370  					},
   371  					{
   372  						SymptomID: "S9",
   373  						Weight:    0,
   374  					},
   375  					{
   376  						SymptomID: "S10",
   377  						Weight:    0,
   378  					},
   379  					{
   380  						SymptomID: "S11",
   381  						Weight:    0,
   382  					},
   383  					{
   384  						SymptomID: "S12",
   385  						Weight:    0.5,
   386  					},
   387  					{
   388  						SymptomID: "S13",
   389  						Weight:    0.2,
   390  					},
   391  				},
   392  			}),
   393  			expectedOutput: false,
   394  		},
   395  		{
   396  			name: "test_forward_chaining_true",
   397  			input: *NewInput(&Input{
   398  				DiseaseID: "D01",
   399  				Locale:    "en",
   400  				Symptoms: []SymptomAndWeight{
   401  					{
   402  						SymptomID: "S1",
   403  						Weight:    0.25,
   404  					},
   405  					{
   406  						SymptomID: "S2",
   407  						Weight:    0.25,
   408  					},
   409  					{
   410  						SymptomID: "S3",
   411  						Weight:    0.25,
   412  					},
   413  					{
   414  						SymptomID: "S4",
   415  						Weight:    0.25,
   416  					},
   417  					{
   418  						SymptomID: "S5",
   419  						Weight:    0.25,
   420  					},
   421  					{
   422  						SymptomID: "S6",
   423  						Weight:    0.25,
   424  					},
   425  					{
   426  						SymptomID: "S7",
   427  						Weight:    0.25,
   428  					},
   429  					{
   430  						SymptomID: "S8",
   431  						Weight:    0,
   432  					},
   433  					{
   434  						SymptomID: "S9",
   435  						Weight:    0,
   436  					},
   437  					{
   438  						SymptomID: "S10",
   439  						Weight:    0,
   440  					},
   441  					{
   442  						SymptomID: "S11",
   443  						Weight:    0,
   444  					},
   445  					{
   446  						SymptomID: "S12",
   447  						Weight:    0.5,
   448  					},
   449  					{
   450  						SymptomID: "S13",
   451  						Weight:    0.2,
   452  					},
   453  				},
   454  			}),
   455  			expectedOutput: true,
   456  		},
   457  	}
   458  
   459  	for _, tt := range tests {
   460  		t.Run(tt.name, func(t *testing.T) {
   461  			output := ForwardChaining(&tt.input, &disease)
   462  
   463  			if tt.expectedOutput != output {
   464  				t.Errorf("Expected and actual verdict values are different! Expected: %v. Got: %v", tt.expectedOutput, output)
   465  			}
   466  		})
   467  	}
   468  }
   469  
   470  func TestCertaintyFactor(t *testing.T) {
   471  	symptoms := getDiseases("en")[0].Symptoms
   472  
   473  	tests := []struct {
   474  		name           string
   475  		input          Input
   476  		expectedOutput float64
   477  	}{
   478  		{
   479  			name: "test_valid_certainty_factor",
   480  			input: Input{
   481  				DiseaseID: "D01",
   482  				Locale:    "en",
   483  				Symptoms: []SymptomAndWeight{
   484  					{
   485  						SymptomID: "S1",
   486  						Weight:    0.25,
   487  					},
   488  					{
   489  						SymptomID: "S2",
   490  						Weight:    0.25,
   491  					},
   492  					{
   493  						SymptomID: "S3",
   494  						Weight:    0.25,
   495  					},
   496  					{
   497  						SymptomID: "S4",
   498  						Weight:    0.25,
   499  					},
   500  					{
   501  						SymptomID: "S5",
   502  						Weight:    0.25,
   503  					},
   504  					{
   505  						SymptomID: "S6",
   506  						Weight:    0.25,
   507  					},
   508  					{
   509  						SymptomID: "S7",
   510  						Weight:    0.25,
   511  					},
   512  					{
   513  						SymptomID: "S8",
   514  						Weight:    0,
   515  					},
   516  					{
   517  						SymptomID: "S9",
   518  						Weight:    0,
   519  					},
   520  					{
   521  						SymptomID: "S10",
   522  						Weight:    0,
   523  					},
   524  					{
   525  						SymptomID: "S11",
   526  						Weight:    0,
   527  					},
   528  					{
   529  						SymptomID: "S12",
   530  						Weight:    0.5,
   531  					},
   532  					{
   533  						SymptomID: "S13",
   534  						Weight:    0.2,
   535  					},
   536  				},
   537  			},
   538  			expectedOutput: 0.7313435264022431,
   539  		},
   540  		{
   541  			name:           "test_invalid_certainty_factor",
   542  			input:          Input{},
   543  			expectedOutput: 0.0,
   544  		},
   545  	}
   546  
   547  	for _, tt := range tests {
   548  		t.Run(tt.name, func(t *testing.T) {
   549  			output := CertaintyFactor(&tt.input, symptoms)
   550  
   551  			if tt.expectedOutput != output {
   552  				t.Errorf("Expected and actual certainty factor values are different! Expected: %v. Got: %v", tt.expectedOutput, output)
   553  			}
   554  		})
   555  	}
   556  }