refactor: use wrapped threads for websocket & queue processor

This commit is contained in:
17ms 2024-07-01 20:36:50 +03:00
parent d10cc13fbf
commit 91d04f63e9
Signed by untrusted user who does not match committer: ae
GPG Key ID: 995EFD5C1B532B3E
3 changed files with 105 additions and 74 deletions

31
main.py
View File

@ -1,15 +1,34 @@
#!/usr/bin/env python3
from dotenv import dotenv_values
import asyncio
import threading
import logging
from src.mempool import WebSocketThread, QueueProcessor
from src.db import Handler
def main(cfg):
pass
def main():
# FIFO queue for cross-thread communications
q = asyncio.Queue()
shutdown_event = threading.Event()
handler = Handler()
ws_thread = WebSocketThread(q, shutdown_event)
qp_thread = QueueProcessor(q, shutdown_event, handler)
def dotconfig(path=".env"):
return dotenv_values(path)
ws_thread.start()
qp_thread.start()
try:
ws_thread.join()
qp_thread.join()
except KeyboardInterrupt:
logging.info("Keyboard interrupt received, shutting down threads.")
shutdown_event.set()
ws_thread.join()
qp_thread.join()
if __name__ == "__main__":
main(dotconfig())
main()

View File

@ -2,7 +2,9 @@ import sqlite3
class Handler:
def __init__(self, database="chainmapper.sql"):
"""Handles all SQLite connections required to create, update, and export the stored addresses."""
def __init__(self, database="chainmapper.sqlite3"):
self.database = database
# Notably `connect` automatically creates the database if it doesn't already exist
self.con = sqlite3.connect(self.database)
@ -37,6 +39,6 @@ class Handler:
)
self.con.commit()
def get_ordered(self):
# TODO: return addresses in descending order (based on `total_tx_count`)
def export(self):
# TODO: handle exporting
pass

View File

@ -1,74 +1,84 @@
import asyncio
import json
import threading
import websocket
import logging
import websockets
from const import WS_ADDR
# FIFO queue for cross-thread communications
tx_queue = asyncio.Queue()
tx_count = 0
class WebSocketThread(threading.Thread):
def __init__(self, q, shutdown_event):
super().__init__()
self.q = q
self.shutdown_event = shutdown_event
self.tx_count = 0
async def connect(self):
async with websockets.connect(WS_ADDR) as ws:
while not self.shutdown_event.is_set():
try:
msg = await ws.recv()
data = self.handle_msg(msg)
if data is None:
continue
self.q.put(data)
except websockets.exceptions.ConnectionClosed:
logging.info("WebSocket connection closed")
self.shutdown_event.set()
break
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("WebSocket error: %s", e)
self.shutdown_event.set()
break
def handle_msg(self, msg):
msg_json = json.loads(msg)
try:
tx_sender = msg_json["transaction"]["from"]
except KeyError as e:
logging.error("Error parsing a WebSocket message: %s", e)
return None
self.tx_count += 1
if self.tx_count % 1000 == 0:
logging.info("Currently at %d received transactions", self.tx_count)
return tx_sender
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.connect())
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("WebSocket thread crashed: %s", e)
self.shutdown_event.set()
finally:
loop.close()
async def process_queue():
"""Handles emptying the transaction queue and calling the database module with the received data."""
while True:
# TODO: handle graceful shutdown
tx_sender = tx_queue.get()
# TODO: send `tx_sender` to the db module
tx_count += 1
tx_queue.task_done()
class QueueProcessor(threading.Thread):
def __init__(self, q, shutdown_event, handler):
super().__init__()
self.q = q
self.shutdown_event = shutdown_event
self.handler = handler
def on_message(_, msg, loop):
msg_json = json.loads(msg)
try:
tx_sender = msg_json["transaction"]["from"]
except KeyError as e:
# TODO: log the seen KeyError `e` & handle what happens next (i.e. proper error handling)?
return
future = asyncio.run_coroutine_threadsafe(tx_queue.put(tx_sender), loop)
future.result() # Won't timeout
def on_error(_, err):
# TODO: error handling
exit(1)
def on_close(_, status_code, msg):
# TODO: log `status_code` & `msg`
pass
def on_open(ws):
# TODO: log "Connection opened"
# Subscribed entity could also be `pending_transactions` to receive the transactions directly
# from the mempool.
ws.send(json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"}))
# TODO: log "Subscription message sent"
async def start_monitor():
"""Connects to the WebSocket feed of mined Ethereum transactions"""
queue_processor = asyncio.create_task(process_queue())
loop = asyncio.get_event_loop()
ws = websocket.WebSocketApp(
WS_ADDR,
on_open=on_open,
on_message=lambda ws, msg: on_message(ws, msg, loop),
on_error=on_error,
on_close=on_close,
)
# Run the WebSocket client in a separate thread
# TODO: replace `run_forever` with something that can be signaled to shutdown gracefully
ws_thread = threading.Thread(target=ws.run_forever)
ws_thread.start()
# Wait for the processor to finish cleaning up the queue before shutting down
await queue_processor()
def run(self):
while not self.shutdown_event.is_set():
try:
tx_sender = self.q.get() # Waits here until new msg is available
self.handler.store(tx_sender)
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("QueueProcessor thread crashed: %s", e)
self.shutdown_event.set()
break