Skip to content

Commit 215a741

Browse files
committed
Moved some methods from SpamModelTrainer to separate classes
1 parent 51c7e0b commit 215a741

File tree

3 files changed

+103
-63
lines changed

3 files changed

+103
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<?php
2+
3+
namespace App\Module\ML\Application\Model;
4+
5+
use App\Module\ML\Domain\Constant;
6+
use Rubix\ML\Classifiers\ClassificationTree;
7+
use Rubix\ML\Classifiers\RandomForest;
8+
use Rubix\ML\Learner;
9+
use Rubix\ML\Pipeline;
10+
use Rubix\ML\Tokenizers\WordStemmer;
11+
use Rubix\ML\Transformers\MultibyteTextNormalizer;
12+
use Rubix\ML\Transformers\StopWordFilter;
13+
use Rubix\ML\Transformers\TfIdfTransformer;
14+
use Rubix\ML\Transformers\WordCountVectorizer;
15+
use Rubix\ML\Transformers\ZScaleStandardizer;
16+
17+
/**
18+
* @see https://docs.rubixml.com/latest/classifiers/random-forest.html
19+
* @see https://docs.rubixml.com/latest/tokenizers/word-stemmer.html
20+
* @see https://docs.rubixml.com/latest/transformers/word-count-vectorizer.html
21+
* @see https://docs.rubixml.com/latest/transformers/stop-word-filter.html
22+
* @see https://docs.rubixml.com/latest/transformers/tf-idf-transformer.html
23+
* @see https://docs.rubixml.com/latest/transformers/z-scale-standardizer.html
24+
*/
25+
class LearnerFactory
26+
{
27+
public static function createLearner(
28+
int $uniqueWordsNum,
29+
int $minDocumentCount = Constant::DEFAULT_MIN_DOCUMENT_COUNT,
30+
float $maxDocumentRatio = Constant::DEFAULT_MAX_DOCUMENT_RATIO,
31+
string $language = Constant::DEFAULT_LANGUAGE,
32+
int $treeMaxHeight = PHP_INT_MAX,
33+
int $treeEstimators = Constant::DEFAULT_TREE_ESTIMATORS,
34+
float $treeRatio = Constant::DEFAULT_TREE_RATIO,
35+
bool $treeBalanced = Constant::DEFAULT_TREE_BALANCED,
36+
): Learner {
37+
return new Pipeline([
38+
new MultibyteTextNormalizer(),
39+
new StopWordFilter(Constant::STOP_WORDS),
40+
new WordCountVectorizer(
41+
maxVocabularySize: $uniqueWordsNum,
42+
minDocumentCount: $minDocumentCount,
43+
maxDocumentRatio: $maxDocumentRatio,
44+
tokenizer: new WordStemmer($language),
45+
),
46+
new TfIdfTransformer(),
47+
new ZScaleStandardizer(),
48+
], new RandomForest(
49+
new ClassificationTree($treeMaxHeight),
50+
estimators: $treeEstimators,
51+
ratio: $treeRatio,
52+
balanced: $treeBalanced,
53+
));
54+
}
55+
}

src/Module/ML/Application/Model/SpamModelTrainer.php

+13-63
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,14 @@
55
use App\Core\Application\Path\AppPathResolver;
66
use App\Core\Infrastructure\Bus\CommandBusInterface;
77
use App\Module\ML\Application\Interaction\Command\SaveSpamDataset\SaveSpamDatasetCommand;
8+
use App\Module\ML\Application\Utils\WordsUtils;
89
use App\Module\ML\Domain\Constant;
9-
use Rubix\ML\Classifiers\ClassificationTree;
10-
use Rubix\ML\Classifiers\RandomForest;
1110
use Rubix\ML\Datasets\Labeled;
1211
use Rubix\ML\Extractors\CSV;
1312
use Rubix\ML\PersistentModel;
1413
use Rubix\ML\Persisters\Filesystem;
15-
use Rubix\ML\Pipeline;
16-
use Rubix\ML\Tokenizers\WordStemmer;
17-
use Rubix\ML\Transformers\MultibyteTextNormalizer;
18-
use Rubix\ML\Transformers\StopWordFilter;
19-
use Rubix\ML\Transformers\TfIdfTransformer;
20-
use Rubix\ML\Transformers\WordCountVectorizer;
21-
use Rubix\ML\Transformers\ZScaleStandardizer;
2214
use Symfony\Component\Console\Style\SymfonyStyle;
2315

24-
use function Symfony\Component\String\u;
25-
26-
/**
27-
* @see https://docs.rubixml.com/latest/classifiers/random-forest.html
28-
* @see https://docs.rubixml.com/latest/tokenizers/word-stemmer.html
29-
* @see https://docs.rubixml.com/latest/transformers/word-count-vectorizer.html
30-
* @see https://docs.rubixml.com/latest/transformers/stop-word-filter.html
31-
* @see https://docs.rubixml.com/latest/transformers/tf-idf-transformer.html
32-
* @see https://docs.rubixml.com/latest/transformers/z-scale-standardizer.html
33-
*/
3416
class SpamModelTrainer
3517
{
3618
private static ?SymfonyStyle $io = null;
@@ -72,24 +54,19 @@ public function train(
7254
self::$io?->info(sprintf('The training dataset contains `%d` samples.', $training->numSamples()));
7355

7456
$modelPath = $this->appPathResolver->getModelPath($outputModelFilename);
57+
$uniqueWordsNum = WordsUtils::countUniqueWords($dataset->samples(), $minWordsCount);
58+
7559
$estimator = new PersistentModel(
76-
new Pipeline([
77-
new MultibyteTextNormalizer(),
78-
new StopWordFilter(Constant::STOP_WORDS),
79-
new WordCountVectorizer(
80-
maxVocabularySize: $this->countUniqueWords($dataset->samples(), $minWordsCount),
81-
minDocumentCount: $minDocumentCount,
82-
maxDocumentRatio: $maxDocumentRatio,
83-
tokenizer: new WordStemmer($language),
84-
),
85-
new TfIdfTransformer(),
86-
new ZScaleStandardizer(),
87-
], new RandomForest(
88-
new ClassificationTree($treeMaxHeight),
89-
estimators: $treeEstimators,
90-
ratio: $treeRatio,
91-
balanced: $treeBalanced,
92-
)),
60+
LearnerFactory::createLearner(
61+
uniqueWordsNum: $uniqueWordsNum,
62+
minDocumentCount: $minDocumentCount,
63+
maxDocumentRatio: $maxDocumentRatio,
64+
language: $language,
65+
treeMaxHeight: $treeMaxHeight,
66+
treeEstimators: $treeEstimators,
67+
treeRatio: $treeRatio,
68+
treeBalanced: $treeBalanced,
69+
),
9370
new Filesystem($modelPath, $history)
9471
);
9572

@@ -124,31 +101,4 @@ private function saveDataset(
124101

125102
self::$io?->info(sprintf('Saved dataset `%s`.', $outputDatasetFilename));
126103
}
127-
128-
/**
129-
* @param array<string[]> $samples
130-
*/
131-
private function countUniqueWords(array $samples, int $minCount): int
132-
{
133-
$words = [];
134-
135-
foreach ($samples as $sample) {
136-
$items = array_filter(
137-
preg_split('/\s/', $sample[0]),
138-
fn (string $word) => !empty($word)
139-
);
140-
141-
foreach ($items as $item) {
142-
$word = u($item)->snake()->toString();
143-
144-
if (!isset($words[$word])) {
145-
$words[$word] = 1;
146-
} else {
147-
++$words[$word];
148-
}
149-
}
150-
}
151-
152-
return count(array_filter($words, fn (int $count) => $count >= $minCount));
153-
}
154104
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<?php
2+
3+
namespace App\Module\ML\Application\Utils;
4+
5+
use function Symfony\Component\String\u;
6+
7+
class WordsUtils
8+
{
9+
/**
10+
* @param array<string[]> $samples
11+
*/
12+
public static function countUniqueWords(array $samples, int $minCount): int
13+
{
14+
$words = [];
15+
16+
foreach ($samples as $sample) {
17+
$items = array_filter(
18+
preg_split('/\s/', $sample[0]),
19+
fn (string $word) => !empty($word)
20+
);
21+
22+
foreach ($items as $item) {
23+
$word = u($item)->snake()->toString();
24+
25+
if (!isset($words[$word])) {
26+
$words[$word] = 1;
27+
} else {
28+
++$words[$word];
29+
}
30+
}
31+
}
32+
33+
return count(array_filter($words, fn (int $count) => $count >= $minCount));
34+
}
35+
}

0 commit comments

Comments
 (0)