|
| 1 | +from typing import List, Tuple |
| 2 | +from pydantic import BaseModel |
| 3 | + |
| 4 | +from cognee.infrastructure.llm.get_llm_client import get_llm_client |
| 5 | +from cognee.infrastructure.llm.prompts import render_prompt, read_query_prompt |
| 6 | +from cognee.root_dir import get_absolute_path |
| 7 | + |
| 8 | + |
| 9 | +class PotentialNodesAndRelationshipNames(BaseModel): |
| 10 | + """Response model containing lists of potential node names and relationship names.""" |
| 11 | + |
| 12 | + nodes: List[str] |
| 13 | + relationship_names: List[str] |
| 14 | + |
| 15 | + |
| 16 | +async def extract_content_nodes_and_relationship_names( |
| 17 | + content: str, existing_nodes: List[str], n_rounds: int = 2 |
| 18 | +) -> Tuple[List[str], List[str]]: |
| 19 | + """Extracts node names and relationship_names from content through multiple rounds of analysis.""" |
| 20 | + llm_client = get_llm_client() |
| 21 | + all_nodes: List[str] = existing_nodes.copy() |
| 22 | + all_relationship_names: List[str] = [] |
| 23 | + existing_node_set = {node.lower() for node in all_nodes} |
| 24 | + existing_relationship_names = set() |
| 25 | + |
| 26 | + for round_num in range(n_rounds): |
| 27 | + context = { |
| 28 | + "text": content, |
| 29 | + "potential_nodes": existing_nodes, |
| 30 | + "previous_nodes": all_nodes, |
| 31 | + "previous_relationship_names": all_relationship_names, |
| 32 | + "round_number": round_num + 1, |
| 33 | + "total_rounds": n_rounds, |
| 34 | + } |
| 35 | + |
| 36 | + base_directory = get_absolute_path("./tasks/graph/cascade_extract/prompts") |
| 37 | + text_input = render_prompt( |
| 38 | + "extract_graph_relationship_names_prompt_input.txt", |
| 39 | + context, |
| 40 | + base_directory=base_directory, |
| 41 | + ) |
| 42 | + system_prompt = read_query_prompt( |
| 43 | + "extract_graph_relationship_names_prompt_system.txt", base_directory=base_directory |
| 44 | + ) |
| 45 | + response = await llm_client.acreate_structured_output( |
| 46 | + text_input=text_input, |
| 47 | + system_prompt=system_prompt, |
| 48 | + response_model=PotentialNodesAndRelationshipNames, |
| 49 | + ) |
| 50 | + |
| 51 | + for node in response.nodes: |
| 52 | + if node.lower() not in existing_node_set: |
| 53 | + all_nodes.append(node) |
| 54 | + existing_node_set.add(node.lower()) |
| 55 | + |
| 56 | + for relationship_name in response.relationship_names: |
| 57 | + if relationship_name.lower() not in existing_relationship_names: |
| 58 | + all_relationship_names.append(relationship_name) |
| 59 | + existing_relationship_names.add(relationship_name.lower()) |
| 60 | + |
| 61 | + return all_nodes, all_relationship_names |
0 commit comments