Skip to content

Commit 8acd698

Browse files
authored
Merge pull request #3593 from flairNLP/filter_relations
Optimize RelationClassifier by adding the option to filter long sentences and truncate context
2 parents ae592bf + 863d903 commit 8acd698

File tree

7 files changed

+194
-37
lines changed

7 files changed

+194
-37
lines changed

flair/models/regexp_tagger.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
import typing
32
from dataclasses import dataclass, field
43
from typing import Union
54

flair/models/relation_classifier_model.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,13 @@ def __init__(
252252
entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]],
253253
entity_pair_labels: Optional[set[tuple[str, str]]] = None,
254254
entity_threshold: Optional[float] = None,
255+
max_allowed_tokens_between_entities: Optional[int] = 20,
256+
max_surrounding_context_length: Optional[int] = 10,
255257
cross_augmentation: bool = True,
256258
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
257259
zero_tag_value: str = "O",
258260
allow_unk_tag: bool = True,
259-
**classifierargs,
261+
**classifierargs: Any,
260262
) -> None:
261263
"""Initializes a `RelationClassifier`.
262264
@@ -267,6 +269,8 @@ def __init__(
267269
entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'.
268270
entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient).
269271
entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model.
272+
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled.
273+
max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled.
270274
cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.
271275
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
272276
zero_tag_value: The label to use for out-of-class relations
@@ -302,6 +306,8 @@ def __init__(
302306
self.entity_pair_labels = entity_pair_labels
303307

304308
self.entity_threshold = entity_threshold
309+
self.max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
310+
self.max_surrounding_context_length = max_surrounding_context_length
305311
self.cross_augmentation = cross_augmentation
306312
self.encoding_strategy = encoding_strategy
307313

@@ -393,12 +399,41 @@ def _entity_pair_permutations(
393399

394400
yield head, tail, gold_label
395401

402+
@staticmethod
403+
def _truncate_context_around_entities(
404+
encoded_sentence_tokens: list[str],
405+
head_idx: int,
406+
tail_idx: int,
407+
context_length: int,
408+
) -> list[str]:
409+
"""Truncates the encoded sentence to include the head and tail entity and their surrounding context.
410+
411+
The context, in between the entity pairs will always be included.
412+
413+
Args:
414+
encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence.
415+
head_idx: The index of the head entity in the token list.
416+
tail_idx: The index of the tail entity in the token list.
417+
context_length: The maximum number of tokens to include as surrounding context around the head and tail entities.
418+
419+
Returns:
420+
The tokens of the truncated sentence.
421+
"""
422+
begin_slice: int = min(head_idx, tail_idx)
423+
end_slice: int = max(head_idx, tail_idx)
424+
425+
# Preserve context around the entities. Always include their in-between context.
426+
begin_slice = max(begin_slice - context_length, 0)
427+
end_slice = min(end_slice + context_length + 1, len(encoded_sentence_tokens))
428+
429+
return encoded_sentence_tokens[begin_slice:end_slice]
430+
396431
def _encode_sentence(
397432
self,
398433
head: _Entity,
399434
tail: _Entity,
400435
gold_label: Optional[str] = None,
401-
) -> EncodedSentence:
436+
) -> Optional[EncodedSentence]:
402437
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.
403438
404439
If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
@@ -414,6 +449,12 @@ def _encode_sentence(
414449
original_sentence: Sentence = head.span.sentence
415450
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence."
416451

452+
# Sanity check: Do not create a labeled span if one entity contains the other
453+
if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx:
454+
return None
455+
if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx:
456+
return None
457+
417458
# Pre-compute non-leading head and tail tokens for entity masking
418459
non_leading_head_tokens: list[Token] = head.span.tokens[1:]
419460
non_leading_tail_tokens: list[Token] = tail.span.tokens[1:]
@@ -422,11 +463,15 @@ def _encode_sentence(
422463
# since there may be multiple occurrences of the same entity mentioned in the sentence.
423464
# Therefore, we use the span's position in the sentence.
424465
encoded_sentence_tokens: list[str] = []
466+
head_idx: Optional[int] = None
467+
tail_idx: Optional[int] = None
425468
for token in original_sentence:
426469
if token is head.span[0]:
470+
head_idx = len(encoded_sentence_tokens)
427471
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))
428472

429473
elif token is tail.span[0]:
474+
tail_idx = len(encoded_sentence_tokens)
430475
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label))
431476

432477
elif all(
@@ -435,6 +480,27 @@ def _encode_sentence(
435480
):
436481
encoded_sentence_tokens.append(token.text)
437482

483+
msg: str
484+
if head_idx is None:
485+
msg = f"The head entity ({head!r}) is not located inside the original sentence ({original_sentence!r})."
486+
raise AssertionError(msg)
487+
if tail_idx is None:
488+
msg = f"The tail entity ({tail!r}) is not located inside the original sentence ({original_sentence!r})."
489+
raise AssertionError(msg)
490+
491+
# Filter cases in which the distance between the two entities is too large
492+
if (
493+
self.max_allowed_tokens_between_entities is not None
494+
and abs(head_idx - tail_idx) > self.max_allowed_tokens_between_entities
495+
):
496+
return None
497+
498+
# Remove excess tokens left and right of entity pair to make encoded sentence shorter
499+
if self.max_surrounding_context_length is not None:
500+
encoded_sentence_tokens = self._truncate_context_around_entities(
501+
encoded_sentence_tokens, head_idx, tail_idx, self.max_surrounding_context_length
502+
)
503+
438504
# Create masked sentence
439505
encoded_sentence: EncodedSentence = EncodedSentence(
440506
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer()
@@ -445,6 +511,7 @@ def _encode_sentence(
445511
# Using the sentence label instead of annotating a separate `Relation` object is easier to manage since,
446512
# during prediction, the forward pass does not need any knowledge about the entities in the sentence.
447513
encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0)
514+
448515
encoded_sentence.copy_context_from_sentence(original_sentence)
449516
return encoded_sentence
450517

@@ -469,13 +536,15 @@ def _encode_sentence_for_inference(
469536
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
470537
"""
471538
for head, tail, gold_label in self._entity_pair_permutations(sentence):
472-
masked_sentence: EncodedSentence = self._encode_sentence(
539+
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
473540
head=head,
474541
tail=tail,
475542
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
476543
)
477544
original_relation: Relation = Relation(first=head.span, second=tail.span)
478-
yield masked_sentence, original_relation
545+
546+
if masked_sentence is not None:
547+
yield masked_sentence, original_relation
479548

480549
def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
481550
"""Create Encoded Sentences and Relation pairs for Training.
@@ -492,13 +561,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
492561
else:
493562
continue # Skip generated data points that do not express an originally annotated relation
494563

495-
masked_sentence: EncodedSentence = self._encode_sentence(
564+
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
496565
head=head,
497566
tail=tail,
498567
gold_label=gold_label,
499568
)
500569

501-
yield masked_sentence
570+
if masked_sentence is not None:
571+
yield masked_sentence
502572

503573
def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
504574
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
@@ -702,6 +772,8 @@ def _get_state_dict(self) -> dict[str, Any]:
702772
"entity_label_types": self.entity_label_types,
703773
"entity_pair_labels": self.entity_pair_labels,
704774
"entity_threshold": self.entity_threshold,
775+
"max_allowed_tokens_between_entities": self.max_allowed_tokens_between_entities,
776+
"max_surrounding_context_length": self.max_surrounding_context_length,
705777
"cross_augmentation": self.cross_augmentation,
706778
"encoding_strategy": self.encoding_strategy,
707779
"zero_tag_value": self.zero_tag_value,
@@ -719,6 +791,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
719791
entity_label_types=state["entity_label_types"],
720792
entity_pair_labels=state["entity_pair_labels"],
721793
entity_threshold=state["entity_threshold"],
794+
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"),
795+
max_surrounding_context_length=state.get("max_surrounding_context_length"),
722796
cross_augmentation=state["cross_augmentation"],
723797
encoding_strategy=state["encoding_strategy"],
724798
zero_tag_value=state["zero_tag_value"],

flair/trainers/trainer.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -966,14 +966,18 @@ def _initialize_model_card(self, **training_parameters):
966966
except ImportError:
967967
pass
968968

969-
# remember all parameters used in train() call
970-
model_card["training_parameters"] = {
971-
k: str(v) if isinstance(v, Path) else v for k, v in training_parameters.items()
972-
}
973-
974-
model_card["training_parameters"] = {
975-
k: f"{v.__module__}.{v.__name__}" if inspect.isclass(v) else v for k, v in training_parameters.items()
976-
}
969+
# remember the training parameters
970+
model_card["training_parameters"] = {}
971+
for k, v in training_parameters.items():
972+
973+
# special rule for Path variables to make sure models can be deserialized on other OS
974+
if isinstance(v, Path):
975+
v = str(v)
976+
# classes are only serialized as names
977+
if inspect.isclass(v):
978+
v = f"{v.__module__}.{v.__name__}"
979+
980+
model_card["training_parameters"][k] = v
977981

978982
plugins = [plugin.get_state() for plugin in model_card["training_parameters"]["plugins"]]
979983
model_card["training_parameters"]["plugins"] = plugins

0 commit comments

Comments
 (0)