fix: bugs caused by missing awaits & misused async features

This commit is contained in:
17ms 2024-07-04 21:12:43 +03:00
parent f13d581c66
commit e262f40290
Signed by untrusted user who does not match committer: ae
GPG Key ID: 995EFD5C1B532B3E
5 changed files with 126 additions and 22 deletions

52
main.py Normal file → Executable file
View File

@ -3,17 +3,55 @@
import asyncio import asyncio
import threading import threading
import logging import logging
import aioprocessing
from dotenv import dotenv_values
from src.mempool import WebSocketThread, QueueProcessor from src.mempool import WebSocketThread, QueueProcessor
from src.db import Handler from src.db import Handler, periodic_export
from src.const import EXPORT_INTERVAL
def main(): async def shutdown(loop, signal=None):
# FIFO queue for cross-thread communications """Cleanup tasks tied to the service's shutdown."""
q = asyncio.Queue() if signal:
shutdown_event = threading.Event() logging.info("Received exit signal %s", signal.name)
logging.info("Napping for a bit before shutting down...")
await asyncio.sleep(2)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for t in tasks:
t.cancel()
logging.info("Cancelling %d outstanding tasks", len(tasks))
await asyncio.gather(*tasks, return_exceptions=True)
logging.info("Flushing metrics")
loop.stop()
def main(env_path=".env"):
cfg = dotenv_values(env_path)
mode = cfg["MODE"]
if mode is None or mode.lower() == "production":
log_level = logging.INFO
else:
log_level = logging.DEBUG
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=log_level)
# FIFO queue for crosst-thread communications
q = aioprocessing.AioQueue()
handler = Handler() handler = Handler()
loop = asyncio.new_event_loop()
# TODO: handle scheduling of the export task
# loop.create_task(periodic_export(handler, EXPORT_INTERVAL))
# export_task_fut = asyncio.run_coroutine_threadsafe(periodic_export, loop)
shutdown_event = threading.Event()
ws_thread = WebSocketThread(q, shutdown_event) ws_thread = WebSocketThread(q, shutdown_event)
qp_thread = QueueProcessor(q, shutdown_event, handler) qp_thread = QueueProcessor(q, shutdown_event, handler)
@ -26,8 +64,12 @@ def main():
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("Keyboard interrupt received, shutting down threads.") logging.info("Keyboard interrupt received, shutting down threads.")
shutdown_event.set() shutdown_event.set()
loop.run_until_complete(shutdown(loop))
ws_thread.join() ws_thread.join()
qp_thread.join() qp_thread.join()
finally:
loop.stop()
loop.close()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1 +1,2 @@
aioprocessing==2.0.1
websockets==12.0 websockets==12.0

View File

@ -2,3 +2,6 @@ import json
WS_ADDR = "wss://ws.blockchain.info/coins" WS_ADDR = "wss://ws.blockchain.info/coins"
SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"}) SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})
# EXPORT_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
EXPORT_INTERVAL = 30

View File

@ -1,13 +1,16 @@
import sqlite3 import sqlite3
import json
import logging
import asyncio
class Handler: class Handler:
"""Handles all SQLite connections required to create, update, and export the stored addresses.""" """Handle all SQLite connections required to create, update, and export the stored addresses."""
def __init__(self, database="chainmapper.sqlite3"): def __init__(self, database="chainmapper.sqlite3"):
self.database = database self.database = database
# Notably `connect` automatically creates the database if it doesn't already exist # Notably `connect` automatically creates the database if it doesn't already exist
self.con = sqlite3.connect(self.database) self.con = sqlite3.connect(self.database, check_same_thread=False)
self.cursor = self.con.cursor() self.cursor = self.con.cursor()
# Initialize the table if necessary # Initialize the table if necessary
@ -24,10 +27,14 @@ class Handler:
self.con.commit() self.con.commit()
def store(self, address): async def store(self, address):
"""Store a new address into the SQLite database, or increments the counter by one if the given address already exists in the database."""
await asyncio.to_thread(self._store, address)
def _store(self, address):
self.cursor.execute( self.cursor.execute(
""" """
INSERT INTO AddressTracking (address) INSERT INTO AddressMapping (address)
VALUES VALUES
(?) ON CONFLICT(address) DO (?) ON CONFLICT(address) DO
UPDATE UPDATE
@ -35,10 +42,37 @@ class Handler:
total_tx_count = total_tx_count + 1, total_tx_count = total_tx_count + 1,
last_updated = CURRENT_TIMESTAMP; last_updated = CURRENT_TIMESTAMP;
""", """,
address, (address,),
) )
self.con.commit() self.con.commit()
def export(self): async def export(self, filename="export.json"):
# TODO: handle exporting """Export the addresses from the SQLite database in descending order based on the transaction counts."""
pass await asyncio.to_thread(self._export, filename)
def _export(self, filename="export.json"):
self.cursor.execute(
"""
SELECT address, total_tx_count
FROM AddressMapping
ORDER BY total_tx_count DESC;
"""
)
records = self.cursor.fetchall()
data = [{"address": record[0], "tx_count": record[1]} for record in records]
data_json = json.dumps(data, indent=4)
logging.info("Exporting the database's current state to '%s' (overwriting if an old copy exists)...", filename)
with open(filename, "w", encoding="utf-8") as f:
f.write(data_json)
logging.info("Data exported to '%s'", filename)
async def periodic_export(handler, interval):
logging.info("Scheduled export task created")
while True:
await asyncio.sleep(interval)
await handler.export()

View File

@ -4,12 +4,15 @@ import threading
import logging import logging
import websockets import websockets
from const import WS_ADDR, SUB_MSG from src.const import WS_ADDR, SUB_MSG
class WebSocketThread(threading.Thread): class WebSocketThread(threading.Thread):
"""Handle connection, subscription, and message parsing for the Blockchain.com WebSocket."""
def __init__(self, q, shutdown_event, sub_msg=SUB_MSG): def __init__(self, q, shutdown_event, sub_msg=SUB_MSG):
super().__init__() super().__init__()
self.name = "WebSocketThread"
self.q = q self.q = q
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.sub_msg = sub_msg self.sub_msg = sub_msg
@ -21,6 +24,9 @@ class WebSocketThread(threading.Thread):
await ws.send(self.sub_msg) await ws.send(self.sub_msg)
logging.info("Subscription message sent") logging.info("Subscription message sent")
# Ignores the confirmation message, as it can't be parsed with the same template
_ = await ws.recv()
while not self.shutdown_event.is_set(): while not self.shutdown_event.is_set():
try: try:
msg = await ws.recv() msg = await ws.recv()
@ -29,14 +35,14 @@ class WebSocketThread(threading.Thread):
if data is None: if data is None:
continue continue
self.q.put(data) await self.q.coro_put(data)
except websockets.exceptions.ConnectionClosed: except websockets.exceptions.ConnectionClosed:
logging.info("WebSocket connection closed") logging.info("WebSocket connection closed")
self.shutdown_event.set() self.shutdown_event.set()
break break
# pylint: disable=broad-exception-caught # pylint: disable=broad-exception-caught
except Exception as e: except Exception as e:
logging.error("WebSocket error: %s", e) logging.error("WebSocket error: %s", str(e))
self.shutdown_event.set() self.shutdown_event.set()
break break
@ -46,7 +52,7 @@ class WebSocketThread(threading.Thread):
try: try:
tx_sender = msg_json["transaction"]["from"] tx_sender = msg_json["transaction"]["from"]
except KeyError as e: except KeyError as e:
logging.error("Error parsing a WebSocket message: %s", e) logging.error("Error parsing a WebSocket message: %s", str(e))
return None return None
self.tx_count += 1 self.tx_count += 1
@ -57,6 +63,7 @@ class WebSocketThread(threading.Thread):
return tx_sender return tx_sender
def run(self): def run(self):
"""Start the WebSocket thread that'll run until it receives a shutdown message or crashes."""
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@ -64,26 +71,43 @@ class WebSocketThread(threading.Thread):
loop.run_until_complete(self.connect()) loop.run_until_complete(self.connect())
# pylint: disable=broad-exception-caught # pylint: disable=broad-exception-caught
except Exception as e: except Exception as e:
logging.error("WebSocket thread crashed: %s", e) logging.error("WebSocket thread crashed: %s", str(e))
self.shutdown_event.set() self.shutdown_event.set()
finally: finally:
loop.close() loop.close()
class QueueProcessor(threading.Thread): class QueueProcessor(threading.Thread):
"""Handle processing of items from the cross-thread queue where the WebSocket thread feeds data into."""
def __init__(self, q, shutdown_event, handler): def __init__(self, q, shutdown_event, handler):
super().__init__() super().__init__()
self.name = "QueueProcessor"
self.q = q self.q = q
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.handler = handler self.handler = handler
def run(self): async def process_queue(self):
while not self.shutdown_event.is_set(): while not self.shutdown_event.is_set():
try: try:
tx_sender = self.q.get() # Waits here until new msg is available tx_sender = await self.q.coro_get() # Waits here until new msg is available
self.handler.store(tx_sender) await self.handler.store(tx_sender)
# pylint: disable=broad-exception-caught # pylint: disable=broad-exception-caught
except Exception as e: except Exception as e:
logging.error("QueueProcessor thread crashed: %s", e) logging.error("QueueProcessor thread crashed: %s", str(e))
self.shutdown_event.set() self.shutdown_event.set()
break break
def run(self):
"""Start the queue processing thread that'll run until it receives a shutdown message or crashes."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.process_queue())
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("QueueProcessor thread crashed: %s", str(e))
self.shutdown_event.set()
finally:
loop.close()