Merge pull request #1 from Jake-Pullen/ai_in_the_middle

feat:  AI Powered enhanced queries to get better results
This commit is contained in:
Jake-Pullen
2026-03-07 11:22:11 +00:00
committed by GitHub
6 changed files with 73 additions and 37 deletions
+2
View File
@@ -17,3 +17,5 @@
* entity chunking & re-ranking
* Logging in Ingestion
* database retrieve for tag or entity
*
+11 -2
View File
@@ -8,11 +8,13 @@ models:
enrich: "lm_studio/qwen-" # will have an identifier, based on amount of active LLMs see ./load_ingestion_llms.sh
embedding: "text-embedding-qwen3-embedding-8b"
retrieval: "lm_studio/qwen/qwen3-30b-a3b-2507"
expansion: "lm_studio/qwen/qwen3-30b-a3b-2507"
# --- Ingestion Settings ---
ingestion:
data_dir: "/home/cosmic/DnD"
db_path: "./data/dmv.db"
data_dir: "/home/jake/DnD"
db_path: "./data/"
db_name: "dmv.db"
active_llms: 2
parallel_requests_per_llm: 2
chunk_size: 800
@@ -36,3 +38,10 @@ retrieval_agent:
Given the context and the question, answer the question.
Do not make things up, base all of your answers on the context.
Always site the file location of your source of information.
expansion_agent:
expansion_signature: |
You are a query expansion expert, specialised in Dungeons and Dragons.
Given a user's question, generate 3-5 similar but enhanced search queries that would help find more relevant information.
Each expanded query should be distinct and add different perspective to the original question.
Return only the queries as a JSON list with key "queries"."""
+2 -1
View File
@@ -1,5 +1,6 @@
import requests
from langchain_core.embeddings import Embeddings
from config_loader import load_config
CFG = load_config()
@@ -37,7 +38,7 @@ class LocalLMEmbeddings(Embeddings):
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
print(f"🚀 Processing batch {(i // self.batch_size) + 1} (Size: {len(batch)})...")
# print(f"🚀 Processing batch {(i // self.batch_size) + 1} (Size: {len(batch)})...")
batch_vectors = self._post_request(batch)
all_embeddings.extend(batch_vectors)
View File
+50 -25
View File
@@ -1,7 +1,7 @@
import os
import turso
import dspy
import dspy
import turso
from config_loader import load_config
from embedding import LocalLMEmbeddings
@@ -9,27 +9,28 @@ from embedding import LocalLMEmbeddings
CFG = load_config()
DATABASE_PATH = CFG["ingestion"]["db_path"]
DATABASE_NAME = CFG["ingestion"]["db_name"]
EMBEDDING_MODEL = CFG["models"]["embedding"]
API_BASE = CFG["api"]["base_url"]
RETRIEVAL_CONFIG = CFG["retrieval_agent"]
EXPANSION_CONFIG = CFG["expansion_agent"]
def retrieve_from_turso(embedded_question, k=5):
query = f"""
SELECT file_path, synopsis, tags, entities, chunk_data,
vector_distance_cos(embedding, vector32('{embedded_question[0]}')) AS distance
vector_distance_cos(embedding, vector32('{embedded_question}')) AS distance
FROM notes
ORDER BY distance ASC
LIMIT {k};
"""
con = turso.connect(DATABASE_PATH)
con = turso.connect(DATABASE_PATH + DATABASE_NAME)
cur = con.cursor()
cur.execute(query)
rows = cur.fetchall()
return rows
# --- DSPy Signature ---
class DnDContextQA(dspy.Signature):
f"{RETRIEVAL_CONFIG['retrieval_signature']}"
@@ -38,47 +39,71 @@ class DnDContextQA(dspy.Signature):
answer = dspy.OutputField(desc="A detailed answer based on the notes, citing the source file.")
class ExpansionSignature(dspy.Signature):
f"{EXPANSION_CONFIG['expansion_signature']}"
question = dspy.InputField()
answer = dspy.OutputField(
desc="A list of questions that will be used to vector search the database."
)
class DnDRAG(dspy.Module):
def __init__(self):
super().__init__()
self.embeddings_model = LocalLMEmbeddings(
model=EMBEDDING_MODEL,
base_url=API_BASE,
batch_size=1, # we only send 1 question at a time.
# batch_size=1,
)
# Tools exposed to the ReAct loop
self.retrieval_lm = dspy.LM(
model=CFG["models"]["retrieval"], api_base=API_BASE + CFG["api"]["api_version"]
)
with dspy.context(lm=self.retrieval_lm, signature=ExpansionSignature):
self.query_expander = dspy.Predict("question -> queries:list[str]")
self.tools = [self.load_file]
self.generate_answer = dspy.ReAct(signature=DnDContextQA, tools=self.tools)
def forward(self, question):
# TODO: Add step here to LLM Expand
# given the current question, generate 3-5 distinct search queries.
# embed all the questions
embedded_question = self.embeddings_model._post_request(question)
# store the 5 from all 3-5 questions (15 - 25 results)
results = retrieve_from_turso(embedded_question, k=5) # k is limit to return
print("Enhancing Question")
with dspy.context(lm=self.retrieval_lm):
expanded_queries = self.query_expander(question=question).queries
print("Enhanced Queries:")
for q in expanded_queries:
print(" ", q)
all_embeddings = self.embeddings_model.embed_documents([question] + expanded_queries)
# print(all_embeddings)
all_results = []
for embedded_question in all_embeddings:
results = retrieve_from_turso(embedded_question, k=5)
all_results.extend(results)
seen = set()
unique_results = []
for row in all_results:
key = (row[0], row[4])
if key not in seen:
seen.add(key)
unique_results.append(row)
# Format context as before
context_parts = []
for i, row in enumerate(results):
source = row[0] # file_path
synopsis = row[1] # synopsis
tags = row[2] # tags
entities = row[3] # entities
content = row[4] # chunk_data
for i, row in enumerate(unique_results):
source = row[0]
synopsis = row[1]
tags = row[2]
entities = row[3]
content = row[4]
closeness = row[5]
context_parts.append(f"""
--- Chunk {i + 1} from {source} ---
synopsis: {synopsis},
tags: {tags},
entities: {entities}
entities: {entities},
closeness: {closeness},
{content}
""")
# print('Closest embedding hits')
# for part in context_parts:
# print(part)
context = "\n\n".join(context_parts)
prediction = self.generate_answer(context=context, question=question)
+6 -7
View File
@@ -16,6 +16,7 @@ from experts.ingestion_agent import IngestionAgent
CFG = load_config()
DATA_DIR = CFG["ingestion"]["data_dir"]
DATABASE_PATH = CFG["ingestion"]["db_path"]
DATABASE_NAME = CFG["ingestion"]["db_name"]
MODEL_BASE = CFG["models"]["enrich"]
EMBEDDING_MODEL = CFG["models"]["embedding"]
API_BASE = CFG["api"]["base_url"]
@@ -139,13 +140,10 @@ def embed_chunks(chunks: List[Any], batch_size: int = EMBEDDING_BATCH_SIZE) -> L
# Process chunks in batches
for i in tqdm(range(0, total_chunks, batch_size), desc="Embedding batches"):
batch = chunks[i : i + batch_size]
print(f"🚀 Processing batch {(i // batch_size) + 1} (Size: {len(batch)})...")
batch_content = [chunk.page_content for chunk in batch]
try:
# Use model's batched embedding method
# batch_embeddings = embeddings_model.embed_query(batch_content)
batch_embeddings = embeddings_model.embed_documents(batch_content)
# Process each chunk in the batch
for j, (chunk, embedding) in enumerate(zip(batch, batch_embeddings)):
# Extract metadata
@@ -228,7 +226,7 @@ def save_to_db(chunk_dicts):
Each dict maps to a row in the 'notes' table.
"""
print("connecting to db")
con = turso.connect(DATABASE_PATH)
con = turso.connect(DATABASE_PATH + DATABASE_NAME)
print("opening cursor")
cur = con.cursor()
@@ -267,7 +265,8 @@ def save_to_db(chunk_dicts):
def create_db():
con = turso.connect(DATABASE_PATH)
Path(DATABASE_PATH).mkdir(exist_ok=True)
con = turso.connect(DATABASE_PATH + DATABASE_NAME)
cur = con.cursor()
cur.execute("""
@@ -334,7 +333,7 @@ def delete_from_db(embedded_chunks):
print(f"Deleting existing rows for {len(file_paths)} file(s)")
con = turso.connect(DATABASE_PATH)
con = turso.connect(DATABASE_PATH + DATABASE_NAME)
cur = con.cursor()
# Use a single DELETE statement with IN clause for efficiency