diff --git a/requirements.txt b/requirements.txt index 1ab3894..22409fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ python-qbittorrent==0.4.3 +SQLAlchemy==2.0.38 diff --git a/tarc/main.py b/tarc/main.py index 1432e5e..0e017b1 100644 --- a/tarc/main.py +++ b/tarc/main.py @@ -13,145 +13,81 @@ import sys import re import uuid import argparse -import sqlite3 - from datetime import datetime, timezone import qbittorrent +from sqlalchemy import create_engine, inspect +from sqlalchemy.orm import Session +from sqlalchemy.exc import DatabaseError + +from .models import Base, SchemaVersion, Client # SCHEMA format is YYYYMMDDX -SCHEMA = 202410060 +SCHEMA = 202503100 -def init_db(conn): +def init_db(engine): """ Initialize database """ + Base.metadata.create_all(engine) - c = conn.cursor() - c.executescript( - f""" - PRAGMA user_version = {SCHEMA}; - - CREATE TABLE IF NOT EXISTS clients ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL UNIQUE, - uuid TEXT NOT NULL UNIQUE, - endpoint TEXT NOT NULL, - last_seen DATETIME NOT NULL - ); - - CREATE TABLE IF NOT EXISTS torrents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - info_hash_v1 TEXT NOT NULL UNIQUE, - info_hash_v2 TEXT UNIQUE, - file_count INTEGER NOT NULL, - completed_on DATETIME NOT NULL - ); - - CREATE TABLE IF NOT EXISTS torrent_clients ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - torrent_id INTEGER NOT NULL, - client_id INTEGER NOT NULL, - name TEXT NOT NULL, - content_path TEXT NOT NULL, - last_seen DATETIME NOT NULL, - FOREIGN KEY (torrent_id) REFERENCES torrents(id), - FOREIGN KEY (client_id) REFERENCES clients(id), - UNIQUE (torrent_id, client_id) - ); - - CREATE TABLE IF NOT EXISTS trackers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - url TEXT NOT NULL UNIQUE, - last_seen DATETIME NOT NULL - ); - - CREATE TABLE IF NOT EXISTS torrent_trackers ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - client_id INTEGER NOT NULL, - torrent_id INTEGER NOT NULL, - tracker_id INTEGER NOT NULL, - last_seen DATETIME NOT NULL, - FOREIGN KEY (client_id) REFERENCES clients(id), - FOREIGN KEY (torrent_id) REFERENCES torrents(id), - FOREIGN KEY (tracker_id) REFERENCES trackers(id), - UNIQUE (client_id, torrent_id, tracker_id) - ); - - CREATE TABLE IF NOT EXISTS files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - size INTEGER NOT NULL, - oshash TEXT NOT NULL UNIQUE, - hash TEXT UNIQUE - ); - - CREATE TABLE IF NOT EXISTS torrent_files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - torrent_id INTEGER NOT NULL, - client_id INTEGER NOT NULL, - file_index INTEGER NOT NULL, - file_path TEXT NOT NULL, - is_downloaded BOOLEAN NOT NULL, - last_checked DATETIME NOT NULL, - FOREIGN KEY (file_id) REFERENCES files(id), - FOREIGN KEY (torrent_id) REFERENCES torrents(id), - FOREIGN KEY (client_id) REFERENCES clients(id), - UNIQUE (file_id, torrent_id, client_id, file_index) - ); - """ - ) - conn.commit() - c.close() + with Session(engine) as session: + if not session.query(SchemaVersion).first(): + now = datetime.now(timezone.utc) + version = SchemaVersion(version=SCHEMA, applied_at=now) + session.add(version) + session.commit() -def list_tables(conn): +def get_schema_version(engine): + """ + Get current schema version from database + """ + with Session(engine) as session: + version = session.query(SchemaVersion).order_by(SchemaVersion.id.desc()).first() + return version.version if version else None + + +def list_tables(engine): """ List all tables in database """ - c = conn.cursor() - c.execute("SELECT name FROM sqlite_master WHERE type='table';") - table_list = c.fetchall() - c.close() - return [table[0] for table in table_list] + inspector = inspect(engine) + return inspector.get_table_names() -def add_client(conn, name, endpoint, last_seen): +def add_client(engine, name, endpoint, last_seen): """ Add a new client endpoint to database """ - c = conn.cursor() - c.execute( - f""" - INSERT INTO clients (uuid, name, endpoint, last_seen) - VALUES ("{uuid.uuid4()}", "{name}", "{endpoint}", "{last_seen}"); - """ - ) - conn.commit() - c.close() + with Session(engine) as session: + client = Client( + uuid=str(uuid.uuid4()), name=name, endpoint=endpoint, last_seen=last_seen + ) + session.add(client) + session.commit() -def find_client(conn, endpoint): +def find_client(engine, endpoint): """ Find existing client """ - c = conn.cursor() - c.execute(f'SELECT id, name, uuid FROM clients WHERE endpoint="{endpoint}";') - response = c.fetchall() - c.close() - return response + with Session(engine) as session: + clients = ( + session.query(Client.id, Client.name, Client.uuid) + .filter_by(endpoint=endpoint) + .all() + ) + return clients -def list_clients(conn): +def list_clients(engine): """ List all stored clients """ - c = conn.cursor() - c.execute("SELECT * FROM clients;") - rows = c.fetchall() - c.close() - return rows + with Session(engine) as session: + return session.query(Client).all() def main(): @@ -181,29 +117,33 @@ def main(): if args.command == "scan": if args.storage is None: - STORAGE = os.path.expanduser("~/.tarch.db") + storage_path = os.path.expanduser("~/.tarc.db") else: - STORAGE = args.storage + storage_path = args.storage + try: - sqlitedb = sqlite3.connect(STORAGE) - tables = list_tables(sqlitedb) - except sqlite3.DatabaseError as e: - print(f'[ERROR]: Database Error "{STORAGE}" ({str(e)})') + engine = create_engine(f"sqlite:///{storage_path}") + tables = list_tables(engine) + except DatabaseError as e: + print(f'[ERROR]: Database Error "{storage_path}" ({str(e)})') sys.exit(1) - if len(tables) == 0: - print(f"[INFO]: Initializing database at {STORAGE}") - init_db(sqlitedb) - cursor = sqlitedb.cursor() - cursor.execute("PRAGMA user_version;") - SCHEMA_FOUND = cursor.fetchone()[0] - cursor.close() - if not SCHEMA == SCHEMA_FOUND: - print(f"[ERROR]: SCHEMA {SCHEMA_FOUND}, expected {SCHEMA}") + + if not tables: + print(f"[INFO]: Initializing database at {storage_path}") + init_db(engine) + + schema_found = get_schema_version(engine) + if schema_found is None: + print("[ERROR]: Could not determine schema version") sys.exit(1) - if not args.directory is None: + if not SCHEMA == schema_found: + print(f"[ERROR]: SCHEMA {schema_found}, expected {SCHEMA}") + sys.exit(1) + + if args.directory is not None: print("[INFO]: --directory is not implemented") sys.exit(0) - elif not args.endpoint is None: + elif args.endpoint is not None: qb = qbittorrent.Client(args.endpoint) if qb.qbittorrent_version is None: print(f'[ERROR]: Couldn\'t find client version at "{args.endpoint}"') @@ -217,14 +157,13 @@ def main(): print( f'[INFO]: Found qbittorrent {qb.qbittorrent_version} at "{args.endpoint}"' ) - clients = find_client(sqlitedb, args.endpoint) + + clients = find_client(engine, args.endpoint) if args.confirm_add: if len(clients) == 0: - if not args.name is None: - now = datetime.now(timezone.utc).isoformat( - sep=" ", timespec="seconds" - ) - add_client(sqlitedb, args.name, args.endpoint, now) + if args.name is not None: + now = datetime.now(timezone.utc) + add_client(engine, args.name, args.endpoint, now) print(f"[INFO]: Added client {args.name} ({args.endpoint})") else: print("[ERROR]: Must specify --name for a new client") diff --git a/tarc/models.py b/tarc/models.py new file mode 100644 index 0000000..4eddacc --- /dev/null +++ b/tarc/models.py @@ -0,0 +1,106 @@ +"""SQLAlchemy models for the tarc database.""" + +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + Boolean, + ForeignKey, + UniqueConstraint, +) +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class SchemaVersion(Base): # pylint: disable=too-few-public-methods + """Database schema version tracking.""" + + __tablename__ = "schema_version" + id = Column(Integer, primary_key=True) + version = Column(Integer, nullable=False) + applied_at = Column(DateTime, nullable=False) + + +class Client(Base): # pylint: disable=too-few-public-methods + """BitTorrent client instance.""" + + __tablename__ = "clients" + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False, unique=True) + uuid = Column(String, nullable=False, unique=True) + endpoint = Column(String, nullable=False) + last_seen = Column(DateTime, nullable=False) + + +class Torrent(Base): # pylint: disable=too-few-public-methods + """BitTorrent metadata.""" + + __tablename__ = "torrents" + id = Column(Integer, primary_key=True) + info_hash_v1 = Column(String, nullable=False, unique=True) + info_hash_v2 = Column(String, unique=True) + file_count = Column(Integer, nullable=False) + completed_on = Column(DateTime, nullable=False) + + +class TorrentClient(Base): # pylint: disable=too-few-public-methods + """Association between torrents and clients.""" + + __tablename__ = "torrent_clients" + id = Column(Integer, primary_key=True) + torrent_id = Column(Integer, ForeignKey("torrents.id"), nullable=False) + client_id = Column(Integer, ForeignKey("clients.id"), nullable=False) + name = Column(String, nullable=False) + content_path = Column(String, nullable=False) + last_seen = Column(DateTime, nullable=False) + __table_args__ = (UniqueConstraint("torrent_id", "client_id"),) + + +class Tracker(Base): # pylint: disable=too-few-public-methods + """BitTorrent tracker information.""" + + __tablename__ = "trackers" + id = Column(Integer, primary_key=True) + url = Column(String, nullable=False, unique=True) + last_seen = Column(DateTime, nullable=False) + + +class TorrentTracker(Base): # pylint: disable=too-few-public-methods + """Association between torrents and trackers.""" + + __tablename__ = "torrent_trackers" + id = Column(Integer, primary_key=True) + client_id = Column(Integer, ForeignKey("clients.id"), nullable=False) + torrent_id = Column(Integer, ForeignKey("torrents.id"), nullable=False) + tracker_id = Column(Integer, ForeignKey("trackers.id"), nullable=False) + last_seen = Column(DateTime, nullable=False) + __table_args__ = (UniqueConstraint("client_id", "torrent_id", "tracker_id"),) + + +class File(Base): # pylint: disable=too-few-public-methods + """File metadata and hashes.""" + + __tablename__ = "files" + id = Column(Integer, primary_key=True) + size = Column(Integer, nullable=False) + oshash = Column(String, nullable=False, unique=True) + hash = Column(String, unique=True) + + +class TorrentFile(Base): # pylint: disable=too-few-public-methods + """Association between torrents and files.""" + + __tablename__ = "torrent_files" + id = Column(Integer, primary_key=True) + file_id = Column(Integer, ForeignKey("files.id"), nullable=False) + torrent_id = Column(Integer, ForeignKey("torrents.id"), nullable=False) + client_id = Column(Integer, ForeignKey("clients.id"), nullable=False) + file_index = Column(Integer, nullable=False) + file_path = Column(String, nullable=False) + is_downloaded = Column(Boolean, nullable=False) + last_checked = Column(DateTime, nullable=False) + __table_args__ = ( + UniqueConstraint("file_id", "torrent_id", "client_id", "file_index"), + )