fix: bugs caused by missing awaits & misused async features
This commit is contained in:
parent
f13d581c66
commit
e262f40290
52
main.py
Normal file → Executable file
52
main.py
Normal file → Executable file
@ -3,17 +3,55 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
|
import aioprocessing
|
||||||
|
from dotenv import dotenv_values
|
||||||
|
|
||||||
from src.mempool import WebSocketThread, QueueProcessor
|
from src.mempool import WebSocketThread, QueueProcessor
|
||||||
from src.db import Handler
|
from src.db import Handler, periodic_export
|
||||||
|
from src.const import EXPORT_INTERVAL
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def shutdown(loop, signal=None):
|
||||||
# FIFO queue for cross-thread communications
|
"""Cleanup tasks tied to the service's shutdown."""
|
||||||
q = asyncio.Queue()
|
if signal:
|
||||||
shutdown_event = threading.Event()
|
logging.info("Received exit signal %s", signal.name)
|
||||||
|
|
||||||
|
logging.info("Napping for a bit before shutting down...")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
|
|
||||||
|
for t in tasks:
|
||||||
|
t.cancel()
|
||||||
|
|
||||||
|
logging.info("Cancelling %d outstanding tasks", len(tasks))
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
logging.info("Flushing metrics")
|
||||||
|
loop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def main(env_path=".env"):
|
||||||
|
cfg = dotenv_values(env_path)
|
||||||
|
mode = cfg["MODE"]
|
||||||
|
|
||||||
|
if mode is None or mode.lower() == "production":
|
||||||
|
log_level = logging.INFO
|
||||||
|
else:
|
||||||
|
log_level = logging.DEBUG
|
||||||
|
|
||||||
|
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=log_level)
|
||||||
|
|
||||||
|
# FIFO queue for crosst-thread communications
|
||||||
|
q = aioprocessing.AioQueue()
|
||||||
handler = Handler()
|
handler = Handler()
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
# TODO: handle scheduling of the export task
|
||||||
|
# loop.create_task(periodic_export(handler, EXPORT_INTERVAL))
|
||||||
|
# export_task_fut = asyncio.run_coroutine_threadsafe(periodic_export, loop)
|
||||||
|
shutdown_event = threading.Event()
|
||||||
|
|
||||||
ws_thread = WebSocketThread(q, shutdown_event)
|
ws_thread = WebSocketThread(q, shutdown_event)
|
||||||
qp_thread = QueueProcessor(q, shutdown_event, handler)
|
qp_thread = QueueProcessor(q, shutdown_event, handler)
|
||||||
|
|
||||||
@ -26,8 +64,12 @@ def main():
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("Keyboard interrupt received, shutting down threads.")
|
logging.info("Keyboard interrupt received, shutting down threads.")
|
||||||
shutdown_event.set()
|
shutdown_event.set()
|
||||||
|
loop.run_until_complete(shutdown(loop))
|
||||||
ws_thread.join()
|
ws_thread.join()
|
||||||
qp_thread.join()
|
qp_thread.join()
|
||||||
|
finally:
|
||||||
|
loop.stop()
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1 +1,2 @@
|
|||||||
|
aioprocessing==2.0.1
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
|
@ -2,3 +2,6 @@ import json
|
|||||||
|
|
||||||
WS_ADDR = "wss://ws.blockchain.info/coins"
|
WS_ADDR = "wss://ws.blockchain.info/coins"
|
||||||
SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})
|
SUB_MSG = json.dumps({"coin": "eth", "command": "subscribe", "entity": "confirmed_transaction"})
|
||||||
|
|
||||||
|
# EXPORT_INTERVAL = 24 * 60 * 60 # 24 hours in seconds
|
||||||
|
EXPORT_INTERVAL = 30
|
||||||
|
50
src/db.py
50
src/db.py
@ -1,13 +1,16 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
"""Handles all SQLite connections required to create, update, and export the stored addresses."""
|
"""Handle all SQLite connections required to create, update, and export the stored addresses."""
|
||||||
|
|
||||||
def __init__(self, database="chainmapper.sqlite3"):
|
def __init__(self, database="chainmapper.sqlite3"):
|
||||||
self.database = database
|
self.database = database
|
||||||
# Notably `connect` automatically creates the database if it doesn't already exist
|
# Notably `connect` automatically creates the database if it doesn't already exist
|
||||||
self.con = sqlite3.connect(self.database)
|
self.con = sqlite3.connect(self.database, check_same_thread=False)
|
||||||
self.cursor = self.con.cursor()
|
self.cursor = self.con.cursor()
|
||||||
|
|
||||||
# Initialize the table if necessary
|
# Initialize the table if necessary
|
||||||
@ -24,10 +27,14 @@ class Handler:
|
|||||||
|
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
|
|
||||||
def store(self, address):
|
async def store(self, address):
|
||||||
|
"""Store a new address into the SQLite database, or increments the counter by one if the given address already exists in the database."""
|
||||||
|
await asyncio.to_thread(self._store, address)
|
||||||
|
|
||||||
|
def _store(self, address):
|
||||||
self.cursor.execute(
|
self.cursor.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO AddressTracking (address)
|
INSERT INTO AddressMapping (address)
|
||||||
VALUES
|
VALUES
|
||||||
(?) ON CONFLICT(address) DO
|
(?) ON CONFLICT(address) DO
|
||||||
UPDATE
|
UPDATE
|
||||||
@ -35,10 +42,37 @@ class Handler:
|
|||||||
total_tx_count = total_tx_count + 1,
|
total_tx_count = total_tx_count + 1,
|
||||||
last_updated = CURRENT_TIMESTAMP;
|
last_updated = CURRENT_TIMESTAMP;
|
||||||
""",
|
""",
|
||||||
address,
|
(address,),
|
||||||
)
|
)
|
||||||
self.con.commit()
|
self.con.commit()
|
||||||
|
|
||||||
def export(self):
|
async def export(self, filename="export.json"):
|
||||||
# TODO: handle exporting
|
"""Export the addresses from the SQLite database in descending order based on the transaction counts."""
|
||||||
pass
|
await asyncio.to_thread(self._export, filename)
|
||||||
|
|
||||||
|
def _export(self, filename="export.json"):
|
||||||
|
self.cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT address, total_tx_count
|
||||||
|
FROM AddressMapping
|
||||||
|
ORDER BY total_tx_count DESC;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
records = self.cursor.fetchall()
|
||||||
|
data = [{"address": record[0], "tx_count": record[1]} for record in records]
|
||||||
|
data_json = json.dumps(data, indent=4)
|
||||||
|
|
||||||
|
logging.info("Exporting the database's current state to '%s' (overwriting if an old copy exists)...", filename)
|
||||||
|
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
f.write(data_json)
|
||||||
|
|
||||||
|
logging.info("Data exported to '%s'", filename)
|
||||||
|
|
||||||
|
|
||||||
|
async def periodic_export(handler, interval):
|
||||||
|
logging.info("Scheduled export task created")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
await handler.export()
|
||||||
|
@ -4,12 +4,15 @@ import threading
|
|||||||
import logging
|
import logging
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from const import WS_ADDR, SUB_MSG
|
from src.const import WS_ADDR, SUB_MSG
|
||||||
|
|
||||||
|
|
||||||
class WebSocketThread(threading.Thread):
|
class WebSocketThread(threading.Thread):
|
||||||
|
"""Handle connection, subscription, and message parsing for the Blockchain.com WebSocket."""
|
||||||
|
|
||||||
def __init__(self, q, shutdown_event, sub_msg=SUB_MSG):
|
def __init__(self, q, shutdown_event, sub_msg=SUB_MSG):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.name = "WebSocketThread"
|
||||||
self.q = q
|
self.q = q
|
||||||
self.shutdown_event = shutdown_event
|
self.shutdown_event = shutdown_event
|
||||||
self.sub_msg = sub_msg
|
self.sub_msg = sub_msg
|
||||||
@ -21,6 +24,9 @@ class WebSocketThread(threading.Thread):
|
|||||||
await ws.send(self.sub_msg)
|
await ws.send(self.sub_msg)
|
||||||
logging.info("Subscription message sent")
|
logging.info("Subscription message sent")
|
||||||
|
|
||||||
|
# Ignores the confirmation message, as it can't be parsed with the same template
|
||||||
|
_ = await ws.recv()
|
||||||
|
|
||||||
while not self.shutdown_event.is_set():
|
while not self.shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
msg = await ws.recv()
|
msg = await ws.recv()
|
||||||
@ -29,14 +35,14 @@ class WebSocketThread(threading.Thread):
|
|||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.q.put(data)
|
await self.q.coro_put(data)
|
||||||
except websockets.exceptions.ConnectionClosed:
|
except websockets.exceptions.ConnectionClosed:
|
||||||
logging.info("WebSocket connection closed")
|
logging.info("WebSocket connection closed")
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
break
|
break
|
||||||
# pylint: disable=broad-exception-caught
|
# pylint: disable=broad-exception-caught
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("WebSocket error: %s", e)
|
logging.error("WebSocket error: %s", str(e))
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -46,7 +52,7 @@ class WebSocketThread(threading.Thread):
|
|||||||
try:
|
try:
|
||||||
tx_sender = msg_json["transaction"]["from"]
|
tx_sender = msg_json["transaction"]["from"]
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logging.error("Error parsing a WebSocket message: %s", e)
|
logging.error("Error parsing a WebSocket message: %s", str(e))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.tx_count += 1
|
self.tx_count += 1
|
||||||
@ -57,6 +63,7 @@ class WebSocketThread(threading.Thread):
|
|||||||
return tx_sender
|
return tx_sender
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
"""Start the WebSocket thread that'll run until it receives a shutdown message or crashes."""
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
@ -64,26 +71,43 @@ class WebSocketThread(threading.Thread):
|
|||||||
loop.run_until_complete(self.connect())
|
loop.run_until_complete(self.connect())
|
||||||
# pylint: disable=broad-exception-caught
|
# pylint: disable=broad-exception-caught
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("WebSocket thread crashed: %s", e)
|
logging.error("WebSocket thread crashed: %s", str(e))
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
finally:
|
finally:
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
class QueueProcessor(threading.Thread):
|
class QueueProcessor(threading.Thread):
|
||||||
|
"""Handle processing of items from the cross-thread queue where the WebSocket thread feeds data into."""
|
||||||
|
|
||||||
def __init__(self, q, shutdown_event, handler):
|
def __init__(self, q, shutdown_event, handler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.name = "QueueProcessor"
|
||||||
self.q = q
|
self.q = q
|
||||||
self.shutdown_event = shutdown_event
|
self.shutdown_event = shutdown_event
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
def run(self):
|
async def process_queue(self):
|
||||||
while not self.shutdown_event.is_set():
|
while not self.shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
tx_sender = self.q.get() # Waits here until new msg is available
|
tx_sender = await self.q.coro_get() # Waits here until new msg is available
|
||||||
self.handler.store(tx_sender)
|
await self.handler.store(tx_sender)
|
||||||
# pylint: disable=broad-exception-caught
|
# pylint: disable=broad-exception-caught
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("QueueProcessor thread crashed: %s", e)
|
logging.error("QueueProcessor thread crashed: %s", str(e))
|
||||||
self.shutdown_event.set()
|
self.shutdown_event.set()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Start the queue processing thread that'll run until it receives a shutdown message or crashes."""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(self.process_queue())
|
||||||
|
# pylint: disable=broad-exception-caught
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("QueueProcessor thread crashed: %s", str(e))
|
||||||
|
self.shutdown_event.set()
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user