Source code for flagevalmm.models.model_cache

import sqlite3
import threading
from typing import Optional
from contextlib import contextmanager

import json
import os
import hashlib
from typing import Any
from flagevalmm.common.logger import get_logger

logger = get_logger(__name__)


[docs] def calculate_hash(data: Any) -> str: string_representation = json.dumps(data, sort_keys=True) sha256 = hashlib.sha256() sha256.update(string_representation.encode("utf-8")) return sha256.hexdigest()
[docs] class ModelCache:
[docs] def __init__(self, db_name="model_cache", cache_dir="./.cache"): # Create cache directory if it doesn't exist os.makedirs(cache_dir, exist_ok=True) db_name = db_name.replace("/", "_") self.db_path = os.path.join(cache_dir, f"{db_name}.sqlite") self._local = threading.local() # Thread-local storage self._lock = threading.Lock() # Lock for initialization # Ensure database and table are created with self._get_conn() as conn: conn.execute( """CREATE TABLE IF NOT EXISTS cache (question TEXT PRIMARY KEY, answer TEXT)""" ) conn.commit()
@contextmanager def _get_conn(self): """Get a thread-safe database connection""" if not hasattr(self._local, "conn"): with self._lock: self._local.conn = sqlite3.connect(self.db_path) # Enable foreign key constraints self._local.conn.execute("PRAGMA foreign_keys = ON") try: yield self._local.conn except sqlite3.Error as e: # If a database error occurs, close and clean up the connection if hasattr(self._local, "conn"): self._local.conn.close() delattr(self._local, "conn") raise RuntimeError(f"Database error: {str(e)}")
[docs] def insert(self, question: str, answer: str) -> None: """Thread-safe insert operation""" question_hash = calculate_hash(question) try: with self._get_conn() as conn: conn.execute( "INSERT OR REPLACE INTO cache (question, answer) VALUES (?, ?)", (question_hash, answer), ) conn.commit() except Exception as e: raise RuntimeError(f"Failed to insert into cache: {str(e)}")
[docs] def get(self, question: str) -> Optional[str]: """Thread-safe query operation""" question_hash = calculate_hash(question) try: with self._get_conn() as conn: cursor = conn.execute( "SELECT answer FROM cache WHERE question = ?", (question_hash,) ) result = cursor.fetchone() if result is None: return None answer: str = result[0] return answer except Exception as e: raise RuntimeError(f"Failed to get from cache: {str(e)}")
[docs] def clear(self) -> None: """Clear the cache""" try: with self._get_conn() as conn: conn.execute("DELETE FROM cache") conn.commit() except Exception as e: raise RuntimeError(f"Failed to clear cache: {str(e)}")
[docs] def close(self) -> None: """Close the database connection for the current thread""" if hasattr(self._local, "conn"): try: self._local.conn.close() delattr(self._local, "conn") except Exception as e: raise RuntimeError(f"Failed to close connection: {str(e)}")
[docs] def exists(self, question: str) -> bool: """Check if a question exists in the cache""" question_hash = calculate_hash(question) try: with self._get_conn() as conn: cursor = conn.execute( "SELECT 1 FROM cache WHERE question = ? LIMIT 1", (question_hash,) ) return cursor.fetchone() is not None except Exception as e: logger.error(f"Failed to check cache existence: {str(e)}") return False
[docs] def delete(self, question: str) -> None: question_hash = calculate_hash(question) try: with self._get_conn() as conn: conn.execute("DELETE FROM cache WHERE question = ?", (question_hash,)) conn.commit() except Exception as e: raise RuntimeError(f"Failed to delete from cache: {str(e)}")
def __del__(self): try: self.close() except BaseException: pass # Ignore cleanup errors
# Example usage if __name__ == "__main__": cache = ModelCache() question = "What is the capital of France?" answer = "Paris" # Insert into cache cache.insert(question, answer) print(cache.exists(question)) # Retrieve from cache cached_answer = cache.get(question) print(f"Cached answer: {cached_answer}") # non-existent question non_existent_question = "What is the capital of Mars?" print(cache.exists(non_existent_question))