diff --git a/main.py b/main.py index 69e6138..f2e8e15 100644 --- a/main.py +++ b/main.py @@ -1,15 +1,34 @@ #!/usr/bin/env python3 -from dotenv import dotenv_values +import asyncio +import threading +import logging + +from src.mempool import WebSocketThread, QueueProcessor +from src.db import Handler -def main(cfg): - pass +def main(): + # FIFO queue for cross-thread communications + q = asyncio.Queue() + shutdown_event = threading.Event() + handler = Handler() + ws_thread = WebSocketThread(q, shutdown_event) + qp_thread = QueueProcessor(q, shutdown_event, handler) -def dotconfig(path=".env"): - return dotenv_values(path) + ws_thread.start() + qp_thread.start() + + try: + ws_thread.join() + qp_thread.join() + except KeyboardInterrupt: + logging.info("Keyboard interrupt received, shutting down threads.") + shutdown_event.set() + ws_thread.join() + qp_thread.join() if __name__ == "__main__": - main(dotconfig()) + main() diff --git a/src/db.py b/src/db.py index ba91067..971d9a8 100644 --- a/src/db.py +++ b/src/db.py @@ -2,7 +2,9 @@ import sqlite3 class Handler: - def __init__(self, database="chainmapper.sql"): + """Handles all SQLite connections required to create, update, and export the stored addresses.""" + + def __init__(self, database="chainmapper.sqlite3"): self.database = database # Notably `connect` automatically creates the database if it doesn't already exist self.con = sqlite3.connect(self.database) @@ -37,6 +39,6 @@ class Handler: ) self.con.commit() - def get_ordered(self): - # TODO: return addresses in descending order (based on `total_tx_count`) + def export(self): + # TODO: handle exporting pass diff --git a/src/mempool.py b/src/mempool.py index cadaeb4..e854dfb 100644 --- a/src/mempool.py +++ b/src/mempool.py @@ -1,74 +1,84 @@ import asyncio import json import threading -import websocket +import logging +import websockets from const import WS_ADDR -# FIFO queue for cross-thread communications -tx_queue = asyncio.Queue() -tx_count = 0 + +class WebSocketThread(threading.Thread): + def __init__(self, q, shutdown_event): + super().__init__() + self.q = q + self.shutdown_event = shutdown_event + self.tx_count = 0 + + async def connect(self): + async with websockets.connect(WS_ADDR) as ws: + while not self.shutdown_event.is_set(): + try: + msg = await ws.recv() + data = self.handle_msg(msg) + + if data is None: + continue + + self.q.put(data) + except websockets.exceptions.ConnectionClosed: + logging.info("WebSocket connection closed") + self.shutdown_event.set() + break + # pylint: disable=broad-exception-caught + except Exception as e: + logging.error("WebSocket error: %s", e) + self.shutdown_event.set() + break + + def handle_msg(self, msg): + msg_json = json.loads(msg) + + try: + tx_sender = msg_json["transaction"]["from"] + except KeyError as e: + logging.error("Error parsing a WebSocket message: %s", e) + return None + + self.tx_count += 1 + + if self.tx_count % 1000 == 0: + logging.info("Currently at %d received transactions", self.tx_count) + + return tx_sender + + def run(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete(self.connect()) + # pylint: disable=broad-exception-caught + except Exception as e: + logging.error("WebSocket thread crashed: %s", e) + self.shutdown_event.set() + finally: + loop.close() -async def process_queue(): - """Handles emptying the transaction queue and calling the database module with the received data.""" - while True: - # TODO: handle graceful shutdown - tx_sender = tx_queue.get() - # TODO: send `tx_sender` to the db module - tx_count += 1 - tx_queue.task_done() +class QueueProcessor(threading.Thread): + def __init__(self, q, shutdown_event, handler): + super().__init__() + self.q = q + self.shutdown_event = shutdown_event + self.handler = handler - -def on_message(_, msg, loop): - msg_json = json.loads(msg) - - try: - tx_sender = msg_json["transaction"]["from"] - except KeyError as e: - # TODO: log the seen KeyError `e` & handle what happens next (i.e. proper error handling)? - return - - future = asyncio.run_coroutine_threadsafe(tx_queue.put(tx_sender), loop) - future.result() # Won't timeout - - -def on_error(_, err): - # TODO: error handling - exit(1) - - -def on_close(_, status_code, msg): - # TODO: log `status_code` & `msg` - pass - - -def on_open(ws): - # TODO: log "Connection opened" - - # Subscribed entity could also be `pending_transactions` to receive the transactions directly - # from the mempool. - ws.send(json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})) - - # TODO: log "Subscription message sent" - - -async def start_monitor(): - """Connects to the WebSocket feed of mined Ethereum transactions""" - queue_processor = asyncio.create_task(process_queue()) - loop = asyncio.get_event_loop() - ws = websocket.WebSocketApp( - WS_ADDR, - on_open=on_open, - on_message=lambda ws, msg: on_message(ws, msg, loop), - on_error=on_error, - on_close=on_close, - ) - - # Run the WebSocket client in a separate thread - # TODO: replace `run_forever` with something that can be signaled to shutdown gracefully - ws_thread = threading.Thread(target=ws.run_forever) - ws_thread.start() - - # Wait for the processor to finish cleaning up the queue before shutting down - await queue_processor() + def run(self): + while not self.shutdown_event.is_set(): + try: + tx_sender = self.q.get() # Waits here until new msg is available + self.handler.store(tx_sender) + # pylint: disable=broad-exception-caught + except Exception as e: + logging.error("QueueProcessor thread crashed: %s", e) + self.shutdown_event.set() + break