diff --git a/main.py b/main.py old mode 100644 new mode 100755 index f2e8e15..65bafae --- a/main.py +++ b/main.py @@ -3,17 +3,55 @@ import asyncio import threading import logging +import aioprocessing +from dotenv import dotenv_values from src.mempool import WebSocketThread, QueueProcessor -from src.db import Handler +from src.db import Handler, periodic_export +from src.const import EXPORT_INTERVAL -def main(): - # FIFO queue for cross-thread communications - q = asyncio.Queue() - shutdown_event = threading.Event() +async def shutdown(loop, signal=None): + """Cleanup tasks tied to the service's shutdown.""" + if signal: + logging.info("Received exit signal %s", signal.name) + + logging.info("Napping for a bit before shutting down...") + await asyncio.sleep(2) + + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + + for t in tasks: + t.cancel() + + logging.info("Cancelling %d outstanding tasks", len(tasks)) + await asyncio.gather(*tasks, return_exceptions=True) + + logging.info("Flushing metrics") + loop.stop() + + +def main(env_path=".env"): + cfg = dotenv_values(env_path) + mode = cfg["MODE"] + + if mode is None or mode.lower() == "production": + log_level = logging.INFO + else: + log_level = logging.DEBUG + + logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=log_level) + + # FIFO queue for crosst-thread communications + q = aioprocessing.AioQueue() handler = Handler() + loop = asyncio.new_event_loop() + # TODO: handle scheduling of the export task + # loop.create_task(periodic_export(handler, EXPORT_INTERVAL)) + # export_task_fut = asyncio.run_coroutine_threadsafe(periodic_export, loop) + shutdown_event = threading.Event() + ws_thread = WebSocketThread(q, shutdown_event) qp_thread = QueueProcessor(q, shutdown_event, handler) @@ -26,8 +64,12 @@ def main(): except KeyboardInterrupt: logging.info("Keyboard interrupt received, shutting down threads.") shutdown_event.set() + loop.run_until_complete(shutdown(loop)) ws_thread.join() qp_thread.join() + finally: + loop.stop() + loop.close() if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 02e1832..cc2010a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +aioprocessing==2.0.1 websockets==12.0 diff --git a/src/const.py b/src/const.py index f02e346..2115fc7 100644 --- a/src/const.py +++ b/src/const.py @@ -2,3 +2,6 @@ import json WS_ADDR = "wss://ws.blockchain.info/coins" SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"}) + +# EXPORT_INTERVAL = 24 * 60 * 60 # 24 hours in seconds +EXPORT_INTERVAL = 30 diff --git a/src/db.py b/src/db.py index 971d9a8..b784023 100644 --- a/src/db.py +++ b/src/db.py @@ -1,13 +1,16 @@ import sqlite3 +import json +import logging +import asyncio class Handler: - """Handles all SQLite connections required to create, update, and export the stored addresses.""" + """Handle all SQLite connections required to create, update, and export the stored addresses.""" def __init__(self, database="chainmapper.sqlite3"): self.database = database # Notably `connect` automatically creates the database if it doesn't already exist - self.con = sqlite3.connect(self.database) + self.con = sqlite3.connect(self.database, check_same_thread=False) self.cursor = self.con.cursor() # Initialize the table if necessary @@ -24,10 +27,14 @@ class Handler: self.con.commit() - def store(self, address): + async def store(self, address): + """Store a new address into the SQLite database, or increments the counter by one if the given address already exists in the database.""" + await asyncio.to_thread(self._store, address) + + def _store(self, address): self.cursor.execute( """ - INSERT INTO AddressTracking (address) + INSERT INTO AddressMapping (address) VALUES (?) ON CONFLICT(address) DO UPDATE @@ -35,10 +42,37 @@ class Handler: total_tx_count = total_tx_count + 1, last_updated = CURRENT_TIMESTAMP; """, - address, + (address,), ) self.con.commit() - def export(self): - # TODO: handle exporting - pass + async def export(self, filename="export.json"): + """Export the addresses from the SQLite database in descending order based on the transaction counts.""" + await asyncio.to_thread(self._export, filename) + + def _export(self, filename="export.json"): + self.cursor.execute( + """ + SELECT address, total_tx_count + FROM AddressMapping + ORDER BY total_tx_count DESC; + """ + ) + records = self.cursor.fetchall() + data = [{"address": record[0], "tx_count": record[1]} for record in records] + data_json = json.dumps(data, indent=4) + + logging.info("Exporting the database's current state to '%s' (overwriting if an old copy exists)...", filename) + + with open(filename, "w", encoding="utf-8") as f: + f.write(data_json) + + logging.info("Data exported to '%s'", filename) + + +async def periodic_export(handler, interval): + logging.info("Scheduled export task created") + + while True: + await asyncio.sleep(interval) + await handler.export() diff --git a/src/mempool.py b/src/mempool.py index f2e8eef..f6eb8ac 100644 --- a/src/mempool.py +++ b/src/mempool.py @@ -4,12 +4,15 @@ import threading import logging import websockets -from const import WS_ADDR, SUB_MSG +from src.const import WS_ADDR, SUB_MSG class WebSocketThread(threading.Thread): + """Handle connection, subscription, and message parsing for the Blockchain.com WebSocket.""" + def __init__(self, q, shutdown_event, sub_msg=SUB_MSG): super().__init__() + self.name = "WebSocketThread" self.q = q self.shutdown_event = shutdown_event self.sub_msg = sub_msg @@ -21,6 +24,9 @@ class WebSocketThread(threading.Thread): await ws.send(self.sub_msg) logging.info("Subscription message sent") + # Ignores the confirmation message, as it can't be parsed with the same template + _ = await ws.recv() + while not self.shutdown_event.is_set(): try: msg = await ws.recv() @@ -29,14 +35,14 @@ class WebSocketThread(threading.Thread): if data is None: continue - self.q.put(data) + await self.q.coro_put(data) except websockets.exceptions.ConnectionClosed: logging.info("WebSocket connection closed") self.shutdown_event.set() break # pylint: disable=broad-exception-caught except Exception as e: - logging.error("WebSocket error: %s", e) + logging.error("WebSocket error: %s", str(e)) self.shutdown_event.set() break @@ -46,7 +52,7 @@ class WebSocketThread(threading.Thread): try: tx_sender = msg_json["transaction"]["from"] except KeyError as e: - logging.error("Error parsing a WebSocket message: %s", e) + logging.error("Error parsing a WebSocket message: %s", str(e)) return None self.tx_count += 1 @@ -57,6 +63,7 @@ class WebSocketThread(threading.Thread): return tx_sender def run(self): + """Start the WebSocket thread that'll run until it receives a shutdown message or crashes.""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -64,26 +71,43 @@ class WebSocketThread(threading.Thread): loop.run_until_complete(self.connect()) # pylint: disable=broad-exception-caught except Exception as e: - logging.error("WebSocket thread crashed: %s", e) + logging.error("WebSocket thread crashed: %s", str(e)) self.shutdown_event.set() finally: loop.close() class QueueProcessor(threading.Thread): + """Handle processing of items from the cross-thread queue where the WebSocket thread feeds data into.""" + def __init__(self, q, shutdown_event, handler): super().__init__() + self.name = "QueueProcessor" self.q = q self.shutdown_event = shutdown_event self.handler = handler - def run(self): + async def process_queue(self): while not self.shutdown_event.is_set(): try: - tx_sender = self.q.get() # Waits here until new msg is available - self.handler.store(tx_sender) + tx_sender = await self.q.coro_get() # Waits here until new msg is available + await self.handler.store(tx_sender) # pylint: disable=broad-exception-caught except Exception as e: - logging.error("QueueProcessor thread crashed: %s", e) + logging.error("QueueProcessor thread crashed: %s", str(e)) self.shutdown_event.set() break + + def run(self): + """Start the queue processing thread that'll run until it receives a shutdown message or crashes.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete(self.process_queue()) + # pylint: disable=broad-exception-caught + except Exception as e: + logging.error("QueueProcessor thread crashed: %s", str(e)) + self.shutdown_event.set() + finally: + loop.close()