Skip to content

Commit d6e470f

Browse files
authored
Update dependency (#39)
* Update pytorch-partial-tagger * Use the Alignments class in pytorch-partial-tagger * Fix an unnecessary alias
1 parent 22df87c commit d6e470f

File tree

4 files changed

+38
-46
lines changed

4 files changed

+38
-46
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"torch<3.0.0,>=2.0.1",
2323
"spacy[transformers]<4.0.0,>=3.3.1",
2424
"spacy-alignments<1.0.0,>=0.8.5",
25-
"pytorch-partial-tagger<1.0.0,>=0.1.12",
25+
"pytorch-partial-tagger<1.0.0,>=0.1.14",
2626
]
2727
dynamic = ["version"]
2828

spacy_partial_tagger/tokenizer.py renamed to spacy_partial_tagger/collator.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
from typing import Optional, Tuple
22

3-
from partial_tagger.data import Alignment, Span
4-
from partial_tagger.data.batch.text import (
5-
BaseTokenizer,
6-
TextBatch,
7-
TransformerTokenizer,
8-
)
3+
from partial_tagger.data import Alignment, Alignments, Span
4+
from partial_tagger.data.collators import BaseCollator, Batch, TransformerCollator
95
from transformers import AutoTokenizer
10-
from transformers.models.bert_japanese import (
11-
BertJapaneseTokenizer as _BertJapaneseTokenizer,
12-
)
6+
from transformers.models.bert_japanese import BertJapaneseTokenizer
137

148
from .util import get_alignments
159

1610

17-
class BertJapaneseTokenizer(BaseTokenizer):
11+
class BertJapaneseCollator(BaseCollator):
1812
def __init__(
1913
self,
20-
tokenizer: _BertJapaneseTokenizer,
14+
tokenizer: BertJapaneseTokenizer,
2115
tokenizer_args: Optional[dict] = None,
2216
):
2317
self.__tokenizer = tokenizer
@@ -29,7 +23,7 @@ def __init__(
2923
}
3024
self.__tokenizer_args["return_offsets_mapping"] = True
3125

32-
def __call__(self, texts: Tuple[str]) -> TextBatch:
26+
def __call__(self, texts: Tuple[str]) -> Tuple[Batch, Alignments]:
3327
batch_encoding = self.__tokenizer(texts, **self.__tokenizer_args)
3428

3529
pad_token_id = self.__tokenizer.pad_token_id
@@ -54,16 +48,16 @@ def __call__(self, texts: Tuple[str]) -> TextBatch:
5448

5549
alignments.append(Alignment(text, char_spans, tuple(token_indices)))
5650

57-
return TextBatch(
58-
tagger_inputs=batch_encoding, mask=mask, alignments=tuple(alignments)
51+
return Batch(tagger_inputs=batch_encoding, mask=mask), Alignments(
52+
tuple(alignments)
5953
)
6054

6155

62-
def get_tokenizer(
56+
def get_collator(
6357
transformer_model_name: str, tokenizer_args: Optional[dict] = None
64-
) -> BaseTokenizer:
58+
) -> BaseCollator:
6559
tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)
66-
if isinstance(tokenizer, _BertJapaneseTokenizer):
67-
return BertJapaneseTokenizer(tokenizer, tokenizer_args)
60+
if isinstance(tokenizer, BertJapaneseTokenizer):
61+
return BertJapaneseCollator(tokenizer, tokenizer_args)
6862
else:
69-
return TransformerTokenizer(tokenizer, tokenizer_args)
63+
return TransformerCollator(tokenizer, tokenizer_args)

spacy_partial_tagger/pipeline.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import srsly
44
import torch
5-
from partial_tagger.data import LabelSet
6-
from partial_tagger.data.batch.tag import TagsBatch
5+
from partial_tagger.data import Alignments, LabelSet
76
from partial_tagger.training import compute_partially_supervised_loss
87
from partial_tagger.utils import create_tag
98
from spacy import util
@@ -51,14 +50,16 @@ def set_annotations(
5150
docs: List[Doc],
5251
tag_indices: Floats2d,
5352
) -> None:
54-
for doc, indices in zip(docs, tag_indices.tolist()):
55-
indices = [index for index in indices if index != self.padding_index]
56-
alignment = doc.user_data["alignment"]
53+
alignments = Alignments(tuple(doc.user_data["alignment"] for doc in docs))
54+
tags_batch = alignments.create_char_based_tags(
55+
tag_indices.tolist(),
56+
label_set=self.label_set,
57+
padding_index=self.padding_index,
58+
)
59+
60+
for doc, tags in zip(docs, tags_batch):
5761
ents = []
58-
for tag in alignment.create_char_based_tags(
59-
tag_indices=indices,
60-
label_set=self.label_set,
61-
):
62+
for tag in tags:
6263
span = doc.char_span(tag.start, tag.start + tag.length, tag.label)
6364
if span:
6465
ents.append(span)
@@ -113,7 +114,7 @@ def get_loss(
113114
scores_pt = xp2torch(scores, requires_grad=True)
114115

115116
char_based_tags = []
116-
alignments = []
117+
temp = []
117118
lengths = []
118119
for example in examples:
119120
tags = tuple(
@@ -124,14 +125,13 @@ def get_loss(
124125

125126
alignment = example.x.user_data["alignment"]
126127
lengths.append(alignment.num_tokens)
127-
alignments.append(alignment)
128+
temp.append(alignment)
128129

129-
tags_batch = TagsBatch(
130-
tags_batch=tuple(char_based_tags),
131-
alignments=alignments,
130+
alignments = Alignments(tuple(temp))
131+
tag_bitmap = torch.tensor(
132+
alignments.get_tag_bitmap(char_based_tags, self.label_set),
133+
device=scores_pt.device,
132134
)
133-
tags_batch.to(scores_pt.device)
134-
tag_bitmap = tags_batch.get_tag_bitmap(self.label_set)
135135

136136
max_length = max(lengths)
137137
mask = torch.tensor(

spacy_partial_tagger/tagger.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from typing import Any, Callable, List, Optional, Tuple, cast
33

44
from partial_tagger.data import LabelSet
5-
from partial_tagger.data.batch.text import BaseTokenizer
65
from spacy.tokens import Doc
76
from spacy.util import registry
87
from thinc.api import Model, get_torch_default_device, torch2xp, xp2torch
98
from thinc.shims import PyTorchGradScaler, PyTorchShim
109
from thinc.types import ArgsKwargs, Floats4d, Ints2d
1110
from thinc.util import convert_recursive, is_torch_array, is_xp_array
1211

13-
from .tokenizer import get_tokenizer
12+
from spacy_partial_tagger.collator import get_collator
13+
1414
from .util import create_tagger
1515

1616

@@ -42,19 +42,17 @@ def forward(
4242
X: List[Doc],
4343
is_train: bool,
4444
) -> Tuple[Tuple[Floats4d, Ints2d], Callable]:
45-
tokenizer: BaseTokenizer = model.attrs["tokenizer"]
46-
47-
text_batch = tokenizer(tuple(doc.text for doc in X))
45+
collator = model.attrs["collator"]
46+
batch, alignments = collator(tuple(doc.text for doc in X))
4847

49-
for doc, alignment in zip(X, text_batch.alignments):
48+
for doc, alignment in zip(X, alignments.alignments):
5049
doc.user_data["alignment"] = alignment
5150

5251
device = get_torch_default_device()
53-
text_batch.to(device)
52+
batch = batch.to(device)
5453

5554
(log_potentials, tag_indices), backward = model.layers[0](
56-
[text_batch.tagger_inputs, text_batch.mask],
57-
is_train,
55+
[batch.tagger_inputs, batch.mask], is_train
5856
)
5957

6058
return (log_potentials, tag_indices), backward
@@ -74,7 +72,7 @@ def init(
7472
mixed_precision = model.attrs["mixed_precision"]
7573
grad_scaler = model.attrs["grad_scaler"]
7674

77-
model.attrs["tokenizer"] = get_tokenizer(transformer_model_name, tokenizer_args)
75+
model.attrs["collator"] = get_collator(transformer_model_name, tokenizer_args)
7876

7977
tagger = create_tagger(transformer_model_name, Y, padding_index)
8078
PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")

0 commit comments

Comments
 (0)