Skip to content

Commit 632d7f2

Browse files
committed
Deal with missing relationships, add tests
1 parent 1462116 commit 632d7f2

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def _enforce_nodes(
403403
schema_entity = schema.entities.get(node.label)
404404
if not schema_entity:
405405
continue
406-
allowed_props = schema_entity.get("properties", [])
406+
allowed_props = schema_entity.get("properties")
407407
if allowed_props:
408408
filtered_props = self._enforce_properties(
409409
node.properties, allowed_props
@@ -439,16 +439,17 @@ def _enforce_relationships(
439439
if self.enforce_schema != SchemaEnforcementMode.STRICT:
440440
return extracted_relationships
441441

442+
if schema.relations is None:
443+
return extracted_relationships
444+
442445
valid_rels = []
443446

444447
valid_nodes = {node.id: node.label for node in filtered_nodes}
445448

446449
potential_schema = schema.potential_schema
447450

448451
for rel in extracted_relationships:
449-
schema_relation = (
450-
schema.relations.get(rel.type) if schema.relations else None
451-
)
452+
schema_relation = schema.relations.get(rel.type)
452453
if not schema_relation:
453454
continue
454455

@@ -473,7 +474,7 @@ def _enforce_relationships(
473474
if not tuple_valid and not reverse_tuple_valid:
474475
continue
475476

476-
allowed_props = schema_relation.get("properties", [])
477+
allowed_props = schema_relation.get("properties")
477478
if allowed_props:
478479
filtered_props = self._enforce_properties(rel.properties, allowed_props)
479480
else:

src/neo4j_graphrag/experimental/components/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class SchemaConfig(DataModel):
9999
@model_validator(mode="before")
100100
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
101101
entities = data.get("entities", {}).keys()
102-
relations = data.get("relations", {}).keys()
102+
relations = (data.get("relations") or {}).keys()
103103
potential_schema = data.get("potential_schema", [])
104104

105105
if potential_schema:

tests/unit/experimental/components/test_entity_relation_extractor.py

+68
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,74 @@ async def test_extractor_schema_enforcement_inverted_relation_direction() -> Non
564564
assert result.relationships[0].end_node_id.split(":")[1] == "2"
565565

566566

567+
@pytest.mark.asyncio
568+
async def test_extractor_schema_enforcement_none_relationships_in_schema() -> None:
569+
llm = MagicMock(spec=LLMInterface)
570+
llm.ainvoke.return_value = LLMResponse(
571+
content='{"nodes":[{"id":"1","label":"Person","properties":'
572+
'{"name":"Alice"}},{"id":"2","label":"Person","properties":'
573+
'{"name":"Bob"}}],'
574+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
575+
'"type":"FRIENDS_WITH","properties":{}}]}'
576+
)
577+
578+
extractor = LLMEntityRelationExtractor(
579+
llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT
580+
)
581+
582+
schema = SchemaConfig(
583+
entities={
584+
"Person": {
585+
"label": "Person",
586+
"properties": [{"name": "name", "type": "STRING"}],
587+
}
588+
},
589+
relations=None,
590+
potential_schema=None,
591+
)
592+
593+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
594+
595+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
596+
597+
assert len(result.nodes) == 2
598+
assert len(result.relationships) == 1
599+
assert result.relationships[0].type == "FRIENDS_WITH"
600+
601+
602+
@pytest.mark.asyncio
603+
async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> None:
604+
llm = MagicMock(spec=LLMInterface)
605+
llm.ainvoke.return_value = LLMResponse(
606+
content='{"nodes":[{"id":"1","label":"Person","properties":'
607+
'{"name":"Alice"}},{"id":"2","label":"Person","properties":'
608+
'{"name":"Bob"}}],'
609+
'"relationships":[{"start_node_id":"1","end_node_id":"2",'
610+
'"type":"FRIENDS_WITH","properties":{}}]}'
611+
)
612+
613+
extractor = LLMEntityRelationExtractor(
614+
llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT
615+
)
616+
617+
schema = SchemaConfig(
618+
entities={
619+
"Person": {
620+
"label": "Person",
621+
"properties": [{"name": "name", "type": "STRING"}],
622+
}
623+
},
624+
relations={},
625+
potential_schema=None,
626+
)
627+
628+
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
629+
630+
result: Neo4jGraph = await extractor.run(chunks, schema=schema)
631+
632+
assert len(result.relationships) == 0
633+
634+
567635
def test_fix_invalid_json_empty_result() -> None:
568636
json_string = "invalid json"
569637

0 commit comments

Comments
 (0)