Source code for duetector.db

from contextlib import contextmanager
from datetime import datetime
from threading import Lock
from typing import Any, Dict, Generator, List, Optional

import sqlalchemy  # type: ignore
from sqlalchemy.orm import (  # type: ignore
    DeclarativeBase,
    Mapped,
    Session,
    mapped_column,
    sessionmaker,
)
from sqlalchemy.types import JSON  # type: ignore

from duetector.analyzer.models import Tracking as AT
from duetector.collectors.models import Tracking as CT
from duetector.config import Configuable


[docs] class TrackingMixin: """ A mixin for sqlalchemy model to track process """ id: Mapped[int] = mapped_column( primary_key=True, autoincrement=True, ) pid: Mapped[Optional[int]] uid: Mapped[Optional[int]] gid: Mapped[Optional[int]] dt: Mapped[Optional[datetime]] comm: Mapped[Optional[str]] cwd: Mapped[Optional[str]] fname: Mapped[Optional[str]] extended: Mapped[Dict[str, Any]] = mapped_column(type_=JSON, default={}) def __repr__(self): return f"<Tracking [{self.pid} {self.comm}] {self.dt}>"
[docs] class TrackingInterface: """ A interface for tracking. """
[docs] def to_collector_tracking(self) -> CT: """ Convert to collector's tracking model. """ raise NotImplementedError
[docs] def to_analyzer_tracking(self) -> AT: """ Convert to analyzer's tracking model. """ raise NotImplementedError
[docs] @classmethod def inspect_fields( cls, value_as_type: bool = False, ) -> Dict[str, Any]: raise NotImplementedError
[docs] class SessionManager(Configuable): """ A wrapper for sqlalchemy session Special config: table_prefix (str): prefix for all table names engine (Dict[str, Any]): config for sqlalchemy.create_engine Example: .. code-block:: python from duetector.db import SessionManager from duetector.collectors.models import Tracking sessionmanager = SessionManager() t = Tracking( tracer="t", ) m = sessionmanager.get_tracking_model(t.tracer, "id") with sessionmanager.begin() as session: session.add(m(**t.model_dump(exclude=["tracer"]))) session.commit() assert sessionmanager.inspect_all_tables() == [ sessionmanager.get_table_names("t", "id") ] assert sessionmanager.inspect_all_tables("not-exist") == [] """ config_scope = "db" default_config = { "table_prefix": "duetector_tracking", "engine": { "url": "sqlite:///:memory:", }, } def __repr__(self): url = self.config.engine.url or "" if "@" in url: database_type = self.config.engine.url.split(":")[0] safe_url = f'{database_type}://********@{(self.config.engine.url or "").split("@")[-1]}' else: safe_url = url return f"<[SessionManager {safe_url}]{self.table_prefix}*>" def __init__(self, config: Optional[Dict[str, Any]] = None, *args, **kwargs): super().__init__(config, *args, **kwargs) self._engine: Optional[sqlalchemy.engine.Engine] = None self._sessionmaker: Optional[sessionmaker] = None self._tracking_models: Dict[str, type] = {} self.mutex = Lock() @property def debug(self): return self.config.debug or self.config.echo @property def table_prefix(self): """ Prefix for all table names """ return self.config.table_prefix @property def engine_config(self) -> Dict[str, Any]: """ Config for sqlalchemy.create_engine """ config = self.config.engine._config_dict if self.debug: config["echo"] = True db_url = config.get("url", "") if ":memory:" in db_url: config.setdefault("poolclass", sqlalchemy.StaticPool) if "sqlite" in db_url: config.setdefault("connect_args", {"check_same_thread": False}) return config @property def engine(self): """ A sqlalchemy engine """ if not self._engine: self._engine = sqlalchemy.create_engine(**self.engine_config) return self._engine @property def sessionmaker(self): """ A sessionmaker for sqlalchemy session """ if not self._sessionmaker: self._sessionmaker = sessionmaker(bind=self.engine) return self._sessionmaker
[docs] @contextmanager def begin(self) -> Generator[Session, None, None]: """ Get a sqlalchemy session. Example: .. code-block:: python with session_manager.begin() as session: session.add(...) """ with self.sessionmaker.begin() as session: yield session
[docs] def get_table_names(self, tracer: str = "unknown", collector_id: str = "") -> str: return f"{self.table_prefix}:{tracer}@{collector_id}"
[docs] def table_name_to_tracer(self, table_name: str) -> str: return table_name.split(":")[1].split("@")[0]
[docs] def table_name_to_collector_id(self, table_name: str) -> str: return table_name.split(":")[1].split("@")[1]
[docs] def get_tracking_model( self, tracer: str = "unknown", collector_id: str = "" ) -> TrackingInterface: """ Get a sqlalchemy model for tracking, each tracer will create a table in database. Args: tracer (str): name of tracer collector_id (str): id of collector Returns: type: a sqlalchemy model for tracking """ if tracer in self._tracking_models: return self._tracking_models[tracer] with self.mutex: if tracer in self._tracking_models: return self._tracking_models[tracer] class Base(DeclarativeBase): pass class TrackingModel(Base, TrackingMixin, TrackingInterface): __tablename__ = self.get_table_names(tracer, collector_id) def to_collector_tracking(self) -> CT: return CT( tracer=tracer, pid=self.pid, uid=self.uid, gid=self.gid, dt=self.dt, comm=self.comm, cwd=self.cwd, fname=self.fname, extended=self.extended, ) def to_analyzer_tracking(self) -> AT: return AT( tracer=tracer, pid=self.pid, uid=self.uid, gid=self.gid, dt=self.dt, comm=self.comm, cwd=self.cwd, fname=self.fname, extended=self.extended, ) @classmethod def inspect_fields( cls, value_as_type: bool = False, ) -> Dict[str, Any]: return { c.name: c.type.python_type if value_as_type else c.type.python_type.__name__ for c in cls.__table__.columns if c.name != "id" } try: self._tracking_models[tracer] = self._init_tracking_model(TrackingModel) except Exception as e: # FIXME: unregister TrackingModel raise return self._tracking_models[tracer]
[docs] def get_all_models(self) -> Dict[str, type]: return self._tracking_models
[docs] def inspect_all_tables( self, tracer: Optional[str] = None, collector_id: Optional[str] = None ) -> str: def _filter(t): if tracer and self.table_name_to_tracer(t) != tracer: return False if collector_id and self.table_name_to_collector_id(t) != collector_id: return False return True return [ t for t in sqlalchemy.inspect(self.engine).get_table_names() if t.startswith(self.table_prefix) and _filter(t) ]
[docs] def inspect_all_tracers(self) -> List[str]: return list(set(self.table_name_to_tracer(t) for t in self.inspect_all_tables()))
[docs] def inspect_all_collector_ids(self) -> List[str]: return list(set(self.table_name_to_collector_id(t) for t in self.inspect_all_tables()))
def _init_tracking_model(self, tracking_model: type) -> type: if not sqlalchemy.inspect(self.engine).has_table(tracking_model.__tablename__): tracking_model.__table__.create(self.engine) return tracking_model
if __name__ == "__main__": from duetector.collectors.models import Tracking sessionmanager = SessionManager() t = Tracking( tracer="t", ) m = sessionmanager.get_tracking_model(t.tracer, "id") with sessionmanager.begin() as session: session.add(m(**t.model_dump(exclude=["tracer"]))) session.commit() assert sessionmanager.inspect_all_tables() == [sessionmanager.get_table_names("t", "id")] assert sessionmanager.inspect_all_tables("not-exist") == []