ruff check and format
This commit is contained in:
@@ -21,30 +21,6 @@ target-version = "py314"
|
|||||||
line-length = 100
|
line-length = 100
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
# Enable latest PEP rules
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle errors
|
|
||||||
"W", # pycodestyle warnings
|
|
||||||
"F", # Pyflakes
|
|
||||||
"UP", # pyupgrade (PEP 585, 604, etc.)
|
|
||||||
"B", # flake8-bugbear
|
|
||||||
"SIM", # flake8-simplify
|
|
||||||
"I", # isort
|
|
||||||
"N", # pep8-naming
|
|
||||||
"D", # pydocstring
|
|
||||||
"C90", # mccabe complexity
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
"D100", # Missing docstring in public module
|
|
||||||
"D104", # Missing docstring in public package
|
|
||||||
"D203", # 1 blank line required before class docstring
|
|
||||||
"D213", # Multi-line docstring summary should start at the second line
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"__init__.py" = ["F401"] # Allow unused imports in __init__.py files
|
|
||||||
|
|
||||||
[tool.ruff.lint.mccabe]
|
[tool.ruff.lint.mccabe]
|
||||||
max-complexity = 10
|
max-complexity = 10
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def load_config(config_path="src/config.yaml"):
|
def load_config(config_path="src/config.yaml"):
|
||||||
with open(config_path, "r") as f:
|
with open(config_path) as f:
|
||||||
return yaml.safe_load(f)
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
# Usage example:
|
# Usage example:
|
||||||
# CFG = load_config()
|
# CFG = load_config()
|
||||||
# print(CFG['api']['base_url'])
|
# print(CFG['api']['base_url'])
|
||||||
+17
-13
@@ -1,10 +1,10 @@
|
|||||||
import dspy
|
|
||||||
from langchain_community.vectorstores import FAISS
|
|
||||||
from embedding import LocalLMEmbeddings
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from config_loader import load_config
|
import dspy
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
|
||||||
|
from config_loader import load_config
|
||||||
|
from embedding import LocalLMEmbeddings
|
||||||
|
|
||||||
CFG = load_config()
|
CFG = load_config()
|
||||||
|
|
||||||
@@ -16,20 +16,22 @@ API_BASE = CFG["api"]["base_url"]
|
|||||||
# --- DSPy Signature ---
|
# --- DSPy Signature ---
|
||||||
class DnDContextQA(dspy.Signature):
|
class DnDContextQA(dspy.Signature):
|
||||||
"""Answer DnD campaign questions using provided snippets and full file context.
|
"""Answer DnD campaign questions using provided snippets and full file context.
|
||||||
/no_think"""
|
/no_think
|
||||||
context = dspy.InputField(desc="Relevant chunks and full file contents from the campaign notes.")
|
"""
|
||||||
|
|
||||||
|
context = dspy.InputField(
|
||||||
|
desc="Relevant chunks and full file contents from the campaign notes."
|
||||||
|
)
|
||||||
question = dspy.InputField()
|
question = dspy.InputField()
|
||||||
answer = dspy.OutputField(desc="A detailed answer based on the notes, citing the source file.")
|
answer = dspy.OutputField(desc="A detailed answer based on the notes, citing the source file.")
|
||||||
|
|
||||||
|
|
||||||
# --- DSPy Module ---
|
# --- DSPy Module ---
|
||||||
class DnDRAG(dspy.Module):
|
class DnDRAG(dspy.Module):
|
||||||
def __init__(self, db_path=DATABASE_PATH, k=3):
|
def __init__(self, db_path=DATABASE_PATH, k=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 1. Setup Embeddings & Load FAISS
|
# 1. Setup Embeddings & Load FAISS
|
||||||
self.embeddings = LocalLMEmbeddings(
|
self.embeddings = LocalLMEmbeddings(model=EMBEDDING_MODEL, base_url=API_BASE)
|
||||||
model=EMBEDDING_MODEL,
|
|
||||||
base_url=API_BASE
|
|
||||||
)
|
|
||||||
self.vectorstore = FAISS.load_local(
|
self.vectorstore = FAISS.load_local(
|
||||||
db_path, self.embeddings, allow_dangerous_deserialization=True
|
db_path, self.embeddings, allow_dangerous_deserialization=True
|
||||||
)
|
)
|
||||||
@@ -41,7 +43,7 @@ class DnDRAG(dspy.Module):
|
|||||||
def get_full_file_content(self, file_path):
|
def get_full_file_content(self, file_path):
|
||||||
"""Helper to read the full source file if it exists."""
|
"""Helper to read the full source file if it exists."""
|
||||||
try:
|
try:
|
||||||
return Path(file_path).read_text(encoding='utf-8')
|
return Path(file_path).read_text(encoding="utf-8")
|
||||||
except Exception:
|
except Exception:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@@ -56,13 +58,15 @@ class DnDRAG(dspy.Module):
|
|||||||
context_parts = []
|
context_parts = []
|
||||||
for i, doc in enumerate(results):
|
for i, doc in enumerate(results):
|
||||||
source = doc.metadata.get("source", "Unknown")
|
source = doc.metadata.get("source", "Unknown")
|
||||||
context_parts.append(f"--- Chunk {i+1} from {source} ---\n{doc.page_content}")
|
context_parts.append(f"--- Chunk {i + 1} from {source} ---\n{doc.page_content}")
|
||||||
|
|
||||||
# 3. Add the Full Content of the top match (optional, but requested!)
|
# 3. Add the Full Content of the top match (optional, but requested!)
|
||||||
# We'll just take the top 1 file to avoid context window explosion
|
# We'll just take the top 1 file to avoid context window explosion
|
||||||
if unique_paths:
|
if unique_paths:
|
||||||
top_file_content = self.get_full_file_content(unique_paths[0])
|
top_file_content = self.get_full_file_content(unique_paths[0])
|
||||||
context_parts.append(f"\n=== FULL SOURCE FILE: {Path(unique_paths[0]).name} ===\n{top_file_content[:10000]}")
|
context_parts.append(
|
||||||
|
f"\n=== FULL SOURCE FILE: {Path(unique_paths[0]).name} ===\n{top_file_content[:10000]}"
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Join everything into one context string
|
# 4. Join everything into one context string
|
||||||
context_str = "\n\n".join(context_parts)
|
context_str = "\n\n".join(context_parts)
|
||||||
|
|||||||
@@ -1,22 +1,24 @@
|
|||||||
import dspy
|
import dspy
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# 1. Define the structure of your metadata
|
# 1. Define the structure of your metadata
|
||||||
class DocMetadata(BaseModel):
|
class DocMetadata(BaseModel):
|
||||||
synopsis: str = Field(description="A one-sentence summary of the document.")
|
synopsis: str = Field(description="A one-sentence summary of the document.")
|
||||||
tags: List[str] = Field(description="Relevant tags (NPCs, Locations, Items, Plot Points).")
|
tags: list[str] = Field(description="Relevant tags (NPCs, Locations, Items, Plot Points).")
|
||||||
entities: List[str] = Field(description="Key names of people, places, or factions.")
|
entities: list[str] = Field(description="Key names of people, places, or factions.")
|
||||||
|
|
||||||
|
|
||||||
class IngestionSignature(dspy.Signature):
|
class IngestionSignature(dspy.Signature):
|
||||||
"""
|
"""You are an expert Dungeon Master's assistant.
|
||||||
You are an expert Dungeon Master's assistant.
|
|
||||||
Analyze the provided notes and extract a concise synopsis and relevant metadata.
|
Analyze the provided notes and extract a concise synopsis and relevant metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
note: str = dspy.InputField(desc="The DM notes or session recap content.")
|
note: str = dspy.InputField(desc="The DM notes or session recap content.")
|
||||||
# By using the Pydantic model as the type, DSPy handles the JSON formatting for you
|
# By using the Pydantic model as the type, DSPy handles the JSON formatting for you
|
||||||
answer: DocMetadata = dspy.OutputField()
|
answer: DocMetadata = dspy.OutputField()
|
||||||
|
|
||||||
|
|
||||||
class IngestionAgent(dspy.Module):
|
class IngestionAgent(dspy.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
+17
-14
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -8,10 +7,9 @@ from langchain_community.vectorstores import FAISS
|
|||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from config_loader import load_config
|
||||||
from embedding import LocalLMEmbeddings
|
from embedding import LocalLMEmbeddings
|
||||||
from experts.ingestion_agent import IngestionAgent
|
from experts.ingestion_agent import IngestionAgent
|
||||||
from config_loader import load_config
|
|
||||||
|
|
||||||
|
|
||||||
CFG = load_config()
|
CFG = load_config()
|
||||||
DATA_DIR = CFG["ingestion"]["data_dir"]
|
DATA_DIR = CFG["ingestion"]["data_dir"]
|
||||||
@@ -20,10 +18,11 @@ MODEL_BASE = CFG["models"]["enrich"]
|
|||||||
EMBEDDING_MODEL = CFG["models"]["embedding"]
|
EMBEDDING_MODEL = CFG["models"]["embedding"]
|
||||||
API_BASE = CFG["api"]["base_url"]
|
API_BASE = CFG["api"]["base_url"]
|
||||||
API_VERSION = CFG["api"]["api_version"]
|
API_VERSION = CFG["api"]["api_version"]
|
||||||
MAX_WORKERS=CFG["ingestion"]["max_workers"]
|
MAX_WORKERS = CFG["ingestion"]["max_workers"]
|
||||||
CHUNK_SIZE=CFG["ingestion"]["chunk_size"],
|
CHUNK_SIZE = (CFG["ingestion"]["chunk_size"],)
|
||||||
CHUNK_OVERLAP=CFG["ingestion"]["chunk_overlap"]
|
CHUNK_OVERLAP = CFG["ingestion"]["chunk_overlap"]
|
||||||
EMBEDDING_BATCH_SIZE=CFG["ingestion"]["embedding_batch_size"]
|
EMBEDDING_BATCH_SIZE = CFG["ingestion"]["embedding_batch_size"]
|
||||||
|
|
||||||
|
|
||||||
def load_documents():
|
def load_documents():
|
||||||
docs = []
|
docs = []
|
||||||
@@ -50,23 +49,24 @@ def load_documents():
|
|||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def chunk_documents(docs):
|
def chunk_documents(docs):
|
||||||
# LangChain preserves metadata during splitting automatically
|
# LangChain preserves metadata during splitting automatically
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=CHUNK_SIZE,
|
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, separators=["\n\n", "\n", ". ", " ", ""]
|
||||||
chunk_overlap=CHUNK_OVERLAP,
|
|
||||||
separators=["\n\n", "\n", ". ", " ", ""]
|
|
||||||
)
|
)
|
||||||
return text_splitter.split_documents(docs)
|
return text_splitter.split_documents(docs)
|
||||||
|
|
||||||
def enrich_chunks(chunks: list) -> list:
|
|
||||||
|
|
||||||
|
def enrich_chunks(chunks: list) -> list:
|
||||||
def process_single_chunk(indexed_chunk):
|
def process_single_chunk(indexed_chunk):
|
||||||
idx, chunk = indexed_chunk
|
idx, chunk = indexed_chunk
|
||||||
lm_index = idx % 8
|
lm_index = idx % 8
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with dspy.context(lm=dspy.LM(f"{MODEL_BASE}:{lm_index}", api_base=API_BASE+API_VERSION)):
|
with dspy.context(
|
||||||
|
lm=dspy.LM(f"{MODEL_BASE}:{lm_index}", api_base=API_BASE + API_VERSION)
|
||||||
|
):
|
||||||
response = IngestionAgent().forward(note=chunk.page_content)
|
response = IngestionAgent().forward(note=chunk.page_content)
|
||||||
|
|
||||||
# This is now an object, not a string!
|
# This is now an object, not a string!
|
||||||
@@ -79,7 +79,6 @@ def enrich_chunks(chunks: list) -> list:
|
|||||||
chunk.metadata.update(metadata)
|
chunk.metadata.update(metadata)
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||||||
# Wrap chunks in enumerate to keep track of order
|
# Wrap chunks in enumerate to keep track of order
|
||||||
@@ -92,6 +91,7 @@ def enrich_chunks(chunks: list) -> list:
|
|||||||
enriched_results.sort(key=lambda x: x[0])
|
enriched_results.sort(key=lambda x: x[0])
|
||||||
return [item[1] for item in enriched_results]
|
return [item[1] for item in enriched_results]
|
||||||
|
|
||||||
|
|
||||||
def store_chunks_locally(chunks, db_path=DATABASE_PATH):
|
def store_chunks_locally(chunks, db_path=DATABASE_PATH):
|
||||||
embeddings_model = LocalLMEmbeddings(
|
embeddings_model = LocalLMEmbeddings(
|
||||||
model=EMBEDDING_MODEL,
|
model=EMBEDDING_MODEL,
|
||||||
@@ -106,14 +106,17 @@ def store_chunks_locally(chunks, db_path=DATABASE_PATH):
|
|||||||
print(f"✅ Successfully stored in FAISS at '{db_path}'")
|
print(f"✅ Successfully stored in FAISS at '{db_path}'")
|
||||||
return vectorstore
|
return vectorstore
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
docs = load_documents()
|
docs = load_documents()
|
||||||
if not docs: return
|
if not docs:
|
||||||
|
return
|
||||||
|
|
||||||
chunks = chunk_documents(docs)
|
chunks = chunk_documents(docs)
|
||||||
enriched_chunks = enrich_chunks(chunks)
|
enriched_chunks = enrich_chunks(chunks)
|
||||||
store_chunks_locally(enriched_chunks)
|
store_chunks_locally(enriched_chunks)
|
||||||
print("🎉 Ingestion complete!")
|
print("🎉 Ingestion complete!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
+6
-4
@@ -1,8 +1,9 @@
|
|||||||
import sys
|
import sys
|
||||||
import dspy
|
|
||||||
from experts.dnd_agent import DnDRAG
|
|
||||||
from config_loader import load_config
|
|
||||||
|
|
||||||
|
import dspy
|
||||||
|
|
||||||
|
from config_loader import load_config
|
||||||
|
from experts.dnd_agent import DnDRAG
|
||||||
|
|
||||||
CFG = load_config()
|
CFG = load_config()
|
||||||
RETRIEVE_MODEL = CFG["models"]["retrieval"]
|
RETRIEVE_MODEL = CFG["models"]["retrieval"]
|
||||||
@@ -13,7 +14,7 @@ API_VERSION = CFG["api"]["api_version"]
|
|||||||
def main():
|
def main():
|
||||||
# 1. Setup the LLM
|
# 1. Setup the LLM
|
||||||
print("🚀 Initializing Qwen-8B via LM Studio...")
|
print("🚀 Initializing Qwen-8B via LM Studio...")
|
||||||
lm = dspy.LM(RETRIEVE_MODEL, api_base=API_BASE+API_VERSION)
|
lm = dspy.LM(RETRIEVE_MODEL, api_base=API_BASE + API_VERSION)
|
||||||
dspy.configure(lm=lm)
|
dspy.configure(lm=lm)
|
||||||
|
|
||||||
# 2. Load the RAG System (only happens once!)
|
# 2. Load the RAG System (only happens once!)
|
||||||
@@ -52,5 +53,6 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n⚠️ An error occurred: {e}")
|
print(f"\n⚠️ An error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user