ruff check and format
This commit is contained in:
@@ -21,30 +21,6 @@ target-version = "py314"
|
||||
line-length = 100
|
||||
indent-width = 4
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Enable latest PEP rules
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"UP", # pyupgrade (PEP 585, 604, etc.)
|
||||
"B", # flake8-bugbear
|
||||
"SIM", # flake8-simplify
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"D", # pydocstring
|
||||
"C90", # mccabe complexity
|
||||
]
|
||||
ignore = [
|
||||
"D100", # Missing docstring in public module
|
||||
"D104", # Missing docstring in public package
|
||||
"D203", # 1 blank line required before class docstring
|
||||
"D213", # Multi-line docstring summary should start at the second line
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401"] # Allow unused imports in __init__.py files
|
||||
|
||||
[tool.ruff.lint.mccabe]
|
||||
max-complexity = 10
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_config(config_path="src/config.yaml"):
|
||||
with open(config_path, "r") as f:
|
||||
with open(config_path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
# Usage example:
|
||||
# CFG = load_config()
|
||||
# print(CFG['api']['base_url'])
|
||||
+16
-12
@@ -1,10 +1,10 @@
|
||||
import dspy
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from embedding import LocalLMEmbeddings
|
||||
from pathlib import Path
|
||||
|
||||
from config_loader import load_config
|
||||
import dspy
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
from config_loader import load_config
|
||||
from embedding import LocalLMEmbeddings
|
||||
|
||||
CFG = load_config()
|
||||
|
||||
@@ -16,20 +16,22 @@ API_BASE = CFG["api"]["base_url"]
|
||||
# --- DSPy Signature ---
|
||||
class DnDContextQA(dspy.Signature):
|
||||
"""Answer DnD campaign questions using provided snippets and full file context.
|
||||
/no_think"""
|
||||
context = dspy.InputField(desc="Relevant chunks and full file contents from the campaign notes.")
|
||||
/no_think
|
||||
"""
|
||||
|
||||
context = dspy.InputField(
|
||||
desc="Relevant chunks and full file contents from the campaign notes."
|
||||
)
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="A detailed answer based on the notes, citing the source file.")
|
||||
|
||||
|
||||
# --- DSPy Module ---
|
||||
class DnDRAG(dspy.Module):
|
||||
def __init__(self, db_path=DATABASE_PATH, k=3):
|
||||
super().__init__()
|
||||
# 1. Setup Embeddings & Load FAISS
|
||||
self.embeddings = LocalLMEmbeddings(
|
||||
model=EMBEDDING_MODEL,
|
||||
base_url=API_BASE
|
||||
)
|
||||
self.embeddings = LocalLMEmbeddings(model=EMBEDDING_MODEL, base_url=API_BASE)
|
||||
self.vectorstore = FAISS.load_local(
|
||||
db_path, self.embeddings, allow_dangerous_deserialization=True
|
||||
)
|
||||
@@ -41,7 +43,7 @@ class DnDRAG(dspy.Module):
|
||||
def get_full_file_content(self, file_path):
|
||||
"""Helper to read the full source file if it exists."""
|
||||
try:
|
||||
return Path(file_path).read_text(encoding='utf-8')
|
||||
return Path(file_path).read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
@@ -62,7 +64,9 @@ class DnDRAG(dspy.Module):
|
||||
# We'll just take the top 1 file to avoid context window explosion
|
||||
if unique_paths:
|
||||
top_file_content = self.get_full_file_content(unique_paths[0])
|
||||
context_parts.append(f"\n=== FULL SOURCE FILE: {Path(unique_paths[0]).name} ===\n{top_file_content[:10000]}")
|
||||
context_parts.append(
|
||||
f"\n=== FULL SOURCE FILE: {Path(unique_paths[0]).name} ===\n{top_file_content[:10000]}"
|
||||
)
|
||||
|
||||
# 4. Join everything into one context string
|
||||
context_str = "\n\n".join(context_parts)
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
import dspy
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
# 1. Define the structure of your metadata
|
||||
class DocMetadata(BaseModel):
|
||||
synopsis: str = Field(description="A one-sentence summary of the document.")
|
||||
tags: List[str] = Field(description="Relevant tags (NPCs, Locations, Items, Plot Points).")
|
||||
entities: List[str] = Field(description="Key names of people, places, or factions.")
|
||||
tags: list[str] = Field(description="Relevant tags (NPCs, Locations, Items, Plot Points).")
|
||||
entities: list[str] = Field(description="Key names of people, places, or factions.")
|
||||
|
||||
|
||||
class IngestionSignature(dspy.Signature):
|
||||
"""
|
||||
You are an expert Dungeon Master's assistant.
|
||||
"""You are an expert Dungeon Master's assistant.
|
||||
Analyze the provided notes and extract a concise synopsis and relevant metadata.
|
||||
"""
|
||||
|
||||
note: str = dspy.InputField(desc="The DM notes or session recap content.")
|
||||
# By using the Pydantic model as the type, DSPy handles the JSON formatting for you
|
||||
answer: DocMetadata = dspy.OutputField()
|
||||
|
||||
|
||||
class IngestionAgent(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
+14
-11
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,10 +7,9 @@ from langchain_community.vectorstores import FAISS
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from config_loader import load_config
|
||||
from embedding import LocalLMEmbeddings
|
||||
from experts.ingestion_agent import IngestionAgent
|
||||
from config_loader import load_config
|
||||
|
||||
|
||||
CFG = load_config()
|
||||
DATA_DIR = CFG["ingestion"]["data_dir"]
|
||||
@@ -21,10 +19,11 @@ EMBEDDING_MODEL = CFG["models"]["embedding"]
|
||||
API_BASE = CFG["api"]["base_url"]
|
||||
API_VERSION = CFG["api"]["api_version"]
|
||||
MAX_WORKERS = CFG["ingestion"]["max_workers"]
|
||||
CHUNK_SIZE=CFG["ingestion"]["chunk_size"],
|
||||
CHUNK_SIZE = (CFG["ingestion"]["chunk_size"],)
|
||||
CHUNK_OVERLAP = CFG["ingestion"]["chunk_overlap"]
|
||||
EMBEDDING_BATCH_SIZE = CFG["ingestion"]["embedding_batch_size"]
|
||||
|
||||
|
||||
def load_documents():
|
||||
docs = []
|
||||
data_path = Path(DATA_DIR)
|
||||
@@ -50,23 +49,24 @@ def load_documents():
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def chunk_documents(docs):
|
||||
# LangChain preserves metadata during splitting automatically
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=CHUNK_OVERLAP,
|
||||
separators=["\n\n", "\n", ". ", " ", ""]
|
||||
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, separators=["\n\n", "\n", ". ", " ", ""]
|
||||
)
|
||||
return text_splitter.split_documents(docs)
|
||||
|
||||
def enrich_chunks(chunks: list) -> list:
|
||||
|
||||
def enrich_chunks(chunks: list) -> list:
|
||||
def process_single_chunk(indexed_chunk):
|
||||
idx, chunk = indexed_chunk
|
||||
lm_index = idx % 8
|
||||
|
||||
try:
|
||||
with dspy.context(lm=dspy.LM(f"{MODEL_BASE}:{lm_index}", api_base=API_BASE+API_VERSION)):
|
||||
with dspy.context(
|
||||
lm=dspy.LM(f"{MODEL_BASE}:{lm_index}", api_base=API_BASE + API_VERSION)
|
||||
):
|
||||
response = IngestionAgent().forward(note=chunk.page_content)
|
||||
|
||||
# This is now an object, not a string!
|
||||
@@ -79,7 +79,6 @@ def enrich_chunks(chunks: list) -> list:
|
||||
chunk.metadata.update(metadata)
|
||||
return chunk
|
||||
|
||||
|
||||
enriched_results = []
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
||||
# Wrap chunks in enumerate to keep track of order
|
||||
@@ -92,6 +91,7 @@ def enrich_chunks(chunks: list) -> list:
|
||||
enriched_results.sort(key=lambda x: x[0])
|
||||
return [item[1] for item in enriched_results]
|
||||
|
||||
|
||||
def store_chunks_locally(chunks, db_path=DATABASE_PATH):
|
||||
embeddings_model = LocalLMEmbeddings(
|
||||
model=EMBEDDING_MODEL,
|
||||
@@ -106,14 +106,17 @@ def store_chunks_locally(chunks, db_path=DATABASE_PATH):
|
||||
print(f"✅ Successfully stored in FAISS at '{db_path}'")
|
||||
return vectorstore
|
||||
|
||||
|
||||
def main():
|
||||
docs = load_documents()
|
||||
if not docs: return
|
||||
if not docs:
|
||||
return
|
||||
|
||||
chunks = chunk_documents(docs)
|
||||
enriched_chunks = enrich_chunks(chunks)
|
||||
store_chunks_locally(enriched_chunks)
|
||||
print("🎉 Ingestion complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+5
-3
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
import dspy
|
||||
from experts.dnd_agent import DnDRAG
|
||||
from config_loader import load_config
|
||||
|
||||
import dspy
|
||||
|
||||
from config_loader import load_config
|
||||
from experts.dnd_agent import DnDRAG
|
||||
|
||||
CFG = load_config()
|
||||
RETRIEVE_MODEL = CFG["models"]["retrieval"]
|
||||
@@ -52,5 +53,6 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ An error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user