fix: bugs caused by missing awaits & misused async features

This commit is contained in:
17ms 2024-07-04 21:12:43 +03:00
parent f13d581c66
commit e262f40290
Signed by untrusted user who does not match committer: ae
GPG Key ID: 995EFD5C1B532B3E
5 changed files with 126 additions and 22 deletions

52
main.py Normal file → Executable file
View File

@ -3,17 +3,55 @@
import asyncio
import threading
import logging
import aioprocessing
from dotenv import dotenv_values
from src.mempool import WebSocketThread, QueueProcessor
from src.db import Handler
from src.db import Handler, periodic_export
from src.const import EXPORT_INTERVAL
def main():
# FIFO queue for cross-thread communications
q = asyncio.Queue()
shutdown_event = threading.Event()
async def shutdown(loop, signal=None):
"""Cleanup tasks tied to the service's shutdown."""
if signal:
logging.info("Received exit signal %s", signal.name)
logging.info("Napping for a bit before shutting down...")
await asyncio.sleep(2)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for t in tasks:
t.cancel()
logging.info("Cancelling %d outstanding tasks", len(tasks))
await asyncio.gather(*tasks, return_exceptions=True)
logging.info("Flushing metrics")
loop.stop()
def main(env_path=".env"):
cfg = dotenv_values(env_path)
mode = cfg["MODE"]
if mode is None or mode.lower() == "production":
log_level = logging.INFO
else:
log_level = logging.DEBUG
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=log_level)
# FIFO queue for crosst-thread communications
q = aioprocessing.AioQueue()
handler = Handler()
loop = asyncio.new_event_loop()
# TODO: handle scheduling of the export task
# loop.create_task(periodic_export(handler, EXPORT_INTERVAL))
# export_task_fut = asyncio.run_coroutine_threadsafe(periodic_export, loop)
shutdown_event = threading.Event()
ws_thread = WebSocketThread(q, shutdown_event)
qp_thread = QueueProcessor(q, shutdown_event, handler)
@ -26,8 +64,12 @@ def main():
except KeyboardInterrupt:
logging.info("Keyboard interrupt received, shutting down threads.")
shutdown_event.set()
loop.run_until_complete(shutdown(loop))
ws_thread.join()
qp_thread.join()
finally:
loop.stop()
loop.close()
if __name__ == "__main__":

View File

@ -1 +1,2 @@
aioprocessing==2.0.1
websockets==12.0

View File

@ -2,3 +2,6 @@ import json
WS_ADDR = "wss://ws.blockchain.info/coins"
SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})
# EXPORT_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
EXPORT_INTERVAL = 30

View File

@ -1,13 +1,16 @@
import sqlite3
import json
import logging
import asyncio
class Handler:
"""Handles all SQLite connections required to create, update, and export the stored addresses."""
"""Handle all SQLite connections required to create, update, and export the stored addresses."""
def __init__(self, database="chainmapper.sqlite3"):
self.database = database
# Notably `connect` automatically creates the database if it doesn't already exist
self.con = sqlite3.connect(self.database)
self.con = sqlite3.connect(self.database, check_same_thread=False)
self.cursor = self.con.cursor()
# Initialize the table if necessary
@ -24,10 +27,14 @@ class Handler:
self.con.commit()
def store(self, address):
async def store(self, address):
"""Store a new address into the SQLite database, or increments the counter by one if the given address already exists in the database."""
await asyncio.to_thread(self._store, address)
def _store(self, address):
self.cursor.execute(
"""
INSERT INTO AddressTracking (address)
INSERT INTO AddressMapping (address)
VALUES
(?) ON CONFLICT(address) DO
UPDATE
@ -35,10 +42,37 @@ class Handler:
total_tx_count = total_tx_count + 1,
last_updated = CURRENT_TIMESTAMP;
""",
address,
(address,),
)
self.con.commit()
def export(self):
# TODO: handle exporting
pass
async def export(self, filename="export.json"):
"""Export the addresses from the SQLite database in descending order based on the transaction counts."""
await asyncio.to_thread(self._export, filename)
def _export(self, filename="export.json"):
self.cursor.execute(
"""
SELECT address, total_tx_count
FROM AddressMapping
ORDER BY total_tx_count DESC;
"""
)
records = self.cursor.fetchall()
data = [{"address": record[0], "tx_count": record[1]} for record in records]
data_json = json.dumps(data, indent=4)
logging.info("Exporting the database's current state to '%s' (overwriting if an old copy exists)...", filename)
with open(filename, "w", encoding="utf-8") as f:
f.write(data_json)
logging.info("Data exported to '%s'", filename)
async def periodic_export(handler, interval):
logging.info("Scheduled export task created")
while True:
await asyncio.sleep(interval)
await handler.export()

View File

@ -4,12 +4,15 @@ import threading
import logging
import websockets
from const import WS_ADDR, SUB_MSG
from src.const import WS_ADDR, SUB_MSG
class WebSocketThread(threading.Thread):
"""Handle connection, subscription, and message parsing for the Blockchain.com WebSocket."""
def __init__(self, q, shutdown_event, sub_msg=SUB_MSG):
super().__init__()
self.name = "WebSocketThread"
self.q = q
self.shutdown_event = shutdown_event
self.sub_msg = sub_msg
@ -21,6 +24,9 @@ class WebSocketThread(threading.Thread):
await ws.send(self.sub_msg)
logging.info("Subscription message sent")
# Ignores the confirmation message, as it can't be parsed with the same template
_ = await ws.recv()
while not self.shutdown_event.is_set():
try:
msg = await ws.recv()
@ -29,14 +35,14 @@ class WebSocketThread(threading.Thread):
if data is None:
continue
self.q.put(data)
await self.q.coro_put(data)
except websockets.exceptions.ConnectionClosed:
logging.info("WebSocket connection closed")
self.shutdown_event.set()
break
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("WebSocket error: %s", e)
logging.error("WebSocket error: %s", str(e))
self.shutdown_event.set()
break
@ -46,7 +52,7 @@ class WebSocketThread(threading.Thread):
try:
tx_sender = msg_json["transaction"]["from"]
except KeyError as e:
logging.error("Error parsing a WebSocket message: %s", e)
logging.error("Error parsing a WebSocket message: %s", str(e))
return None
self.tx_count += 1
@ -57,6 +63,7 @@ class WebSocketThread(threading.Thread):
return tx_sender
def run(self):
"""Start the WebSocket thread that'll run until it receives a shutdown message or crashes."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -64,26 +71,43 @@ class WebSocketThread(threading.Thread):
loop.run_until_complete(self.connect())
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("WebSocket thread crashed: %s", e)
logging.error("WebSocket thread crashed: %s", str(e))
self.shutdown_event.set()
finally:
loop.close()
class QueueProcessor(threading.Thread):
"""Handle processing of items from the cross-thread queue where the WebSocket thread feeds data into."""
def __init__(self, q, shutdown_event, handler):
super().__init__()
self.name = "QueueProcessor"
self.q = q
self.shutdown_event = shutdown_event
self.handler = handler
def run(self):
async def process_queue(self):
while not self.shutdown_event.is_set():
try:
tx_sender = self.q.get() # Waits here until new msg is available
self.handler.store(tx_sender)
tx_sender = await self.q.coro_get() # Waits here until new msg is available
await self.handler.store(tx_sender)
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("QueueProcessor thread crashed: %s", e)
logging.error("QueueProcessor thread crashed: %s", str(e))
self.shutdown_event.set()
break
def run(self):
"""Start the queue processing thread that'll run until it receives a shutdown message or crashes."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.process_queue())
# pylint: disable=broad-exception-caught
except Exception as e:
logging.error("QueueProcessor thread crashed: %s", str(e))
self.shutdown_event.set()
finally:
loop.close()