fix: bugs caused by missing awaits & misused async features
This commit is contained in:
parent
f13d581c66
commit
e262f40290
52
main.py
Normal file → Executable file
52
main.py
Normal file → Executable file
@ -3,17 +3,55 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import logging
|
||||
import aioprocessing
|
||||
from dotenv import dotenv_values
|
||||
|
||||
from src.mempool import WebSocketThread, QueueProcessor
|
||||
from src.db import Handler
|
||||
from src.db import Handler, periodic_export
|
||||
from src.const import EXPORT_INTERVAL
|
||||
|
||||
|
||||
def main():
|
||||
# FIFO queue for cross-thread communications
|
||||
q = asyncio.Queue()
|
||||
shutdown_event = threading.Event()
|
||||
async def shutdown(loop, signal=None):
|
||||
"""Cleanup tasks tied to the service's shutdown."""
|
||||
if signal:
|
||||
logging.info("Received exit signal %s", signal.name)
|
||||
|
||||
logging.info("Napping for a bit before shutting down...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
|
||||
logging.info("Cancelling %d outstanding tasks", len(tasks))
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
logging.info("Flushing metrics")
|
||||
loop.stop()
|
||||
|
||||
|
||||
def main(env_path=".env"):
|
||||
cfg = dotenv_values(env_path)
|
||||
mode = cfg["MODE"]
|
||||
|
||||
if mode is None or mode.lower() == "production":
|
||||
log_level = logging.INFO
|
||||
else:
|
||||
log_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=log_level)
|
||||
|
||||
# FIFO queue for crosst-thread communications
|
||||
q = aioprocessing.AioQueue()
|
||||
handler = Handler()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
# TODO: handle scheduling of the export task
|
||||
# loop.create_task(periodic_export(handler, EXPORT_INTERVAL))
|
||||
# export_task_fut = asyncio.run_coroutine_threadsafe(periodic_export, loop)
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
ws_thread = WebSocketThread(q, shutdown_event)
|
||||
qp_thread = QueueProcessor(q, shutdown_event, handler)
|
||||
|
||||
@ -26,8 +64,12 @@ def main():
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Keyboard interrupt received, shutting down threads.")
|
||||
shutdown_event.set()
|
||||
loop.run_until_complete(shutdown(loop))
|
||||
ws_thread.join()
|
||||
qp_thread.join()
|
||||
finally:
|
||||
loop.stop()
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1 +1,2 @@
|
||||
aioprocessing==2.0.1
|
||||
websockets==12.0
|
||||
|
@ -2,3 +2,6 @@ import json
|
||||
|
||||
WS_ADDR = "wss://ws.blockchain.info/coins"
|
||||
SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})
|
||||
|
||||
# EXPORT_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
|
||||
EXPORT_INTERVAL = 30
|
||||
|
50
src/db.py
50
src/db.py
@ -1,13 +1,16 @@
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
|
||||
class Handler:
|
||||
"""Handles all SQLite connections required to create, update, and export the stored addresses."""
|
||||
"""Handle all SQLite connections required to create, update, and export the stored addresses."""
|
||||
|
||||
def __init__(self, database="chainmapper.sqlite3"):
|
||||
self.database = database
|
||||
# Notably `connect` automatically creates the database if it doesn't already exist
|
||||
self.con = sqlite3.connect(self.database)
|
||||
self.con = sqlite3.connect(self.database, check_same_thread=False)
|
||||
self.cursor = self.con.cursor()
|
||||
|
||||
# Initialize the table if necessary
|
||||
@ -24,10 +27,14 @@ class Handler:
|
||||
|
||||
self.con.commit()
|
||||
|
||||
def store(self, address):
|
||||
async def store(self, address):
|
||||
"""Store a new address into the SQLite database, or increments the counter by one if the given address already exists in the database."""
|
||||
await asyncio.to_thread(self._store, address)
|
||||
|
||||
def _store(self, address):
|
||||
self.cursor.execute(
|
||||
"""
|
||||
INSERT INTO AddressTracking (address)
|
||||
INSERT INTO AddressMapping (address)
|
||||
VALUES
|
||||
(?) ON CONFLICT(address) DO
|
||||
UPDATE
|
||||
@ -35,10 +42,37 @@ class Handler:
|
||||
total_tx_count = total_tx_count + 1,
|
||||
last_updated = CURRENT_TIMESTAMP;
|
||||
""",
|
||||
address,
|
||||
(address,),
|
||||
)
|
||||
self.con.commit()
|
||||
|
||||
def export(self):
|
||||
# TODO: handle exporting
|
||||
pass
|
||||
async def export(self, filename="export.json"):
|
||||
"""Export the addresses from the SQLite database in descending order based on the transaction counts."""
|
||||
await asyncio.to_thread(self._export, filename)
|
||||
|
||||
def _export(self, filename="export.json"):
|
||||
self.cursor.execute(
|
||||
"""
|
||||
SELECT address, total_tx_count
|
||||
FROM AddressMapping
|
||||
ORDER BY total_tx_count DESC;
|
||||
"""
|
||||
)
|
||||
records = self.cursor.fetchall()
|
||||
data = [{"address": record[0], "tx_count": record[1]} for record in records]
|
||||
data_json = json.dumps(data, indent=4)
|
||||
|
||||
logging.info("Exporting the database's current state to '%s' (overwriting if an old copy exists)...", filename)
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(data_json)
|
||||
|
||||
logging.info("Data exported to '%s'", filename)
|
||||
|
||||
|
||||
async def periodic_export(handler, interval):
|
||||
logging.info("Scheduled export task created")
|
||||
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
await handler.export()
|
||||
|
@ -4,12 +4,15 @@ import threading
|
||||
import logging
|
||||
import websockets
|
||||
|
||||
from const import WS_ADDR, SUB_MSG
|
||||
from src.const import WS_ADDR, SUB_MSG
|
||||
|
||||
|
||||
class WebSocketThread(threading.Thread):
|
||||
"""Handle connection, subscription, and message parsing for the Blockchain.com WebSocket."""
|
||||
|
||||
def __init__(self, q, shutdown_event, sub_msg=SUB_MSG):
|
||||
super().__init__()
|
||||
self.name = "WebSocketThread"
|
||||
self.q = q
|
||||
self.shutdown_event = shutdown_event
|
||||
self.sub_msg = sub_msg
|
||||
@ -21,6 +24,9 @@ class WebSocketThread(threading.Thread):
|
||||
await ws.send(self.sub_msg)
|
||||
logging.info("Subscription message sent")
|
||||
|
||||
# Ignores the confirmation message, as it can't be parsed with the same template
|
||||
_ = await ws.recv()
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
try:
|
||||
msg = await ws.recv()
|
||||
@ -29,14 +35,14 @@ class WebSocketThread(threading.Thread):
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
self.q.put(data)
|
||||
await self.q.coro_put(data)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logging.info("WebSocket connection closed")
|
||||
self.shutdown_event.set()
|
||||
break
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
logging.error("WebSocket error: %s", e)
|
||||
logging.error("WebSocket error: %s", str(e))
|
||||
self.shutdown_event.set()
|
||||
break
|
||||
|
||||
@ -46,7 +52,7 @@ class WebSocketThread(threading.Thread):
|
||||
try:
|
||||
tx_sender = msg_json["transaction"]["from"]
|
||||
except KeyError as e:
|
||||
logging.error("Error parsing a WebSocket message: %s", e)
|
||||
logging.error("Error parsing a WebSocket message: %s", str(e))
|
||||
return None
|
||||
|
||||
self.tx_count += 1
|
||||
@ -57,6 +63,7 @@ class WebSocketThread(threading.Thread):
|
||||
return tx_sender
|
||||
|
||||
def run(self):
|
||||
"""Start the WebSocket thread that'll run until it receives a shutdown message or crashes."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
@ -64,26 +71,43 @@ class WebSocketThread(threading.Thread):
|
||||
loop.run_until_complete(self.connect())
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
logging.error("WebSocket thread crashed: %s", e)
|
||||
logging.error("WebSocket thread crashed: %s", str(e))
|
||||
self.shutdown_event.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
class QueueProcessor(threading.Thread):
|
||||
"""Handle processing of items from the cross-thread queue where the WebSocket thread feeds data into."""
|
||||
|
||||
def __init__(self, q, shutdown_event, handler):
|
||||
super().__init__()
|
||||
self.name = "QueueProcessor"
|
||||
self.q = q
|
||||
self.shutdown_event = shutdown_event
|
||||
self.handler = handler
|
||||
|
||||
def run(self):
|
||||
async def process_queue(self):
|
||||
while not self.shutdown_event.is_set():
|
||||
try:
|
||||
tx_sender = self.q.get() # Waits here until new msg is available
|
||||
self.handler.store(tx_sender)
|
||||
tx_sender = await self.q.coro_get() # Waits here until new msg is available
|
||||
await self.handler.store(tx_sender)
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
logging.error("QueueProcessor thread crashed: %s", e)
|
||||
logging.error("QueueProcessor thread crashed: %s", str(e))
|
||||
self.shutdown_event.set()
|
||||
break
|
||||
|
||||
def run(self):
|
||||
"""Start the queue processing thread that'll run until it receives a shutdown message or crashes."""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(self.process_queue())
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
logging.error("QueueProcessor thread crashed: %s", str(e))
|
||||
self.shutdown_event.set()
|
||||
finally:
|
||||
loop.close()
|
||||
|
Loading…
Reference in New Issue
Block a user