DRIFT Search
In [1]:
Copied!
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
In [2]:
Copied!
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")
entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")
entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[2], line 38 35 entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet") 36 entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet") ---> 38 entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL) 40 # load description embeddings to an in-memory lancedb vectorstore 41 # to connect to a remote db, specify url and port values. 42 description_embedding_store = LanceDBVectorStore( 43 collection_name="default-entity-description", 44 ) File ~/work/graphrag/graphrag/graphrag/query/indexer_adapters.py:154, in read_indexer_entities(final_nodes, final_entities, community_level) 149 final_df = nodes_df.merge(entities_df, on="id", how="inner").drop_duplicates( 150 subset=["id"] 151 ) 153 # read entity dataframe to knowledge model objects --> 154 return read_entities( 155 df=final_df, 156 id_col="id", 157 title_col="title", 158 type_col="type", 159 short_id_col="human_readable_id", 160 description_col="description", 161 community_col="community", 162 rank_col="degree", 163 name_embedding_col=None, 164 description_embedding_col="description_embedding", 165 text_unit_ids_col="text_unit_ids", 166 ) File ~/work/graphrag/graphrag/graphrag/query/input/loaders/dfs.py:44, in read_entities(df, id_col, short_id_col, title_col, type_col, description_col, name_embedding_col, description_embedding_col, community_col, text_unit_ids_col, rank_col, attributes_cols) 39 entities = [] 40 for idx, row in df.iterrows(): 41 entity = Entity( 42 id=to_str(row, id_col), 43 short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx), ---> 44 title=to_str(row, title_col), 45 type=to_optional_str(row, type_col), 46 description=to_optional_str(row, description_col), 47 name_embedding=to_optional_list(row, name_embedding_col, item_type=float), 48 description_embedding=to_optional_list( 49 row, description_embedding_col, item_type=float 50 ), 51 community_ids=to_optional_list(row, community_col, item_type=str), 52 text_unit_ids=to_optional_list(row, text_unit_ids_col), 53 rank=to_optional_int(row, rank_col), 54 attributes=( 55 {col: row.get(col) for col in attributes_cols} 56 if attributes_cols 57 else None 58 ), 59 ) 60 entities.append(entity) 61 return entities File ~/work/graphrag/graphrag/graphrag/query/input/loaders/utils.py:19, in to_str(data, column_name) 17 return str(data[column_name]) 18 msg = f"Column {column_name} not found in data" ---> 19 raise ValueError(msg) ValueError: Column title not found in data
In [3]:
Copied!
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_llm = ChatOpenAI(
api_key=api_key,
model=llm_model,
api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
max_retries=20,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
text_embedder = OpenAIEmbedding(
api_key=api_key,
api_base=None,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_llm = ChatOpenAI(
api_key=api_key,
model=llm_model,
api_type=OpenaiApiType.OpenAI, # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
max_retries=20,
)
token_encoder = tiktoken.get_encoding("cl100k_base")
text_embedder = OpenAIEmbedding(
api_key=api_key,
api_base=None,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
In [4]:
Copied!
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
entity_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
entity_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
In [5]:
Copied!
context_builder = DRIFTSearchContextBuilder(
chat_llm=chat_llm,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
)
search = DRIFTSearch(
llm=chat_llm, context_builder=context_builder, token_encoder=token_encoder
)
context_builder = DRIFTSearchContextBuilder(
chat_llm=chat_llm,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
)
search = DRIFTSearch(
llm=chat_llm, context_builder=context_builder, token_encoder=token_encoder
)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[5], line 4 1 context_builder = DRIFTSearchContextBuilder( 2 chat_llm=chat_llm, 3 text_embedder=text_embedder, ----> 4 entities=entities, 5 relationships=relationships, 6 reports=reports, 7 entity_text_embeddings=description_embedding_store, 8 text_units=text_units, 9 ) 11 search = DRIFTSearch( 12 llm=chat_llm, context_builder=context_builder, token_encoder=token_encoder 13 ) NameError: name 'entities' is not defined
In [6]:
Copied!
resp = await search.asearch("Who is agent Mercer?")
resp = await search.asearch("Who is agent Mercer?")
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[6], line 1 ----> 1 resp = await search.asearch("Who is agent Mercer?") NameError: name 'search' is not defined
In [7]:
Copied!
resp.response
resp.response
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[7], line 1 ----> 1 resp.response NameError: name 'resp' is not defined
In [8]:
Copied!
resp.response["nodes"][0]["answer"]
resp.response["nodes"][0]["answer"]
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[8], line 1 ----> 1 resp.response["nodes"][0]["answer"] NameError: name 'resp' is not defined