DRIFT Search
In [1]:
Copied!
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
In [2]:
Copied!
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
full_content_embedding_store = LanceDBVectorStore(
collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
full_content_embedding_store = LanceDBVectorStore(
collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
Entity df columns: Index(['id', 'human_readable_id', 'title', 'type', 'description', 'text_unit_ids', 'frequency', 'degree', 'x', 'y'], dtype='object') Entity count: 18 Relationship count: 54 Text unit records: 5
Out[2]:
id | human_readable_id | text | n_tokens | document_ids | entity_ids | relationship_ids | covariate_ids | |
---|---|---|---|---|---|---|---|---|
0 | 8e938693af886bfd081acbbe8384c3671446bff84a134a... | 1 | # Operation: Dulce\n\n## Chapter 1\n\nThe thru... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [745d28dd-be20-411b-85ff-1c69ca70e7b3, 9cba185... |
1 | fd1f46d32e1df6cd429542aeda3d64ddf3745ccb80f443... | 2 | , the hollow echo of the bay a stark reminder ... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [4f9b461f-5e8f-465d-9586-e2fc81787062, 0f74618... |
2 | 7296d9a1f046854d59079dc183de8a054c27c4843d2979... | 3 | differently than praise from others. This was... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [3ef1be9c-4080-4fac-99bd-c4a636248904, 8730b20... |
3 | ac72722a02ac71242a2a91fca323198d04197daf60515d... | 4 | contrast to the rigid silence enveloping the ... | 1200 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [425a7862-0aef-4f69-a4c8-8bd42151c9d4, bcdbf1f... | [2bfad9f4-5abd-48d0-8db3-a9cad9120413, 6cbb838... | [2c292047-b79a-4958-ab57-7bf7d7a22c92, 3cbd18a... |
4 | 4c277337d461a16aaf8f9760ddb8b44ef220e948a2341d... | 5 | a mask of duty.\n\nIn the midst of the descen... | 35 | [6e81f882f89dd5596e1925dd3ae8a4f0a0edcb55b35a8... | [d084d615-3584-4ec8-9931-90aa6075c764, 4b84859... | [6efdc42e-69a2-47c0-97ec-4b296cd16d5e] | [db8da02f-f889-4bb5-8e81-ab2a72e380bb] |
In [3]:
Copied!
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIChat,
model=llm_model,
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.OpenAIChat,
config=chat_config,
)
token_encoder = tiktoken.encoding_for_model(llm_model)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIEmbedding,
model=embedding_model,
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.OpenAIEmbedding,
config=embedding_config,
)
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIChat,
model=llm_model,
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.OpenAIChat,
config=chat_config,
)
token_encoder = tiktoken.encoding_for_model(llm_model)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIEmbedding,
model=embedding_model,
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.OpenAIEmbedding,
config=embedding_config,
)
In [4]:
Copied!
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
community_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
community_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)
In [5]:
Copied!
drift_params = DRIFTSearchConfig(
temperature=0,
max_tokens=12_000,
primer_folds=1,
drift_k_followups=3,
n_depth=3,
n=1,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
token_encoder=token_encoder,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)
drift_params = DRIFTSearchConfig(
temperature=0,
max_tokens=12_000,
primer_folds=1,
drift_k_followups=3,
n_depth=3,
n=1,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
token_encoder=token_encoder,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)
In [6]:
Copied!
resp = await search.search("Who is agent Mercer?")
resp = await search.search("Who is agent Mercer?")
--------------------------------------------------------------------------- AuthenticationError Traceback (most recent call last) Cell In[6], line 1 ----> 1 resp = await search.search("Who is agent Mercer?") File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/search.py:213, in DRIFTSearch.search(self, query, conversation_history, reduce, **kwargs) 210 # Check if query state is empty 211 if not self.query_state.graph: 212 # Prime the search with the primer --> 213 primer_context, token_ct = await self.context_builder.build_context(query) 214 llm_calls["build_context"] = token_ct["llm_calls"] 215 prompt_tokens["build_context"] = token_ct["prompt_tokens"] File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/drift_context.py:199, in DRIFTSearchContextBuilder.build_context(self, query, **kwargs) 190 raise ValueError(missing_reports_error) 192 query_processor = PrimerQueryProcessor( 193 chat_model=self.model, 194 text_embedder=self.text_embedder, 195 token_encoder=self.token_encoder, 196 reports=self.reports, 197 ) --> 199 query_embedding, token_ct = await query_processor(query) 201 report_df = self.convert_reports_to_df(self.reports) 203 # Check compatibility between query embedding and document embeddings File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/primer.py:96, in PrimerQueryProcessor.__call__(self, query) 85 async def __call__(self, query: str) -> tuple[list[float], dict[str, int]]: 86 """ 87 Call method to process the query, expand it, and embed the result. 88 (...) 94 tuple[list[float], int]: List of embeddings for the expanded query and the token count. 95 """ ---> 96 hyde_query, token_ct = await self.expand_query(query) 97 log.info("Expanded query: %s", hyde_query) 98 return self.text_embedder.embed(hyde_query), token_ct File ~/work/graphrag/graphrag/graphrag/query/structured_search/drift_search/primer.py:70, in PrimerQueryProcessor.expand_query(self, query) 63 template = secrets.choice(self.reports).full_content # nosec S311 65 prompt = f"""Create a hypothetical answer to the following query: {query}\n\n 66 Format it to follow the structure of the template below:\n\n 67 {template}\n" 68 Ensure that the hypothetical answer does not reference new named entities that are not present in the original query.""" ---> 70 model_response = await self.chat_model.achat(prompt) 71 text = model_response.output.content 73 prompt_tokens = num_tokens(prompt, self.token_encoder) File ~/work/graphrag/graphrag/graphrag/language_model/providers/fnllm/models.py:82, in OpenAIChatFNLLM.achat(self, prompt, history, **kwargs) 70 """ 71 Chat with the Model using the given prompt. 72 (...) 79 The response from the Model. 80 """ 81 if history is None: ---> 82 response = await self.model(prompt, **kwargs) 83 else: 84 response = await self.model(prompt, history=history, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/llm/openai_chat_llm.py:94, in OpenAIChatLLMImpl.__call__(self, prompt, stream, **kwargs) 91 if stream: 92 return await self._streaming_chat_llm(prompt, **kwargs) ---> 94 return await self._text_chat_llm(prompt, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/services/openai_tools_parsing.py:130, in OpenAIParseToolsLLM.__call__(self, prompt, **kwargs) 127 tools = kwargs.get("tools", []) 129 if not tools: --> 130 return await self._delegate(prompt, **kwargs) 132 completion_parameters = self._add_tools_to_parameters(kwargs, tools) 134 result = await self._delegate(prompt, **completion_parameters) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/base_llm.py:144, in BaseLLM.__call__(self, prompt, **kwargs) 142 try: 143 prompt, kwargs = self._rewrite_input(prompt, kwargs) --> 144 return await self._decorated_target(prompt, **kwargs) 145 except BaseException as e: 146 stack_trace = traceback.format_exc() File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/services/json.py:78, in JsonReceiver.decorate.<locals>.invoke(prompt, **kwargs) 76 if kwargs.get("json_model") is not None or kwargs.get("json"): 77 return await this.invoke_json(delegate, prompt, kwargs) ---> 78 return await delegate(prompt, **kwargs) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/services/rate_limiter.py:75, in RateLimiter.decorate.<locals>.invoke(prompt, **args) 73 async with self._limiter.use(manifest): 74 await self._events.on_limit_acquired(manifest) ---> 75 result = await delegate(prompt, **args) 76 finally: 77 await self._events.on_limit_released(manifest) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/base/base_llm.py:126, in BaseLLM._decorator_target(self, prompt, **kwargs) 121 """Target for the decorator chain. 122 123 Leave signature alone as prompt, kwargs. 124 """ 125 await self._events.on_execute_llm() --> 126 output = await self._execute_llm(prompt, kwargs) 127 result = LLMOutput(output=output) 128 await self._inject_usage(result) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/fnllm/openai/llm/openai_text_chat_llm.py:166, in OpenAITextChatLLMImpl._execute_llm(self, prompt, kwargs) 163 local_model_parameters = kwargs.get("model_parameters") 164 parameters = self._build_completion_parameters(local_model_parameters) --> 166 raw_response = await self._client.chat.completions.with_raw_response.create( 167 messages=cast(Iterator[ChatCompletionMessageParam], messages), 168 **parameters, 169 ) 170 completion = raw_response.parse() 171 headers = raw_response.headers File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_legacy_response.py:381, in async_to_raw_response_wrapper.<locals>.wrapped(*args, **kwargs) 377 extra_headers[RAW_RESPONSE_HEADER] = "true" 379 kwargs["extra_headers"] = extra_headers --> 381 return cast(LegacyAPIResponse[R], await func(*args, **kwargs)) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/resources/chat/completions/completions.py:2028, in AsyncCompletions.create(self, messages, model, audio, frequency_penalty, function_call, functions, logit_bias, logprobs, max_completion_tokens, max_tokens, metadata, modalities, n, parallel_tool_calls, prediction, presence_penalty, reasoning_effort, response_format, seed, service_tier, stop, store, stream, stream_options, temperature, tool_choice, tools, top_logprobs, top_p, user, web_search_options, extra_headers, extra_query, extra_body, timeout) 1985 @required_args(["messages", "model"], ["messages", "model", "stream"]) 1986 async def create( 1987 self, (...) 2025 timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, 2026 ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: 2027 validate_response_format(response_format) -> 2028 return await self._post( 2029 "/chat/completions", 2030 body=await async_maybe_transform( 2031 { 2032 "messages": messages, 2033 "model": model, 2034 "audio": audio, 2035 "frequency_penalty": frequency_penalty, 2036 "function_call": function_call, 2037 "functions": functions, 2038 "logit_bias": logit_bias, 2039 "logprobs": logprobs, 2040 "max_completion_tokens": max_completion_tokens, 2041 "max_tokens": max_tokens, 2042 "metadata": metadata, 2043 "modalities": modalities, 2044 "n": n, 2045 "parallel_tool_calls": parallel_tool_calls, 2046 "prediction": prediction, 2047 "presence_penalty": presence_penalty, 2048 "reasoning_effort": reasoning_effort, 2049 "response_format": response_format, 2050 "seed": seed, 2051 "service_tier": service_tier, 2052 "stop": stop, 2053 "store": store, 2054 "stream": stream, 2055 "stream_options": stream_options, 2056 "temperature": temperature, 2057 "tool_choice": tool_choice, 2058 "tools": tools, 2059 "top_logprobs": top_logprobs, 2060 "top_p": top_p, 2061 "user": user, 2062 "web_search_options": web_search_options, 2063 }, 2064 completion_create_params.CompletionCreateParamsStreaming 2065 if stream 2066 else completion_create_params.CompletionCreateParamsNonStreaming, 2067 ), 2068 options=make_request_options( 2069 extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout 2070 ), 2071 cast_to=ChatCompletion, 2072 stream=stream or False, 2073 stream_cls=AsyncStream[ChatCompletionChunk], 2074 ) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_base_client.py:1742, in AsyncAPIClient.post(self, path, cast_to, body, files, options, stream, stream_cls) 1728 async def post( 1729 self, 1730 path: str, (...) 1737 stream_cls: type[_AsyncStreamT] | None = None, 1738 ) -> ResponseT | _AsyncStreamT: 1739 opts = FinalRequestOptions.construct( 1740 method="post", url=path, json_data=body, files=await async_to_httpx_files(files), **options 1741 ) -> 1742 return await self.request(cast_to, opts, stream=stream, stream_cls=stream_cls) File ~/.cache/pypoetry/virtualenvs/graphrag-F2jvqev7-py3.11/lib/python3.11/site-packages/openai/_base_client.py:1549, in AsyncAPIClient.request(self, cast_to, options, stream, stream_cls) 1546 await err.response.aread() 1548 log.debug("Re-raising status error") -> 1549 raise self._make_status_error_from_response(err.response) from None 1551 break 1553 assert response is not None, "could not resolve response (should never happen)" AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: sk-proj-********************************************************************************************************************************************************zWYA. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}
In [7]:
Copied!
resp.response
resp.response
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[7], line 1 ----> 1 resp.response NameError: name 'resp' is not defined
In [8]:
Copied!
print(resp.context_data)
print(resp.context_data)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[8], line 1 ----> 1 print(resp.context_data) NameError: name 'resp' is not defined