Skip to content

Commit 1eae096

Browse files
committed
Added command to predictions
1 parent 1efab86 commit 1eae096

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ composer.phar
44
/resources/training.csv
55
/resources/test.csv
66
/resources/model.rbx
7+
/resources/predictions.csv
78

89
###> symfony/framework-bundle ###
910
/.env.local

src/Domain/FileNames.php

+1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ final class FileNames
99
public const string TRAINING_SET = 'training.csv';
1010
public const string TEST_SET = 'test.csv';
1111
public const string MODEL_FILENAME = 'model.rbx';
12+
public const string PREDICTIONS_FILENAME = 'predictions.csv';
1213
}

src/UI/CLI/PredictCommand.php

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
<?php
2+
3+
namespace App\UI\CLI;
4+
5+
use App\Application\Path\AppPathResolver;
6+
use App\Domain\FileNames;
7+
use League\Csv\CannotInsertRecord;
8+
use League\Csv\Exception;
9+
use League\Csv\UnavailableStream;
10+
use League\Csv\Writer;
11+
use Rubix\ML\Datasets\Labeled;
12+
use Rubix\ML\Extractors\CSV;
13+
use Rubix\ML\PersistentModel;
14+
use Rubix\ML\Persisters\Filesystem;
15+
use Symfony\Component\Console\Attribute\AsCommand;
16+
use Symfony\Component\Console\Command\Command;
17+
use Symfony\Component\Console\Input\InputInterface;
18+
use Symfony\Component\Console\Output\OutputInterface;
19+
use Symfony\Component\Console\Style\SymfonyStyle;
20+
21+
#[AsCommand(
22+
name: 'app:predict',
23+
)]
24+
class PredictCommand extends Command
25+
{
26+
public function __construct(
27+
private readonly AppPathResolver $appPathResolver,
28+
) {
29+
parent::__construct();
30+
}
31+
32+
/**
33+
* @throws UnavailableStream
34+
* @throws CannotInsertRecord
35+
* @throws Exception
36+
*/
37+
protected function execute(InputInterface $input, OutputInterface $output): int
38+
{
39+
$io = new SymfonyStyle($input, $output);
40+
41+
$dataset = Labeled::fromIterator(new CSV(
42+
$this->appPathResolver->getResourcesPath(FileNames::TEST_SET),
43+
header: false,
44+
))->transformLabels(fn ($value) => (float) $value);
45+
46+
$model = PersistentModel::load(new Filesystem(
47+
$this->appPathResolver->getResourcesPath(FileNames::MODEL_FILENAME),
48+
));
49+
50+
$writer = Writer::createFromPath($this->appPathResolver->getResourcesPath(
51+
FileNames::PREDICTIONS_FILENAME,
52+
), 'w');
53+
54+
$predictions = $model->predict($dataset);
55+
56+
foreach ($predictions as $prediction) {
57+
$writer->insertOne([$prediction]);
58+
}
59+
60+
$io->success('Results saved to file');
61+
62+
return Command::SUCCESS;
63+
}
64+
}

0 commit comments

Comments
 (0)