Files
dungeon_masters_vault/src/embedding.py
T
2026-03-05 20:07:35 +00:00

51 lines
1.9 KiB
Python

import requests
from langchain_core.embeddings import Embeddings
from config_loader import load_config
CFG = load_config()
API_BASE = CFG["api"]["base_url"]
API_VERSION = CFG["api"]["api_version"]
class LocalLMEmbeddings(Embeddings):
def __init__(self, model: str, base_url: str = API_BASE, batch_size: int = 32):
self.url = f"{base_url}/{API_VERSION}embeddings"
self.model = model
self.batch_size = batch_size
def _post_request(self, input_texts: list[str]) -> list[list[float]]:
"""Handles the actual HTTP POST to the local server."""
payload = {"model": self.model, "input": input_texts}
try:
# print(f'payload: {payload}')
response = requests.post(
self.url, json=payload, timeout=120
) # Longer timeout for batches
response.raise_for_status()
data = response.json()
# print(data)
return [item["embedding"] for item in data["data"]]
except Exception as e:
print(f"❌ Batch request failed: {e}")
# Returning empty lists to maintain index integrity if needed,
# or you could raise the error to stop the pipeline.
return [[] for _ in input_texts]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Splits chunks into batches of 32 and processes them."""
all_embeddings = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
print(f"🚀 Processing batch {(i // self.batch_size) + 1} (Size: {len(batch)})...")
batch_vectors = self._post_request(batch)
all_embeddings.extend(batch_vectors)
return all_embeddings
def embed_query(self, text: str) -> list[float]:
"""Embeds the single search query."""
result = self._post_request([text])
return result[0] if result else []