Skip to content

Commit 8dd1e59

Browse files
committed
Added command to generate a report about the model
1 parent 215a741 commit 8dd1e59

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
<?php
2+
3+
namespace App\Module\ML\Application\Model;
4+
5+
use App\Core\Application\Path\AppPathResolver;
6+
use Rubix\ML\CrossValidation\Reports\AggregateReport;
7+
use Rubix\ML\CrossValidation\Reports\ConfusionMatrix;
8+
use Rubix\ML\CrossValidation\Reports\MulticlassBreakdown;
9+
use Rubix\ML\Datasets\Labeled;
10+
use Rubix\ML\Extractors\CSV;
11+
use Rubix\ML\PersistentModel;
12+
use Rubix\ML\Persisters\Filesystem;
13+
use Rubix\ML\Report;
14+
15+
/**
16+
* @see https://docs.rubixml.com/latest/cross-validation.html
17+
* @see https://docs.rubixml.com/latest/cross-validation/reports/multiclass-breakdown.html
18+
* @see https://docs.rubixml.com/latest/cross-validation/reports/confusion-matrix.html
19+
*/
20+
readonly class SpamModelReport
21+
{
22+
public function __construct(
23+
private AppPathResolver $appPathResolver,
24+
) {
25+
}
26+
27+
public function generateReport(
28+
string $testingDatasetFilename,
29+
string $modelFilename,
30+
): Report {
31+
$dataset = Labeled::fromIterator(new CSV(
32+
$this->appPathResolver->getDatasetPath($testingDatasetFilename),
33+
header: true,
34+
));
35+
36+
$estimator = PersistentModel::load(new Filesystem(
37+
$this->appPathResolver->getModelPath($modelFilename)
38+
));
39+
40+
$predictions = $estimator->predict($dataset);
41+
42+
$report = new AggregateReport([
43+
new MulticlassBreakdown(),
44+
new ConfusionMatrix(),
45+
]);
46+
47+
return $report->generate($predictions, $dataset->labels());
48+
}
49+
}
+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<?php
2+
3+
namespace App\Module\ML\UI\CLI;
4+
5+
use App\Module\ML\Application\Model\SpamModelReport;
6+
use App\Module\ML\Domain\Constant;
7+
use Symfony\Component\Console\Attribute\AsCommand;
8+
use Symfony\Component\Console\Command\Command;
9+
use Symfony\Component\Console\Input\InputInterface;
10+
use Symfony\Component\Console\Output\OutputInterface;
11+
use Symfony\Component\Console\Style\SymfonyStyle;
12+
13+
#[AsCommand(
14+
name: 'app:ml:report',
15+
description: 'Command to generate a report about the model.',
16+
)]
17+
class MlReportCommand extends Command
18+
{
19+
public function __construct(
20+
private readonly SpamModelReport $spamModelReport,
21+
) {
22+
parent::__construct();
23+
}
24+
25+
protected function execute(InputInterface $input, OutputInterface $output): int
26+
{
27+
$io = new SymfonyStyle($input, $output);
28+
29+
$io->info('Generating report...');
30+
31+
$io->writeln($this->spamModelReport->generateReport(
32+
Constant::DEFAULT_SPAM_TESTING_DATASET_FILENAME,
33+
Constant::SPAM_MODEL_FILENAME,
34+
)->toJSON());
35+
36+
$io->success('The report has been generated!');
37+
38+
return Command::SUCCESS;
39+
}
40+
}

0 commit comments

Comments
 (0)