|
5 | 5 | use App\Core\Application\Path\AppPathResolver;
|
6 | 6 | use App\Core\Infrastructure\Bus\CommandBusInterface;
|
7 | 7 | use App\Module\ML\Application\Interaction\Command\SaveSpamDataset\SaveSpamDatasetCommand;
|
| 8 | +use App\Module\ML\Application\Utils\WordsUtils; |
8 | 9 | use App\Module\ML\Domain\Constant;
|
9 |
| -use Rubix\ML\Classifiers\ClassificationTree; |
10 |
| -use Rubix\ML\Classifiers\RandomForest; |
11 | 10 | use Rubix\ML\Datasets\Labeled;
|
12 | 11 | use Rubix\ML\Extractors\CSV;
|
13 | 12 | use Rubix\ML\PersistentModel;
|
14 | 13 | use Rubix\ML\Persisters\Filesystem;
|
15 |
| -use Rubix\ML\Pipeline; |
16 |
| -use Rubix\ML\Tokenizers\WordStemmer; |
17 |
| -use Rubix\ML\Transformers\MultibyteTextNormalizer; |
18 |
| -use Rubix\ML\Transformers\StopWordFilter; |
19 |
| -use Rubix\ML\Transformers\TfIdfTransformer; |
20 |
| -use Rubix\ML\Transformers\WordCountVectorizer; |
21 |
| -use Rubix\ML\Transformers\ZScaleStandardizer; |
22 | 14 | use Symfony\Component\Console\Style\SymfonyStyle;
|
23 | 15 |
|
24 |
| -use function Symfony\Component\String\u; |
25 |
| - |
26 |
| -/** |
27 |
| - * @see https://docs.rubixml.com/latest/classifiers/random-forest.html |
28 |
| - * @see https://docs.rubixml.com/latest/tokenizers/word-stemmer.html |
29 |
| - * @see https://docs.rubixml.com/latest/transformers/word-count-vectorizer.html |
30 |
| - * @see https://docs.rubixml.com/latest/transformers/stop-word-filter.html |
31 |
| - * @see https://docs.rubixml.com/latest/transformers/tf-idf-transformer.html |
32 |
| - * @see https://docs.rubixml.com/latest/transformers/z-scale-standardizer.html |
33 |
| - */ |
34 | 16 | class SpamModelTrainer
|
35 | 17 | {
|
36 | 18 | private static ?SymfonyStyle $io = null;
|
@@ -72,24 +54,19 @@ public function train(
|
72 | 54 | self::$io?->info(sprintf('The training dataset contains `%d` samples.', $training->numSamples()));
|
73 | 55 |
|
74 | 56 | $modelPath = $this->appPathResolver->getModelPath($outputModelFilename);
|
| 57 | + $uniqueWordsNum = WordsUtils::countUniqueWords($dataset->samples(), $minWordsCount); |
| 58 | + |
75 | 59 | $estimator = new PersistentModel(
|
76 |
| - new Pipeline([ |
77 |
| - new MultibyteTextNormalizer(), |
78 |
| - new StopWordFilter(Constant::STOP_WORDS), |
79 |
| - new WordCountVectorizer( |
80 |
| - maxVocabularySize: $this->countUniqueWords($dataset->samples(), $minWordsCount), |
81 |
| - minDocumentCount: $minDocumentCount, |
82 |
| - maxDocumentRatio: $maxDocumentRatio, |
83 |
| - tokenizer: new WordStemmer($language), |
84 |
| - ), |
85 |
| - new TfIdfTransformer(), |
86 |
| - new ZScaleStandardizer(), |
87 |
| - ], new RandomForest( |
88 |
| - new ClassificationTree($treeMaxHeight), |
89 |
| - estimators: $treeEstimators, |
90 |
| - ratio: $treeRatio, |
91 |
| - balanced: $treeBalanced, |
92 |
| - )), |
| 60 | + LearnerFactory::createLearner( |
| 61 | + uniqueWordsNum: $uniqueWordsNum, |
| 62 | + minDocumentCount: $minDocumentCount, |
| 63 | + maxDocumentRatio: $maxDocumentRatio, |
| 64 | + language: $language, |
| 65 | + treeMaxHeight: $treeMaxHeight, |
| 66 | + treeEstimators: $treeEstimators, |
| 67 | + treeRatio: $treeRatio, |
| 68 | + treeBalanced: $treeBalanced, |
| 69 | + ), |
93 | 70 | new Filesystem($modelPath, $history)
|
94 | 71 | );
|
95 | 72 |
|
@@ -124,31 +101,4 @@ private function saveDataset(
|
124 | 101 |
|
125 | 102 | self::$io?->info(sprintf('Saved dataset `%s`.', $outputDatasetFilename));
|
126 | 103 | }
|
127 |
| - |
128 |
| - /** |
129 |
| - * @param array<string[]> $samples |
130 |
| - */ |
131 |
| - private function countUniqueWords(array $samples, int $minCount): int |
132 |
| - { |
133 |
| - $words = []; |
134 |
| - |
135 |
| - foreach ($samples as $sample) { |
136 |
| - $items = array_filter( |
137 |
| - preg_split('/\s/', $sample[0]), |
138 |
| - fn (string $word) => !empty($word) |
139 |
| - ); |
140 |
| - |
141 |
| - foreach ($items as $item) { |
142 |
| - $word = u($item)->snake()->toString(); |
143 |
| - |
144 |
| - if (!isset($words[$word])) { |
145 |
| - $words[$word] = 1; |
146 |
| - } else { |
147 |
| - ++$words[$word]; |
148 |
| - } |
149 |
| - } |
150 |
| - } |
151 |
| - |
152 |
| - return count(array_filter($words, fn (int $count) => $count >= $minCount)); |
153 |
| - } |
154 | 104 | }
|
0 commit comments