github.com/lauslim12/expert-systems@v0.0.0-20221115131159-018513aad29c/pkg/inference/inference_test.go (about) 1 package inference 2 3 import ( 4 "reflect" 5 "testing" 6 ) 7 8 func TestInfer(t *testing.T) { 9 tests := []struct { 10 name string 11 input Input 12 expectedCertaintyFactor float64 13 expectedVerdict bool 14 }{ 15 { 16 name: "test_basic_input", 17 input: Input{ 18 DiseaseID: "D01", 19 Locale: "en", 20 Symptoms: []SymptomAndWeight{ 21 { 22 SymptomID: "S1", 23 Weight: 0.5, 24 }, 25 }, 26 }, 27 expectedCertaintyFactor: 0.2565, 28 expectedVerdict: false, 29 }, 30 { 31 name: "test_advanced_input_locale_en", 32 input: Input{ 33 DiseaseID: "D01", 34 Locale: "en", 35 Symptoms: []SymptomAndWeight{ 36 { 37 SymptomID: "S1", 38 Weight: 0.5, 39 }, 40 { 41 SymptomID: "S2", 42 Weight: 0.4, 43 }, 44 { 45 SymptomID: "S3", 46 Weight: 0.2, 47 }, 48 { 49 SymptomID: "S4", 50 Weight: 0.6, 51 }, 52 { 53 SymptomID: "S5", 54 Weight: 0.2, 55 }, 56 { 57 SymptomID: "S6", 58 Weight: 0.4, 59 }, 60 { 61 SymptomID: "S7", 62 Weight: 0.8, 63 }, 64 { 65 SymptomID: "S8", 66 Weight: 0.2, 67 }, 68 { 69 SymptomID: "S9", 70 Weight: 0.2, 71 }, 72 { 73 SymptomID: "S10", 74 Weight: 0.4, 75 }, 76 { 77 SymptomID: "S11", 78 Weight: 0.2, 79 }, 80 { 81 SymptomID: "S12", 82 Weight: 0.2, 83 }, 84 { 85 SymptomID: "S13", 86 Weight: 0.8, 87 }, 88 }, 89 }, 90 expectedCertaintyFactor: 0.9471713614230385, 91 expectedVerdict: true, 92 }, 93 { 94 name: "test_advanced_input_locale_en", 95 input: Input{ 96 DiseaseID: "D01", 97 Locale: "en", 98 Symptoms: []SymptomAndWeight{ 99 { 100 SymptomID: "S1", 101 Weight: 0.5, 102 }, 103 { 104 SymptomID: "S2", 105 Weight: 0.4, 106 }, 107 { 108 SymptomID: "S3", 109 Weight: 0.2, 110 }, 111 { 112 SymptomID: "S4", 113 Weight: 0.6, 114 }, 115 { 116 SymptomID: "S5", 117 Weight: 0.2, 118 }, 119 { 120 SymptomID: "S6", 121 Weight: 0.4, 122 }, 123 { 124 SymptomID: "S7", 125 Weight: 0.8, 126 }, 127 { 128 SymptomID: "S8", 129 Weight: 0.2, 130 }, 131 { 132 SymptomID: "S9", 133 Weight: 0.2, 134 }, 135 { 136 SymptomID: "S10", 137 Weight: 0.4, 138 }, 139 { 140 SymptomID: "S11", 141 Weight: 0.2, 142 }, 143 { 144 SymptomID: "S12", 145 Weight: 0.2, 146 }, 147 { 148 SymptomID: "S13", 149 Weight: 0.8, 150 }, 151 }, 152 }, 153 expectedCertaintyFactor: 0.9471713614230385, 154 expectedVerdict: true, 155 }, 156 { 157 name: "test_advanced_input_locale_en", 158 input: Input{ 159 DiseaseID: "D01", 160 Locale: "id", 161 Symptoms: []SymptomAndWeight{ 162 { 163 SymptomID: "S1", 164 Weight: 0.25, 165 }, 166 { 167 SymptomID: "S2", 168 Weight: 0, 169 }, 170 { 171 SymptomID: "S3", 172 Weight: 0.25, 173 }, 174 { 175 SymptomID: "S4", 176 Weight: 0, 177 }, 178 { 179 SymptomID: "S5", 180 Weight: 0, 181 }, 182 { 183 SymptomID: "S6", 184 Weight: 0, 185 }, 186 { 187 SymptomID: "S7", 188 Weight: 0, 189 }, 190 { 191 SymptomID: "S8", 192 Weight: 0, 193 }, 194 { 195 SymptomID: "S9", 196 Weight: 0, 197 }, 198 { 199 SymptomID: "S10", 200 Weight: 0, 201 }, 202 { 203 SymptomID: "S11", 204 Weight: 0, 205 }, 206 { 207 SymptomID: "S12", 208 Weight: 0.5, 209 }, 210 { 211 SymptomID: "S13", 212 Weight: 0.2, 213 }, 214 }, 215 }, 216 expectedCertaintyFactor: 0.47902158346120005, 217 expectedVerdict: false, 218 }, 219 { 220 name: "test_invalid_input", 221 input: Input{}, 222 expectedCertaintyFactor: 0.0, 223 expectedVerdict: false, 224 }, 225 } 226 227 for _, tt := range tests { 228 t.Run(tt.name, func(t *testing.T) { 229 output := Infer(&tt.input) 230 231 if tt.expectedCertaintyFactor != output.Probability { 232 t.Errorf("Expected and actual certainty factor values are different! Expected: %v. Got: %v", tt.expectedCertaintyFactor, output.Probability) 233 } 234 235 if tt.expectedVerdict != output.Verdict { 236 t.Errorf("Expected and actual verdict values are different! Expected: %v. Got: %v", tt.expectedVerdict, output.Verdict) 237 } 238 239 }) 240 } 241 } 242 243 func TestNewInput(t *testing.T) { 244 tests := []struct { 245 name string 246 input *Input 247 expectedOutput *Input 248 }{ 249 { 250 name: "test_valid_input", 251 input: &Input{ 252 DiseaseID: "D01", 253 Locale: "id", 254 Symptoms: []SymptomAndWeight{ 255 { 256 SymptomID: "S01", 257 Weight: 0.25, 258 }, 259 }, 260 }, 261 expectedOutput: &Input{ 262 DiseaseID: "D01", 263 Locale: "id", 264 Symptoms: []SymptomAndWeight{ 265 { 266 SymptomID: "S01", 267 Weight: 0.25, 268 }, 269 }, 270 }, 271 }, 272 { 273 name: "test_invalid_input", 274 input: &Input{}, 275 expectedOutput: &Input{ 276 DiseaseID: "D01", 277 Locale: "en", 278 Symptoms: []SymptomAndWeight{}, 279 }, 280 }, 281 } 282 283 for _, tt := range tests { 284 t.Run(tt.name, func(t *testing.T) { 285 output := NewInput(tt.input) 286 287 if !reflect.DeepEqual(&tt.expectedOutput, &output) { 288 t.Errorf("Expected and actual structs are not equal! Expected: %v. Got: %v", tt.expectedOutput, output) 289 } 290 }) 291 } 292 } 293 294 func TestGetDiseaseByID(t *testing.T) { 295 diseases := getDiseases("en") 296 297 tests := []struct { 298 name string 299 diseaseID string 300 expectedOutput *Disease 301 }{ 302 { 303 name: "test_valid_disease_id", 304 diseaseID: "D01", 305 expectedOutput: &diseases[0], 306 }, 307 { 308 name: "test_invalid_disease_id", 309 diseaseID: "404", 310 expectedOutput: nil, 311 }, 312 } 313 314 for _, tt := range tests { 315 t.Run(tt.name, func(t *testing.T) { 316 output := GetDiseaseByID(tt.diseaseID, diseases) 317 318 if !reflect.DeepEqual(&tt.expectedOutput, &output) { 319 t.Errorf("Expected and actual structs are not equal! Expected: %v. Got: %v", tt.expectedOutput, output) 320 } 321 }) 322 } 323 } 324 325 func TestForwardChaining(t *testing.T) { 326 disease := getDiseases("en")[0] 327 328 tests := []struct { 329 name string 330 input Input 331 expectedOutput bool 332 }{ 333 { 334 name: "test_forward_chaining_false", 335 input: *NewInput(&Input{ 336 DiseaseID: "D01", 337 Locale: "en", 338 Symptoms: []SymptomAndWeight{ 339 { 340 SymptomID: "S1", 341 Weight: 0.25, 342 }, 343 { 344 SymptomID: "S2", 345 Weight: 0, 346 }, 347 { 348 SymptomID: "S3", 349 Weight: 0.25, 350 }, 351 { 352 SymptomID: "S4", 353 Weight: 0, 354 }, 355 { 356 SymptomID: "S5", 357 Weight: 0, 358 }, 359 { 360 SymptomID: "S6", 361 Weight: 0, 362 }, 363 { 364 SymptomID: "S7", 365 Weight: 0, 366 }, 367 { 368 SymptomID: "S8", 369 Weight: 0, 370 }, 371 { 372 SymptomID: "S9", 373 Weight: 0, 374 }, 375 { 376 SymptomID: "S10", 377 Weight: 0, 378 }, 379 { 380 SymptomID: "S11", 381 Weight: 0, 382 }, 383 { 384 SymptomID: "S12", 385 Weight: 0.5, 386 }, 387 { 388 SymptomID: "S13", 389 Weight: 0.2, 390 }, 391 }, 392 }), 393 expectedOutput: false, 394 }, 395 { 396 name: "test_forward_chaining_true", 397 input: *NewInput(&Input{ 398 DiseaseID: "D01", 399 Locale: "en", 400 Symptoms: []SymptomAndWeight{ 401 { 402 SymptomID: "S1", 403 Weight: 0.25, 404 }, 405 { 406 SymptomID: "S2", 407 Weight: 0.25, 408 }, 409 { 410 SymptomID: "S3", 411 Weight: 0.25, 412 }, 413 { 414 SymptomID: "S4", 415 Weight: 0.25, 416 }, 417 { 418 SymptomID: "S5", 419 Weight: 0.25, 420 }, 421 { 422 SymptomID: "S6", 423 Weight: 0.25, 424 }, 425 { 426 SymptomID: "S7", 427 Weight: 0.25, 428 }, 429 { 430 SymptomID: "S8", 431 Weight: 0, 432 }, 433 { 434 SymptomID: "S9", 435 Weight: 0, 436 }, 437 { 438 SymptomID: "S10", 439 Weight: 0, 440 }, 441 { 442 SymptomID: "S11", 443 Weight: 0, 444 }, 445 { 446 SymptomID: "S12", 447 Weight: 0.5, 448 }, 449 { 450 SymptomID: "S13", 451 Weight: 0.2, 452 }, 453 }, 454 }), 455 expectedOutput: true, 456 }, 457 } 458 459 for _, tt := range tests { 460 t.Run(tt.name, func(t *testing.T) { 461 output := ForwardChaining(&tt.input, &disease) 462 463 if tt.expectedOutput != output { 464 t.Errorf("Expected and actual verdict values are different! Expected: %v. Got: %v", tt.expectedOutput, output) 465 } 466 }) 467 } 468 } 469 470 func TestCertaintyFactor(t *testing.T) { 471 symptoms := getDiseases("en")[0].Symptoms 472 473 tests := []struct { 474 name string 475 input Input 476 expectedOutput float64 477 }{ 478 { 479 name: "test_valid_certainty_factor", 480 input: Input{ 481 DiseaseID: "D01", 482 Locale: "en", 483 Symptoms: []SymptomAndWeight{ 484 { 485 SymptomID: "S1", 486 Weight: 0.25, 487 }, 488 { 489 SymptomID: "S2", 490 Weight: 0.25, 491 }, 492 { 493 SymptomID: "S3", 494 Weight: 0.25, 495 }, 496 { 497 SymptomID: "S4", 498 Weight: 0.25, 499 }, 500 { 501 SymptomID: "S5", 502 Weight: 0.25, 503 }, 504 { 505 SymptomID: "S6", 506 Weight: 0.25, 507 }, 508 { 509 SymptomID: "S7", 510 Weight: 0.25, 511 }, 512 { 513 SymptomID: "S8", 514 Weight: 0, 515 }, 516 { 517 SymptomID: "S9", 518 Weight: 0, 519 }, 520 { 521 SymptomID: "S10", 522 Weight: 0, 523 }, 524 { 525 SymptomID: "S11", 526 Weight: 0, 527 }, 528 { 529 SymptomID: "S12", 530 Weight: 0.5, 531 }, 532 { 533 SymptomID: "S13", 534 Weight: 0.2, 535 }, 536 }, 537 }, 538 expectedOutput: 0.7313435264022431, 539 }, 540 { 541 name: "test_invalid_certainty_factor", 542 input: Input{}, 543 expectedOutput: 0.0, 544 }, 545 } 546 547 for _, tt := range tests { 548 t.Run(tt.name, func(t *testing.T) { 549 output := CertaintyFactor(&tt.input, symptoms) 550 551 if tt.expectedOutput != output { 552 t.Errorf("Expected and actual certainty factor values are different! Expected: %v. Got: %v", tt.expectedOutput, output) 553 } 554 }) 555 } 556 }