Skip to content

Commit f17ad2c

Browse files
authored
Update pytorch-partial-tagger (#35)
* Remove unnecessary code * Update dependencies * Update incompatible code * Bump version * Update requirements.txt
1 parent 2adbeab commit f17ad2c

File tree

6 files changed

+54
-32
lines changed

6 files changed

+54
-32
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ requires-python = ">=3.8"
88

99
[tool.poetry]
1010
name = "spacy-partial-tagger"
11-
version = "0.14.0"
11+
version = "0.15.0"
1212
description = "Sequence Tagger for Partially Annotated Dataset in spaCy"
1313
authors = ["yasufumi <yasufumi.taniguchi@gmail.com>"]
1414
license = "MIT"
@@ -27,7 +27,7 @@ transformers = {extras = ["ja"], version = "^4.25.1"}
2727
torch = "^2.0.1"
2828
spacy = {extras = ["transformers"], version = "^3.3.1"}
2929
spacy-alignments = "^0.8.5"
30-
pytorch-partial-tagger = "^0.1.6"
30+
pytorch-partial-tagger = "^0.1.7"
3131

3232
[tool.poetry.group.dev.dependencies]
3333
mypy = "^1.3.0"

requirements.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ confection==0.0.4 ; python_version >= "3.8" and python_version < "4.0"
99
coverage[toml]==7.2.7 ; python_version >= "3.8" and python_version < "4.0"
1010
cymem==2.0.7 ; python_version >= "3.8" and python_version < "4.0"
1111
exceptiongroup==1.1.1 ; python_version >= "3.8" and python_version < "3.11"
12-
filelock==3.12.0 ; python_version >= "3.8" and python_version < "4.0"
12+
filelock==3.12.1 ; python_version >= "3.8" and python_version < "4.0"
1313
flake8==4.0.1 ; python_version >= "3.8" and python_version < "4.0"
14-
fsspec==2023.5.0 ; python_version >= "3.8" and python_version < "4.0"
14+
fsspec==2023.6.0 ; python_version >= "3.8" and python_version < "4.0"
1515
fugashi==1.2.1 ; python_version >= "3.8" and python_version < "4.0"
1616
huggingface-hub==0.15.1 ; python_version >= "3.8" and python_version < "4.0"
1717
idna==3.4 ; python_version >= "3.8" and python_version < "4.0"
@@ -32,15 +32,15 @@ packaging==23.1 ; python_version >= "3.8" and python_version < "4.0"
3232
pathspec==0.11.1 ; python_version >= "3.8" and python_version < "4.0"
3333
pathy==0.10.1 ; python_version >= "3.8" and python_version < "4.0"
3434
plac==1.3.5 ; python_version >= "3.8" and python_version < "4.0"
35-
platformdirs==3.5.1 ; python_version >= "3.8" and python_version < "4.0"
35+
platformdirs==3.5.3 ; python_version >= "3.8" and python_version < "4.0"
3636
pluggy==1.0.0 ; python_version >= "3.8" and python_version < "4.0"
3737
preshed==3.0.8 ; python_version >= "3.8" and python_version < "4.0"
3838
pycodestyle==2.8.0 ; python_version >= "3.8" and python_version < "4.0"
39-
pydantic==1.10.8 ; python_version >= "3.8" and python_version < "4.0"
39+
pydantic==1.10.9 ; python_version >= "3.8" and python_version < "4.0"
4040
pyflakes==2.4.0 ; python_version >= "3.8" and python_version < "4.0"
4141
pytest-cov==3.0.0 ; python_version >= "3.8" and python_version < "4.0"
42-
pytest==7.3.1 ; python_version >= "3.8" and python_version < "4.0"
43-
pytorch-partial-tagger==0.1.6 ; python_version >= "3.8" and python_version < "4.0"
42+
pytest==7.3.2 ; python_version >= "3.8" and python_version < "4.0"
43+
pytorch-partial-tagger==0.1.7 ; python_version >= "3.8" and python_version < "4.0"
4444
pyyaml==6.0 ; python_version >= "3.8" and python_version < "4.0"
4545
regex==2023.6.3 ; python_version >= "3.8" and python_version < "4.0"
4646
requests==2.31.0 ; python_version >= "3.8" and python_version < "4.0"
@@ -51,6 +51,7 @@ smart-open==6.3.0 ; python_version >= "3.8" and python_version < "4.0"
5151
spacy-alignments==0.8.6 ; python_version >= "3.8" and python_version < "4.0"
5252
spacy-legacy==3.0.12 ; python_version >= "3.8" and python_version < "4.0"
5353
spacy-loggers==1.0.4 ; python_version >= "3.8" and python_version < "4.0"
54+
spacy-transformers==1.2.4 ; python_version >= "3.8" and python_version < "4.0"
5455
spacy==3.5.3 ; python_version >= "3.8" and python_version < "4.0"
5556
spacy[transformers]==3.5.3 ; python_version >= "3.8" and python_version < "4.0"
5657
srsly==2.4.6 ; python_version >= "3.8" and python_version < "4.0"
@@ -68,5 +69,5 @@ typer==0.7.0 ; python_version >= "3.8" and python_version < "4.0"
6869
typing-extensions==4.6.3 ; python_version >= "3.8" and python_version < "4.0"
6970
unidic-lite==1.0.8 ; python_version >= "3.8" and python_version < "4.0"
7071
unidic==1.1.0 ; python_version >= "3.8" and python_version < "4.0"
71-
urllib3==2.0.2 ; python_version >= "3.8" and python_version < "4.0"
72+
urllib3==2.0.3 ; python_version >= "3.8" and python_version < "4.0"
7273
wasabi==0.10.1 ; python_version >= "3.8" and python_version < "4.0"

spacy_partial_tagger/pipeline.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import srsly
44
import torch
55
from partial_tagger.data import CharBasedTags, LabelSet
6-
from partial_tagger.data.batch.tag import TagFactory
6+
from partial_tagger.data.batch.tag import TagsBatch
7+
from partial_tagger.data.batch.text import create_token_based_tags
78
from partial_tagger.training import expected_entity_ratio_loss
89
from partial_tagger.utils import create_tag
910
from spacy import util
@@ -34,6 +35,8 @@ def __init__(
3435
self.model = model
3536
self.name = name
3637
self.scorer = scorer
38+
self.padding_index = padding_index
39+
self.unknown_index = unknown_index
3740
self.cfg: Dict[str, List[str]] = {"labels": []}
3841

3942
@property
@@ -50,9 +53,10 @@ def set_annotations(
5053
tag_indices: Floats2d,
5154
) -> None:
5255
tokenized_texts = [doc.user_data["tokenized_text"] for doc in docs]
53-
tag_factory = TagFactory(tokenized_texts, self.label_set)
5456

55-
tags_batch = tag_factory.create_char_based_tags(tag_indices)
57+
tags_batch = create_token_based_tags(
58+
tokenized_texts, tag_indices, self.label_set, self.padding_index
59+
)
5660

5761
for doc, tags in zip(docs, tags_batch):
5862
ents = []
@@ -110,28 +114,31 @@ def get_loss(
110114
) -> Tuple[float, Floats4d]:
111115
scores_pt = xp2torch(scores, requires_grad=True)
112116

113-
tokenized_texts = [
114-
example.x.user_data["tokenized_text"] for example in examples
115-
]
116-
tag_factory = TagFactory(tokenized_texts, self.label_set)
117-
118-
tags_batch = []
117+
token_based_tags = []
118+
lengths = []
119119
for example in examples:
120120
tags = tuple(
121121
create_tag(ent.start_char, len(ent.text), ent.label_)
122122
for ent in example.y.ents
123123
)
124-
tags_batch.append(CharBasedTags(tags, example.y.text))
124+
tokenized_text = example.x.user_data["tokenized_text"]
125+
token_based_tags.append(
126+
CharBasedTags(tags, example.x.text).convert_to_token_based(
127+
tokenized_text
128+
)
129+
)
130+
lengths.append(tokenized_text.num_tokens)
131+
132+
tags_batch = TagsBatch(tuple(token_based_tags), self.label_set)
133+
tags_batch.to(scores_pt.device)
134+
tag_bitmap = tags_batch.get_tag_bitmap()
125135

126-
lengths = [text.num_tokens for text in tokenized_texts]
127136
max_length = max(lengths)
128137
mask = torch.tensor(
129138
[[True] * length + [False] * (max_length - length) for length in lengths],
130139
device=scores_pt.device,
131140
)
132141

133-
tag_bitmap = tag_factory.create_tag_bitmap(tuple(tags_batch), scores_pt.device)
134-
135142
loss = expected_entity_ratio_loss(
136143
scores_pt, tag_bitmap, mask, self.label_set.get_outside_index()
137144
)

spacy_partial_tagger/tagger.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from partial_tagger.data import LabelSet
55
from partial_tagger.data.batch.text import BaseTokenizer
6-
from partial_tagger.utils import create_tagger
76
from spacy.tokens import Doc
87
from spacy.util import registry
98
from thinc.api import Model, get_torch_default_device, torch2xp, xp2torch
@@ -12,6 +11,7 @@
1211
from thinc.util import convert_recursive, is_torch_array, is_xp_array
1312

1413
from .tokenizer import get_tokenizer
14+
from .util import create_tagger
1515

1616

1717
@registry.architectures.register("spacy-partial-tagger.PartialTagger.v1")
@@ -51,9 +51,10 @@ def forward(
5151
doc.user_data["tokenized_text"] = text
5252

5353
device = get_torch_default_device()
54+
text_batch.to(device)
5455

5556
(log_potentials, tag_indices), backward = model.layers[0](
56-
[text_batch.get_tagger_inputs(device), text_batch.get_mask(device)],
57+
[text_batch.tagger_inputs, text_batch.mask],
5758
is_train,
5859
)
5960

spacy_partial_tagger/tokenizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
44
from partial_tagger.data import Span, TokenizedText
55
from partial_tagger.data.batch.text import (
66
BaseTokenizer,
77
TextBatch,
8-
Texts,
98
TransformerTokenizer,
109
)
1110
from transformers import AutoTokenizer
@@ -31,7 +30,7 @@ def __init__(
3130
}
3231
self.__tokenizer_args["return_offsets_mapping"] = True
3332

34-
def __call__(self, texts: Texts) -> TextBatch:
33+
def __call__(self, texts: Tuple[str]) -> TextBatch:
3534
batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args)
3635

3736
pad_token_id = self.__tokenizer.pad_token_id

spacy_partial_tagger/util.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
from typing import List, Tuple
22

3-
import catalogue
43
import spacy_alignments as tokenizations
5-
from spacy.util import registry
4+
from partial_tagger.data import LabelSet
5+
from partial_tagger.decoders.viterbi import Contrainer, ViterbiDecoder
6+
from partial_tagger.encoders.transformer import TransformerModelEncoderFactory
7+
from partial_tagger.tagger import SequenceTagger
68
from transformers import PreTrainedTokenizer
79

8-
registry.label_indexers = catalogue.create( # type:ignore
9-
"spacy", "label_indexers", entry_points=True
10-
)
10+
11+
def create_tagger(
12+
model_name: str, label_set: LabelSet, padding_index: int
13+
) -> SequenceTagger:
14+
return SequenceTagger(
15+
TransformerModelEncoderFactory(model_name).create(label_set),
16+
ViterbiDecoder(
17+
padding_index,
18+
Contrainer(
19+
label_set.get_start_states(),
20+
label_set.get_end_states(),
21+
label_set.get_transitions(),
22+
),
23+
),
24+
)
1125

1226

1327
def get_alignments(

0 commit comments

Comments
 (0)