@@ -252,11 +252,13 @@ def __init__(
252
252
entity_label_types : Union [str , Sequence [str ], dict [str , Optional [set [str ]]]],
253
253
entity_pair_labels : Optional [set [tuple [str , str ]]] = None ,
254
254
entity_threshold : Optional [float ] = None ,
255
+ max_allowed_tokens_between_entities : Optional [int ] = 20 ,
256
+ max_surrounding_context_length : Optional [int ] = 10 ,
255
257
cross_augmentation : bool = True ,
256
258
encoding_strategy : EncodingStrategy = TypedEntityMarker (),
257
259
zero_tag_value : str = "O" ,
258
260
allow_unk_tag : bool = True ,
259
- ** classifierargs ,
261
+ ** classifierargs : Any ,
260
262
) -> None :
261
263
"""Initializes a `RelationClassifier`.
262
264
@@ -267,6 +269,8 @@ def __init__(
267
269
entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'.
268
270
entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient).
269
271
entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model.
272
+ max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled.
273
+ max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled.
270
274
cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.
271
275
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
272
276
zero_tag_value: The label to use for out-of-class relations
@@ -302,6 +306,8 @@ def __init__(
302
306
self .entity_pair_labels = entity_pair_labels
303
307
304
308
self .entity_threshold = entity_threshold
309
+ self .max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
310
+ self .max_surrounding_context_length = max_surrounding_context_length
305
311
self .cross_augmentation = cross_augmentation
306
312
self .encoding_strategy = encoding_strategy
307
313
@@ -393,12 +399,41 @@ def _entity_pair_permutations(
393
399
394
400
yield head , tail , gold_label
395
401
402
+ @staticmethod
403
+ def _truncate_context_around_entities (
404
+ encoded_sentence_tokens : list [str ],
405
+ head_idx : int ,
406
+ tail_idx : int ,
407
+ context_length : int ,
408
+ ) -> list [str ]:
409
+ """Truncates the encoded sentence to include the head and tail entity and their surrounding context.
410
+
411
+ The context, in between the entity pairs will always be included.
412
+
413
+ Args:
414
+ encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence.
415
+ head_idx: The index of the head entity in the token list.
416
+ tail_idx: The index of the tail entity in the token list.
417
+ context_length: The maximum number of tokens to include as surrounding context around the head and tail entities.
418
+
419
+ Returns:
420
+ The tokens of the truncated sentence.
421
+ """
422
+ begin_slice : int = min (head_idx , tail_idx )
423
+ end_slice : int = max (head_idx , tail_idx )
424
+
425
+ # Preserve context around the entities. Always include their in-between context.
426
+ begin_slice = max (begin_slice - context_length , 0 )
427
+ end_slice = min (end_slice + context_length + 1 , len (encoded_sentence_tokens ))
428
+
429
+ return encoded_sentence_tokens [begin_slice :end_slice ]
430
+
396
431
def _encode_sentence (
397
432
self ,
398
433
head : _Entity ,
399
434
tail : _Entity ,
400
435
gold_label : Optional [str ] = None ,
401
- ) -> EncodedSentence :
436
+ ) -> Optional [ EncodedSentence ] :
402
437
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.
403
438
404
439
If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
@@ -414,6 +449,12 @@ def _encode_sentence(
414
449
original_sentence : Sentence = head .span .sentence
415
450
assert original_sentence is tail .span .sentence , "The head and tail need to come from the same sentence."
416
451
452
+ # Sanity check: Do not create a labeled span if one entity contains the other
453
+ if head .span [0 ].idx <= tail .span [0 ].idx and head .span [- 1 ].idx >= tail .span [- 1 ].idx :
454
+ return None
455
+ if head .span [0 ].idx >= tail .span [0 ].idx and head .span [- 1 ].idx <= tail .span [- 1 ].idx :
456
+ return None
457
+
417
458
# Pre-compute non-leading head and tail tokens for entity masking
418
459
non_leading_head_tokens : list [Token ] = head .span .tokens [1 :]
419
460
non_leading_tail_tokens : list [Token ] = tail .span .tokens [1 :]
@@ -422,11 +463,15 @@ def _encode_sentence(
422
463
# since there may be multiple occurrences of the same entity mentioned in the sentence.
423
464
# Therefore, we use the span's position in the sentence.
424
465
encoded_sentence_tokens : list [str ] = []
466
+ head_idx : Optional [int ] = None
467
+ tail_idx : Optional [int ] = None
425
468
for token in original_sentence :
426
469
if token is head .span [0 ]:
470
+ head_idx = len (encoded_sentence_tokens )
427
471
encoded_sentence_tokens .append (self .encoding_strategy .encode_head (head .span , head .label ))
428
472
429
473
elif token is tail .span [0 ]:
474
+ tail_idx = len (encoded_sentence_tokens )
430
475
encoded_sentence_tokens .append (self .encoding_strategy .encode_tail (tail .span , tail .label ))
431
476
432
477
elif all (
@@ -435,6 +480,27 @@ def _encode_sentence(
435
480
):
436
481
encoded_sentence_tokens .append (token .text )
437
482
483
+ msg : str
484
+ if head_idx is None :
485
+ msg = f"The head entity ({ head !r} ) is not located inside the original sentence ({ original_sentence !r} )."
486
+ raise AssertionError (msg )
487
+ if tail_idx is None :
488
+ msg = f"The tail entity ({ tail !r} ) is not located inside the original sentence ({ original_sentence !r} )."
489
+ raise AssertionError (msg )
490
+
491
+ # Filter cases in which the distance between the two entities is too large
492
+ if (
493
+ self .max_allowed_tokens_between_entities is not None
494
+ and abs (head_idx - tail_idx ) > self .max_allowed_tokens_between_entities
495
+ ):
496
+ return None
497
+
498
+ # Remove excess tokens left and right of entity pair to make encoded sentence shorter
499
+ if self .max_surrounding_context_length is not None :
500
+ encoded_sentence_tokens = self ._truncate_context_around_entities (
501
+ encoded_sentence_tokens , head_idx , tail_idx , self .max_surrounding_context_length
502
+ )
503
+
438
504
# Create masked sentence
439
505
encoded_sentence : EncodedSentence = EncodedSentence (
440
506
" " .join (encoded_sentence_tokens ), use_tokenizer = SpaceTokenizer ()
@@ -445,6 +511,7 @@ def _encode_sentence(
445
511
# Using the sentence label instead of annotating a separate `Relation` object is easier to manage since,
446
512
# during prediction, the forward pass does not need any knowledge about the entities in the sentence.
447
513
encoded_sentence .add_label (typename = self .label_type , value = gold_label , score = 1.0 )
514
+
448
515
encoded_sentence .copy_context_from_sentence (original_sentence )
449
516
return encoded_sentence
450
517
@@ -469,13 +536,15 @@ def _encode_sentence_for_inference(
469
536
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
470
537
"""
471
538
for head , tail , gold_label in self ._entity_pair_permutations (sentence ):
472
- masked_sentence : EncodedSentence = self ._encode_sentence (
539
+ masked_sentence : Optional [ EncodedSentence ] = self ._encode_sentence (
473
540
head = head ,
474
541
tail = tail ,
475
542
gold_label = gold_label if gold_label is not None else self .zero_tag_value ,
476
543
)
477
544
original_relation : Relation = Relation (first = head .span , second = tail .span )
478
- yield masked_sentence , original_relation
545
+
546
+ if masked_sentence is not None :
547
+ yield masked_sentence , original_relation
479
548
480
549
def _encode_sentence_for_training (self , sentence : Sentence ) -> Iterator [EncodedSentence ]:
481
550
"""Create Encoded Sentences and Relation pairs for Training.
@@ -492,13 +561,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
492
561
else :
493
562
continue # Skip generated data points that do not express an originally annotated relation
494
563
495
- masked_sentence : EncodedSentence = self ._encode_sentence (
564
+ masked_sentence : Optional [ EncodedSentence ] = self ._encode_sentence (
496
565
head = head ,
497
566
tail = tail ,
498
567
gold_label = gold_label ,
499
568
)
500
569
501
- yield masked_sentence
570
+ if masked_sentence is not None :
571
+ yield masked_sentence
502
572
503
573
def transform_sentence (self , sentences : Union [Sentence , list [Sentence ]]) -> list [EncodedSentence ]:
504
574
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
@@ -702,6 +772,8 @@ def _get_state_dict(self) -> dict[str, Any]:
702
772
"entity_label_types" : self .entity_label_types ,
703
773
"entity_pair_labels" : self .entity_pair_labels ,
704
774
"entity_threshold" : self .entity_threshold ,
775
+ "max_allowed_tokens_between_entities" : self .max_allowed_tokens_between_entities ,
776
+ "max_surrounding_context_length" : self .max_surrounding_context_length ,
705
777
"cross_augmentation" : self .cross_augmentation ,
706
778
"encoding_strategy" : self .encoding_strategy ,
707
779
"zero_tag_value" : self .zero_tag_value ,
@@ -719,6 +791,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
719
791
entity_label_types = state ["entity_label_types" ],
720
792
entity_pair_labels = state ["entity_pair_labels" ],
721
793
entity_threshold = state ["entity_threshold" ],
794
+ max_allowed_tokens_between_entities = state .get ("max_allowed_tokens_between_entities" ),
795
+ max_surrounding_context_length = state .get ("max_surrounding_context_length" ),
722
796
cross_augmentation = state ["cross_augmentation" ],
723
797
encoding_strategy = state ["encoding_strategy" ],
724
798
zero_tag_value = state ["zero_tag_value" ],
0 commit comments