DRIFT Search
In [1]:
Copied!
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
In [2]:
Copied!
import os
import pandas as pd
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag_vectors.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
db_uri=LANCEDB_URI,
index_name="entity_description",
)
description_embedding_store.connect()
full_content_embedding_store = LanceDBVectorStore(
db_uri=LANCEDB_URI,
index_name="community_full_content",
)
full_content_embedding_store.connect()
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, community_df, COMMUNITY_LEVEL)
read_indexer_report_embeddings(reports, full_content_embedding_store)
import os
import pandas as pd
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.tokenizer.get_tokenizer import get_tokenizer
from graphrag_vectors.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
db_uri=LANCEDB_URI,
index_name="entity_description",
)
description_embedding_store.connect()
full_content_embedding_store = LanceDBVectorStore(
db_uri=LANCEDB_URI,
index_name="community_full_content",
)
full_content_embedding_store.connect()
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, community_df, COMMUNITY_LEVEL)
read_indexer_report_embeddings(reports, full_content_embedding_store)
--------------------------------------------------------------------------- ImportError Traceback (most recent call last) Cell In[2], line 4 1 import os 3 import pandas as pd ----> 4 from graphrag.config.enums import ModelType 5 from graphrag.config.models.drift_search_config import DRIFTSearchConfig 6 from graphrag.config.models.language_model_config import LanguageModelConfig ImportError: cannot import name 'ModelType' from 'graphrag.config.enums' (/home/runner/work/graphrag/graphrag/packages/graphrag/graphrag/config/enums.py)
In [3]:
Copied!
api_key = os.environ["GRAPHRAG_API_KEY"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.Chat,
model_provider="openai",
model="gpt-4.1",
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.Chat,
config=chat_config,
)
tokenizer = get_tokenizer(chat_config)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.Embedding,
model_provider="openai",
model="text-embedding-3-large",
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.Embedding,
config=embedding_config,
)
api_key = os.environ["GRAPHRAG_API_KEY"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.Chat,
model_provider="openai",
model="gpt-4.1",
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.Chat,
config=chat_config,
)
tokenizer = get_tokenizer(chat_config)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.Embedding,
model_provider="openai",
model="text-embedding-3-large",
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.Embedding,
config=embedding_config,
)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[3], line 3 1 api_key = os.environ["GRAPHRAG_API_KEY"] ----> 3 chat_config = LanguageModelConfig( 4 api_key=api_key, 5 type=ModelType.Chat, 6 model_provider="openai", 7 model="gpt-4.1", 8 max_retries=20, 9 ) 10 chat_model = ModelManager().get_or_create_chat_model( 11 name="local_search", 12 model_type=ModelType.Chat, 13 config=chat_config, 14 ) 16 tokenizer = get_tokenizer(chat_config) NameError: name 'LanguageModelConfig' is not defined
In [4]:
Copied!
drift_params = DRIFTSearchConfig(
primer_folds=1,
drift_k_followups=3,
n_depth=3,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
tokenizer=tokenizer,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)
drift_params = DRIFTSearchConfig(
primer_folds=1,
drift_k_followups=3,
n_depth=3,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
tokenizer=tokenizer,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, tokenizer=tokenizer
)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[4], line 1 ----> 1 drift_params = DRIFTSearchConfig( 2 primer_folds=1, 3 drift_k_followups=3, 4 n_depth=3, 5 ) 7 context_builder = DRIFTSearchContextBuilder( 8 model=chat_model, 9 text_embedder=text_embedder, (...) 16 config=drift_params, 17 ) 19 search = DRIFTSearch( 20 model=chat_model, context_builder=context_builder, tokenizer=tokenizer 21 ) NameError: name 'DRIFTSearchConfig' is not defined
In [5]:
Copied!
resp = await search.search("Who is agent Mercer?")
resp = await search.search("Who is agent Mercer?")
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[5], line 1 ----> 1 resp = await search.search("Who is agent Mercer?") NameError: name 'search' is not defined
In [6]:
Copied!
resp.response
resp.response
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[6], line 1 ----> 1 resp.response NameError: name 'resp' is not defined
In [7]:
Copied!
print(resp.context_data)
print(resp.context_data)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[7], line 1 ----> 1 print(resp.context_data) NameError: name 'resp' is not defined