fa45d8aa5f
- health_checklist.json: 192.168.1.122→node122
- ocr_client.py: docstring IP→node122
- docs/market-data-requirements.md: IP→node122
- 所有API调用通过ProxyHandler({})绕过系统代理
Privoxy对node122:18003返回500,直连正常
1159 lines
38 KiB
Python
1159 lines
38 KiB
Python
import asyncio
|
|
import collections
|
|
import contextvars
|
|
import itertools
|
|
import json
|
|
import logging
|
|
import re
|
|
|
|
from greenlet import greenlet, getcurrent
|
|
from peewee import *
|
|
from peewee import _atomic, _savepoint, _transaction
|
|
from peewee import _callable_context_manager
|
|
from peewee import __exception_wrapper__
|
|
from peewee import Node
|
|
from peewee import Psycopg3Adapter
|
|
from playhouse.postgres_ext import Json
|
|
|
|
try:
|
|
import aiosqlite
|
|
except ImportError:
|
|
aiosqlite = None
|
|
|
|
try:
|
|
import asyncpg
|
|
except ImportError:
|
|
asyncpg = None
|
|
|
|
try:
|
|
import aiomysql
|
|
except ImportError:
|
|
aiomysql = None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MissingGreenletBridge(RuntimeError):
|
|
pass
|
|
|
|
|
|
_BRIDGE_ERR_HINT = (
|
|
' Hint: in async code, run queries through the async API, e.g. '
|
|
'`await Model.aget(...)`, `await query.aexecute()`, `await '
|
|
'db.list(query)`, or wrap synchronous code with `await db.run(fn)`. '
|
|
'For lazy foreign-key access use `await obj.afetch(Model.rel_field)`. See '
|
|
'https://docs.peewee-orm.com/en/latest/peewee/asyncio.html#sharp-corners')
|
|
|
|
|
|
async def greenlet_spawn(fn, *args, **kwargs):
|
|
parent = getcurrent()
|
|
result = None
|
|
error = None
|
|
|
|
def runner():
|
|
nonlocal result, error
|
|
try:
|
|
result = fn(*args, **kwargs)
|
|
except BaseException as exc:
|
|
error = exc
|
|
|
|
# Run the sync code in a greenlet - the sync code must use await_()
|
|
# whenever blocking would occur. await_() transfers a coroutine and control
|
|
# back up to this runner, which can safely `await` the coroutine before
|
|
# switching back to the sync code.
|
|
g = greenlet(runner, parent=parent)
|
|
g.gr_context = parent.gr_context
|
|
value = g.switch()
|
|
while not g.dead:
|
|
try:
|
|
value = g.switch(await value)
|
|
except BaseException as exc:
|
|
value = g.throw(exc)
|
|
|
|
if error:
|
|
raise error
|
|
return result
|
|
|
|
|
|
def await_(awaitable):
|
|
current = getcurrent()
|
|
parent = current.parent
|
|
if parent is None:
|
|
if asyncio.iscoroutine(awaitable):
|
|
awaitable.close() # Avoid a "never awaited" RuntimeWarning.
|
|
errmsg = 'await_() called outside greenlet_spawn()' + _BRIDGE_ERR_HINT
|
|
raise MissingGreenletBridge(errmsg)
|
|
return parent.switch(awaitable)
|
|
|
|
|
|
class _State(object):
|
|
__slots__ = ('conn', 'closed', 'transactions', 'ctx', '_task_id')
|
|
|
|
def __init__(self):
|
|
self._task_id = None
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.conn = None
|
|
self.closed = True
|
|
self.transactions = []
|
|
self.ctx = []
|
|
|
|
|
|
class _ConnectionState(object):
|
|
def __init__(self):
|
|
self._cv = contextvars.ContextVar('pwasyncio_state')
|
|
# Central registry: task-id -> _State. Allows close_pool() to
|
|
# enumerate *all* live states and release their connections.
|
|
self._states = {}
|
|
self._orphaned_conns = []
|
|
|
|
def _current(self):
|
|
task = asyncio.current_task()
|
|
if task is None:
|
|
raise RuntimeError('Cannot determine current task')
|
|
tid = id(task)
|
|
|
|
try:
|
|
state = self._cv.get()
|
|
if state._task_id == tid:
|
|
# Re-register if evicted (e.g. by close_pool clearing _states).
|
|
if tid not in self._states:
|
|
self._states[tid] = state
|
|
# Unnecessary to register the callback; task is still
|
|
# running so the original callback should be present.
|
|
# task.add_done_callback(self._on_task_done)
|
|
return state
|
|
except LookupError:
|
|
pass
|
|
|
|
if tid in self._states:
|
|
state = self._states[tid]
|
|
else:
|
|
state = _State()
|
|
state._task_id = tid
|
|
self._states[tid] = state
|
|
task.add_done_callback(self._on_task_done)
|
|
|
|
# Cache in the contextvar for subsequent calls for task.
|
|
self._cv.set(state)
|
|
return state
|
|
|
|
def _on_task_done(self, task):
|
|
tid = id(task)
|
|
state = self._states.pop(tid, None)
|
|
if state is not None and state.conn is not None and not state.closed:
|
|
self._orphaned_conns.append(state.conn)
|
|
state.reset()
|
|
|
|
@property
|
|
def conn(self):
|
|
return self._current().conn
|
|
|
|
@property
|
|
def closed(self):
|
|
return self._current().closed
|
|
|
|
@property
|
|
def transactions(self):
|
|
return self._current().transactions
|
|
|
|
@property
|
|
def ctx(self):
|
|
return self._current().ctx
|
|
|
|
def reset(self):
|
|
try:
|
|
state = self._current()
|
|
except RuntimeError:
|
|
return
|
|
state.reset()
|
|
|
|
def set_connection(self, conn):
|
|
state = self._current()
|
|
state.conn = conn
|
|
state.closed = False
|
|
|
|
|
|
class _async_transaction_helper(object):
|
|
async def __aenter__(self):
|
|
return await self.db.run(self.__enter__)
|
|
|
|
async def __aexit__(self, exc_typ, exc, tb):
|
|
return await self.db.run(self.__exit__, exc_typ, exc, tb)
|
|
|
|
async def acommit(self):
|
|
return await self.db.run(self.commit)
|
|
|
|
async def arollback(self):
|
|
return await self.db.run(self.rollback)
|
|
|
|
|
|
class async_atomic(_async_transaction_helper, _atomic): pass
|
|
class async_transaction(_async_transaction_helper, _transaction): pass
|
|
class async_savepoint(_async_transaction_helper, _savepoint): pass
|
|
|
|
|
|
class AsyncDatabaseMixin(object):
|
|
def __init__(self, database, **kwargs):
|
|
self._pool_size = kwargs.pop('pool_size', 10)
|
|
self._pool_min_size = kwargs.pop('pool_min_size', 1)
|
|
self._acquire_timeout = kwargs.pop('acquire_timeout', 10)
|
|
super(AsyncDatabaseMixin, self).__init__(database, **kwargs)
|
|
|
|
self._state = _ConnectionState()
|
|
self._pool = None
|
|
self._pool_lock = asyncio.Lock()
|
|
self._closing = False # Guard against use during shutdown.
|
|
|
|
def execute_sql(self, sql, params=None):
|
|
try:
|
|
return await_(self.aexecute_sql(sql, params or ()))
|
|
except MissingGreenletBridge as exc:
|
|
errmsg = f'Attempted query outside greenlet runner: {sql}.'
|
|
raise MissingGreenletBridge(errmsg + _BRIDGE_ERR_HINT) from exc
|
|
|
|
async def aexecute_sql(self, sql, params=None):
|
|
conn = await self.aconnect()
|
|
with __exception_wrapper__:
|
|
return await conn.execute(sql, params)
|
|
|
|
def connect(self):
|
|
return await_(self.aconnect())
|
|
|
|
async def aconnect(self):
|
|
if self._closing:
|
|
raise InterfaceError('Database pool is shutting down.')
|
|
|
|
# Drain any connections orphaned by dead tasks.
|
|
while self._state._orphaned_conns:
|
|
orphan = self._state._orphaned_conns.pop()
|
|
await self._pool_release(orphan)
|
|
|
|
conn = self._state.conn
|
|
if conn is None or conn.conn is None:
|
|
if conn is not None:
|
|
# Previous connection was invalidated, release it.
|
|
await self._pool_release(conn)
|
|
conn = await self._acquire_conn_async()
|
|
self._state.set_connection(conn)
|
|
return conn
|
|
|
|
def close(self):
|
|
return await_(self.aclose())
|
|
|
|
async def aclose(self):
|
|
if self.in_transaction():
|
|
raise OperationalError('Attempting to close database while '
|
|
'transaction is open.')
|
|
conn = self._state.conn
|
|
if conn:
|
|
self._state.reset()
|
|
logger.debug('Releasing connection %s to pool.', id(conn))
|
|
await self._pool_release(conn)
|
|
|
|
async def _acquire_conn_async(self):
|
|
async with self._pool_lock:
|
|
if self._pool is None:
|
|
self._pool = await self._create_pool_async()
|
|
|
|
try:
|
|
conn = await self._pool_acquire()
|
|
except asyncio.TimeoutError:
|
|
raise OperationalError(
|
|
'Timed out acquiring connection from pool '
|
|
'(acquire_timeout=%s).' % self._acquire_timeout) from None
|
|
logger.debug('Acquired connection %s from pool.', id(conn))
|
|
return conn
|
|
|
|
async def _create_pool_async(self):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
async def _pool_acquire(self):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
async def _pool_release(self, conn):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
async def close_pool(self):
|
|
self._closing = True
|
|
try:
|
|
if self._pool:
|
|
# Release connections held by any task still in the registry.
|
|
# We must clear each state BEFORE releasing the connection,
|
|
# because the await in _pool_release can let the event loop
|
|
# run pending task-done callbacks. If the callback sees
|
|
# state.conn still set it will orphan the same connection,
|
|
# leading to a double-release that overfills the pool queue.
|
|
for state in list(self._state._states.values()):
|
|
if state.conn and not state.closed:
|
|
conn = state.conn
|
|
state.reset()
|
|
try:
|
|
await self._pool_release(conn)
|
|
except Exception:
|
|
logger.warning(
|
|
'Error releasing connection during pool close',
|
|
exc_info=True)
|
|
self._state._states.clear()
|
|
|
|
# Drain any connections orphaned by completed tasks.
|
|
while self._state._orphaned_conns:
|
|
orphan = self._state._orphaned_conns.pop()
|
|
try:
|
|
await self._pool_release(orphan)
|
|
except Exception:
|
|
logger.warning('Error releasing orphaned connection',
|
|
exc_info=True)
|
|
|
|
await self._pool_close()
|
|
self._pool = None
|
|
finally:
|
|
self._closing = False
|
|
|
|
async def _pool_close(self):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
async def __aenter__(self):
|
|
await self.run(self.connect)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_typ, exc, tb):
|
|
await self.run(self.close)
|
|
|
|
def atomic(self, *args, **kwargs):
|
|
return async_atomic(self, *args, **kwargs)
|
|
|
|
def transaction(self, *args, **kwargs):
|
|
return async_transaction(self, *args, **kwargs)
|
|
|
|
def savepoint(self):
|
|
return async_savepoint(self)
|
|
|
|
async def acreate_tables(self, *args, **kwargs):
|
|
return await greenlet_spawn(self.create_tables, *args, **kwargs)
|
|
|
|
async def adrop_tables(self, *args, **kwargs):
|
|
return await greenlet_spawn(self.drop_tables, *args, **kwargs)
|
|
|
|
async def aexecute(self, query):
|
|
query.bind(self)
|
|
return await self.run(query.execute)
|
|
|
|
async def get(self, query):
|
|
return await self.run(query.get)
|
|
|
|
async def first(self, query, n=1):
|
|
return await self.run(query.first, n=n)
|
|
|
|
async def list(self, query):
|
|
return await self.run(list, query)
|
|
|
|
async def scalar(self, query):
|
|
return await self.run(query.scalar)
|
|
|
|
async def count(self, query):
|
|
return await self.run(query.count)
|
|
|
|
async def exists(self, query):
|
|
return await self.run(query.exists)
|
|
|
|
async def aprefetch(self, query, *subqueries):
|
|
return await self.run(prefetch, query, *subqueries)
|
|
|
|
async def iterate(self, query, buffer_size=None):
|
|
# Use similar approach to postgres_ext server-side query impl.
|
|
query.bind(self)
|
|
sql, params = query.sql()
|
|
conn = await self.aconnect()
|
|
with __exception_wrapper__:
|
|
cursor = await conn.execute_iter(sql, params or ())
|
|
if buffer_size is not None:
|
|
cursor._buffer_size = buffer_size
|
|
|
|
try:
|
|
wrapper = query._get_cursor_wrapper(cursor)
|
|
row_iter = wrapper.iterator()
|
|
_sentinel = object()
|
|
|
|
# Cursor wrapper `iterator()` calls fetchone() to grab rows from
|
|
# the internal buffer. `fetchone()` may dispatch do the event loop
|
|
# to refill buffer (async).
|
|
while True:
|
|
row = await greenlet_spawn(next, row_iter, _sentinel)
|
|
if row is _sentinel:
|
|
break
|
|
yield row
|
|
finally:
|
|
await cursor.aclose()
|
|
|
|
async def run(self, fn, *args, **kwargs):
|
|
return await greenlet_spawn(fn, *args, **kwargs)
|
|
|
|
def is_closed(self):
|
|
try:
|
|
return self._state.closed
|
|
except RuntimeError:
|
|
return True
|
|
|
|
@property
|
|
def Model(self):
|
|
if not hasattr(self, '_Model'):
|
|
class Meta: database = self
|
|
self._Model = type('AsyncBaseModel', (AsyncModel,), {'Meta': Meta})
|
|
return self._Model
|
|
|
|
|
|
def _aio_database(model):
|
|
# Ensure we have an asyncio-friendly db.
|
|
db = model._meta.database
|
|
if isinstance(db, Proxy):
|
|
db = db.obj
|
|
if db is None:
|
|
raise InterfaceError('%s is not bound to a database. Async methods '
|
|
'require an Async database.' % model.__name__)
|
|
elif not isinstance(db, AsyncDatabaseMixin):
|
|
raise InterfaceError('%s is not bound to an asyncio-compatible '
|
|
'database (%s). Async methods require an Async '
|
|
'database.' % (model.__name__, type(db).__name__))
|
|
return db
|
|
|
|
|
|
class AsyncModelMixin(object):
|
|
@classmethod
|
|
async def acreate(cls, **query):
|
|
return await _aio_database(cls).run(cls.create, **query)
|
|
|
|
@classmethod
|
|
async def aget(cls, *query, **filters):
|
|
return await _aio_database(cls).run(cls.get, *query, **filters)
|
|
|
|
@classmethod
|
|
async def aget_or_none(cls, *query, **filters):
|
|
return await _aio_database(cls).run(cls.get_or_none, *query, **filters)
|
|
|
|
@classmethod
|
|
async def aget_by_id(cls, pk):
|
|
return await _aio_database(cls).run(cls.get_by_id, pk)
|
|
|
|
@classmethod
|
|
async def aget_or_create(cls, **kwargs):
|
|
# Delegates to Model.get_or_create: atomic() + IntegrityError race
|
|
# recovery run inside the bridge.
|
|
return await _aio_database(cls).run(cls.get_or_create, **kwargs)
|
|
|
|
@classmethod
|
|
async def aset_by_id(cls, key, value):
|
|
return await _aio_database(cls).run(cls.set_by_id, key, value)
|
|
|
|
@classmethod
|
|
async def adelete_by_id(cls, pk):
|
|
return await _aio_database(cls).run(cls.delete_by_id, pk)
|
|
|
|
@classmethod
|
|
async def abulk_create(cls, model_list, batch_size=None):
|
|
return await _aio_database(cls).run(
|
|
cls.bulk_create,
|
|
model_list,
|
|
batch_size)
|
|
|
|
@classmethod
|
|
async def abulk_update(cls, model_list, fields, batch_size=None):
|
|
return await _aio_database(cls).run(
|
|
cls.bulk_update,
|
|
model_list,
|
|
fields,
|
|
batch_size)
|
|
|
|
async def asave(self, force_insert=False, only=None):
|
|
# resolve MRO, e.g. playhouse.signals overrides running in bridge.
|
|
return await _aio_database(type(self)).run(
|
|
self.save,
|
|
force_insert,
|
|
only)
|
|
|
|
async def adelete_instance(self, recursive=False, delete_nullable=False):
|
|
return await _aio_database(type(self)).run(
|
|
self.delete_instance,
|
|
recursive,
|
|
delete_nullable)
|
|
|
|
async def afetch(self, field):
|
|
# await tweet.afetch(Tweet.user), lazy foreign-key helper.
|
|
if isinstance(field, str):
|
|
field = self._meta.combined[field]
|
|
if not isinstance(field, ForeignKeyField):
|
|
raise ValueError('afetch() accepts a foreign-key field.')
|
|
if field.name in self.__rel__:
|
|
return self.__rel__[field.name] # Load cached.
|
|
if not field.lazy_load:
|
|
raise ValueError('%s.%s is declared with lazy_load=False.' %
|
|
(type(self).__name__, field.name))
|
|
return await _aio_database(type(self)).run(getattr, self, field.name)
|
|
|
|
|
|
class AsyncModel(AsyncModelMixin, Model):
|
|
pass
|
|
|
|
|
|
class CursorAdapter(object):
|
|
DEFAULT_BUFFER_SIZE = 100
|
|
|
|
def __init__(self, rows=None, lastrowid=None, rowcount=None,
|
|
description=None, fetch_many=None, cleanup=None,
|
|
buffer_size=None):
|
|
self._rows = rows or []
|
|
self._idx = 0
|
|
self.lastrowid = lastrowid
|
|
self.rowcount = rowcount if rowcount is not None else len(self._rows)
|
|
self.description = description or []
|
|
|
|
# Async server-side cursor support.
|
|
self._fetch_many = fetch_many
|
|
self._cleanup = cleanup
|
|
self._buffer_size = buffer_size or self.DEFAULT_BUFFER_SIZE
|
|
self._buffer = collections.deque()
|
|
self._exhausted = False
|
|
|
|
def fetchone(self):
|
|
if self._fetch_many is not None:
|
|
return self._lazy_fetchone()
|
|
if self._idx >= len(self._rows):
|
|
return
|
|
row = self._rows[self._idx]
|
|
self._idx += 1
|
|
return row
|
|
|
|
def _lazy_fetchone(self):
|
|
if not self._buffer:
|
|
if self._exhausted:
|
|
return None
|
|
with __exception_wrapper__:
|
|
rows = await_(self._fetch_many(self._buffer_size))
|
|
if not rows:
|
|
self._exhausted = True
|
|
return None
|
|
self._buffer.extend(rows)
|
|
return self._buffer.popleft()
|
|
|
|
def fetchall(self):
|
|
if self._fetch_many is not None:
|
|
return list(self)
|
|
return self._rows
|
|
|
|
def __iter__(self):
|
|
if self._fetch_many is not None:
|
|
return _lazy_cursor_iter(self)
|
|
return iter(self._rows)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
async def aclose(self):
|
|
if self._cleanup is not None:
|
|
try:
|
|
await self._cleanup()
|
|
finally:
|
|
self._cleanup = None
|
|
self._fetch_many = None
|
|
|
|
|
|
def _lazy_cursor_iter(cursor):
|
|
while True:
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return
|
|
yield row
|
|
|
|
|
|
class DummyCursor(object):
|
|
def __init__(self, conn):
|
|
self.conn = conn
|
|
|
|
def execute(self, sql, params=None):
|
|
return await_(self._async_execute(sql, params))
|
|
|
|
async def _async_execute(self, sql, params):
|
|
return await self.conn.execute(sql, params)
|
|
|
|
|
|
class AsyncConnectionWrapper(object):
|
|
# Grace period for an abandoned iterate() generator to finalize (e.g.
|
|
# the caller broke out of the async-for) before a competing query on
|
|
# this connection gives up instead of deadlocking.
|
|
streaming_timeout = 5.0
|
|
|
|
def __init__(self, conn):
|
|
self.conn = conn
|
|
self._lock = asyncio.Lock()
|
|
self._streaming = False # Lock is held by an open iterate() cursor.
|
|
|
|
async def _acquire_lock(self):
|
|
# When an iterate() cursor holds the lock, wait briefly for it to
|
|
# finalize rather than deadlocking - this covers plain queries AND
|
|
# a second iterate() on the same connection.
|
|
if self._streaming:
|
|
try:
|
|
await asyncio.wait_for(self._lock.acquire(),
|
|
self.streaming_timeout)
|
|
except asyncio.TimeoutError:
|
|
raise InterfaceError(
|
|
'Connection is busy streaming results from iterate(). '
|
|
'Run the query from another task, or exhaust or '
|
|
'aclose() the iterator.') from None
|
|
else:
|
|
await self._lock.acquire()
|
|
|
|
async def execute(self, sql, params=None):
|
|
await self._acquire_lock()
|
|
try:
|
|
return await self._execute(sql, params)
|
|
finally:
|
|
self._lock.release()
|
|
|
|
async def _execute(self, sql, params):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
def cursor(self):
|
|
return DummyCursor(self)
|
|
|
|
async def execute_iter(self, sql, params=None):
|
|
raise NotImplementedError('Subclasses must implement.')
|
|
|
|
async def close(self):
|
|
if self.conn:
|
|
await self.conn.close()
|
|
self.conn = None
|
|
|
|
|
|
class AsyncSqlitePool(object):
|
|
def __init__(self, database, pool_size=5, on_connect=None,
|
|
**connect_params):
|
|
self._database = database
|
|
self._pool_size = pool_size
|
|
self._on_connect = on_connect
|
|
self._connect_params = connect_params
|
|
self._queue = asyncio.Queue(maxsize=pool_size)
|
|
self._all_connections = []
|
|
self._closed = False
|
|
|
|
async def initialize(self):
|
|
for _ in range(self._pool_size):
|
|
conn = await self._create_connection()
|
|
self._queue.put_nowait(conn)
|
|
return self
|
|
|
|
async def _create_connection(self):
|
|
conn = await aiosqlite.connect(
|
|
self._database,
|
|
isolation_level=None,
|
|
**self._connect_params)
|
|
if self._on_connect is not None:
|
|
await self._on_connect(conn)
|
|
wrapped = AsyncSqliteConnection(conn)
|
|
self._all_connections.append(wrapped)
|
|
return wrapped
|
|
|
|
async def acquire(self, timeout=None):
|
|
if self._closed:
|
|
raise InterfaceError('Pool is closed.')
|
|
return await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
|
|
|
def _conn_is_valid(self, conn):
|
|
driver_conn = conn.conn
|
|
if driver_conn is None:
|
|
return False
|
|
# aiosqlite private attrs - tolerate their absence in new versions.
|
|
if not getattr(driver_conn, '_running', True):
|
|
return False
|
|
if not getattr(driver_conn, '_connection', True):
|
|
return False
|
|
return True
|
|
|
|
async def release(self, conn):
|
|
if self._closed:
|
|
return
|
|
valid = self._conn_is_valid(conn)
|
|
if valid and conn.conn.in_transaction:
|
|
# Roll back any transaction left open, e.g. by a dead task, so
|
|
# the next acquirer gets a clean connection.
|
|
try:
|
|
await conn.conn.rollback()
|
|
except Exception:
|
|
logger.warning('Error rolling back connection', exc_info=True)
|
|
valid = False
|
|
if valid:
|
|
await self._queue.put(conn)
|
|
else:
|
|
try:
|
|
self._all_connections.remove(conn)
|
|
except ValueError:
|
|
pass
|
|
await self._queue.put(await self._create_connection())
|
|
|
|
async def close(self):
|
|
self._closed = True
|
|
conns, self._all_connections = list(self._all_connections), []
|
|
for conn in conns:
|
|
try:
|
|
await conn.close()
|
|
except Exception:
|
|
logger.warning('Error closing pooled connection',
|
|
exc_info=True)
|
|
|
|
|
|
class AsyncSqliteConnection(AsyncConnectionWrapper):
|
|
async def _execute(self, sql, params=None):
|
|
params = params or ()
|
|
cursor = await self.conn.execute(sql, params)
|
|
rows = await cursor.fetchall()
|
|
lastrowid = cursor.lastrowid
|
|
rowcount = cursor.rowcount
|
|
description = cursor.description
|
|
await cursor.close()
|
|
return CursorAdapter(rows, lastrowid=lastrowid, rowcount=rowcount,
|
|
description=description)
|
|
|
|
async def execute_iter(self, sql, params=None):
|
|
await self._acquire_lock()
|
|
self._streaming = True
|
|
try:
|
|
cursor = await self.conn.execute(sql, params or ())
|
|
except BaseException:
|
|
self._streaming = False
|
|
self._lock.release()
|
|
raise
|
|
|
|
lock = self._lock
|
|
|
|
async def fetch_many(count):
|
|
return await cursor.fetchmany(count)
|
|
|
|
async def cleanup():
|
|
try:
|
|
await cursor.close()
|
|
finally:
|
|
self._streaming = False
|
|
lock.release()
|
|
|
|
return CursorAdapter(
|
|
description=cursor.description,
|
|
fetch_many=fetch_many,
|
|
cleanup=cleanup)
|
|
|
|
|
|
class AsyncSqliteDatabase(AsyncDatabaseMixin, SqliteDatabase):
|
|
async def _create_pool_async(self):
|
|
if aiosqlite is None:
|
|
raise ImproperlyConfigured('aiosqlite is not installed')
|
|
if self.database == ':memory:':
|
|
# Pooled in-memory connections would each be a separate, empty
|
|
# database - use a single shared connection instead.
|
|
pool_size = 1
|
|
else:
|
|
pool_size = self._pool_size
|
|
pool = AsyncSqlitePool(self.database, pool_size=pool_size,
|
|
on_connect=self._add_conn_hooks,
|
|
timeout=self._timeout,
|
|
**self.connect_params)
|
|
return await pool.initialize()
|
|
|
|
async def _add_conn_hooks(self, conn):
|
|
if self._attached:
|
|
await self._attach_databases(conn)
|
|
if self._pragmas:
|
|
await self._set_pragmas(conn)
|
|
if self._aggregates:
|
|
await self._load_aggregates(conn)
|
|
if self._collations:
|
|
await self._load_collations(conn)
|
|
if self._functions:
|
|
await self._load_functions(conn)
|
|
if self._window_functions and \
|
|
aiosqlite.sqlite_version_info >= (3, 25, 0):
|
|
await self._load_window_functions(conn)
|
|
if self._extensions:
|
|
await self._load_extensions(conn)
|
|
|
|
async def _attach_databases(self, conn):
|
|
for name, db in self._attached.items():
|
|
await conn.execute('ATTACH DATABASE "%s" AS "%s"' % (db, name))
|
|
|
|
async def _set_pragmas(self, conn):
|
|
for pragma, value in self._pragmas:
|
|
await conn.execute('PRAGMA %s = %s;' % (pragma, value))
|
|
|
|
async def _load_aggregates(self, conn):
|
|
# aiosqlite exposes no create_aggregate - run it on the worker
|
|
# thread against the raw sqlite3 connection.
|
|
for name, (klass, num_params) in self._aggregates.items():
|
|
await conn._execute(
|
|
conn._conn.create_aggregate, name, num_params, klass)
|
|
|
|
async def _load_collations(self, conn):
|
|
for name, fn in self._collations.items():
|
|
await conn._execute(conn._conn.create_collation, name, fn)
|
|
|
|
async def _load_functions(self, conn):
|
|
for name, (fn, n_params, deterministic) in self._functions.items():
|
|
kwargs = {'deterministic': deterministic} if deterministic else {}
|
|
await conn.create_function(name, n_params, fn, **kwargs)
|
|
|
|
async def _load_window_functions(self, conn):
|
|
for name, (klass, num_params) in self._window_functions.items():
|
|
await conn._execute(
|
|
conn._conn.create_window_function, name, num_params, klass)
|
|
|
|
async def _load_extensions(self, conn):
|
|
await conn.enable_load_extension(True)
|
|
for extension in self._extensions:
|
|
await conn.load_extension(extension)
|
|
|
|
async def _pool_acquire(self):
|
|
return await self._pool.acquire(timeout=self._acquire_timeout)
|
|
|
|
async def _pool_release(self, conn):
|
|
if conn is not None:
|
|
await self._pool.release(conn)
|
|
|
|
async def _pool_close(self):
|
|
if self._pool:
|
|
await self._pool.close()
|
|
|
|
|
|
class AsyncMySQLConnection(AsyncConnectionWrapper):
|
|
async def _execute(self, sql, params=None):
|
|
params = params or ()
|
|
cursor = await self.conn.cursor()
|
|
try:
|
|
await cursor.execute(sql, params)
|
|
rows = await cursor.fetchall()
|
|
lastrowid = cursor.lastrowid
|
|
rowcount = cursor.rowcount
|
|
description = cursor.description
|
|
finally:
|
|
await cursor.close()
|
|
return CursorAdapter(rows, lastrowid=lastrowid, rowcount=rowcount,
|
|
description=description)
|
|
|
|
async def execute_iter(self, sql, params=None):
|
|
await self._acquire_lock()
|
|
self._streaming = True
|
|
try:
|
|
# Server-side cursor for unbuffered streaming.
|
|
cursor = await self.conn.cursor(aiomysql.SSCursor)
|
|
await cursor.execute(sql, params or ())
|
|
except BaseException:
|
|
self._streaming = False
|
|
self._lock.release()
|
|
raise
|
|
|
|
lock = self._lock
|
|
|
|
async def fetch_many(count):
|
|
return await cursor.fetchmany(count)
|
|
|
|
async def cleanup():
|
|
try:
|
|
await cursor.close()
|
|
finally:
|
|
self._streaming = False
|
|
lock.release()
|
|
|
|
return CursorAdapter(
|
|
description=cursor.description,
|
|
fetch_many=fetch_many,
|
|
cleanup=cleanup)
|
|
|
|
|
|
class AsyncMySQLDatabase(AsyncDatabaseMixin, MySQLDatabase):
|
|
async def _create_pool_async(self):
|
|
if aiomysql is None:
|
|
raise ImproperlyConfigured('aiomysql is not installed')
|
|
return await aiomysql.create_pool(
|
|
db=self.database,
|
|
autocommit=True,
|
|
minsize=self._pool_min_size,
|
|
maxsize=self._pool_size,
|
|
**self.connect_params)
|
|
|
|
async def _pool_acquire(self):
|
|
conn = await asyncio.wait_for(
|
|
self._pool.acquire(),
|
|
timeout=self._acquire_timeout)
|
|
if self.server_version is None:
|
|
# Distinguishes MySQL from MariaDB, e.g. for JSONField SQL.
|
|
self.server_version = self._extract_server_version(
|
|
conn.get_server_info())
|
|
return AsyncMySQLConnection(conn)
|
|
|
|
async def _pool_release(self, conn):
|
|
if conn and conn.conn:
|
|
if conn.conn.get_transaction_status():
|
|
# Roll back any transaction left open, e.g. by a dead task,
|
|
# so the next acquirer gets a clean connection (aiomysql
|
|
# destroys connections released mid-transaction).
|
|
try:
|
|
await conn.conn.rollback()
|
|
except Exception:
|
|
logger.warning('Error rolling back connection',
|
|
exc_info=True)
|
|
self._pool.release(conn.conn)
|
|
|
|
async def _pool_close(self):
|
|
self._pool.close()
|
|
await self._pool.wait_closed()
|
|
|
|
|
|
class AsyncPostgresqlConnection(AsyncConnectionWrapper):
|
|
async def _execute(self, sql, params=None):
|
|
# asyncpg uses $1, $2 positional params instead of %s.
|
|
if params:
|
|
sql = self._translate_placeholders(sql)
|
|
|
|
stmt = await self.conn.prepare(sql)
|
|
records = await stmt.fetch(*(params or ()))
|
|
if records:
|
|
description = [(k,) for k in records[0].keys()]
|
|
else:
|
|
description = []
|
|
|
|
# asyncpg exposes no rowcount; parse the command-status tail, e.g.
|
|
# "UPDATE 3" / "DELETE 2" / "INSERT 0 3".
|
|
status = (stmt.get_statusmsg() or '').rsplit(' ', 1)
|
|
if len(status) == 2 and status[1].isdigit():
|
|
rowcount = int(status[1])
|
|
else:
|
|
rowcount = len(records)
|
|
|
|
return CursorAdapter(records, rowcount=rowcount,
|
|
description=description)
|
|
|
|
async def execute_iter(self, sql, params=None):
|
|
if params:
|
|
sql = self._translate_placeholders(sql)
|
|
await self._acquire_lock()
|
|
self._streaming = True
|
|
tr = None
|
|
try:
|
|
# NB: asyncpg cursors require an active transaction.
|
|
# Right now we cannot use peewee-managed transactions because
|
|
# asyncpg's Cursor._check_ready() requires an asyncpg-managed
|
|
# transaction be active.
|
|
# See: https://github.com/MagicStack/asyncpg/issues/1311
|
|
tr = self.conn.transaction()
|
|
await tr.start()
|
|
stmt = await self.conn.prepare(sql)
|
|
cursor = await stmt.cursor(*(params or ()))
|
|
except BaseException:
|
|
if tr is not None:
|
|
# Don't leave the connection inside an open transaction.
|
|
try:
|
|
await tr.rollback()
|
|
except Exception:
|
|
pass
|
|
self._streaming = False
|
|
self._lock.release()
|
|
raise
|
|
|
|
lock = self._lock
|
|
|
|
async def fetch_many(count):
|
|
return await cursor.fetch(count)
|
|
|
|
async def cleanup():
|
|
try:
|
|
await tr.rollback()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
self._streaming = False
|
|
lock.release()
|
|
|
|
return CursorAdapter(
|
|
fetch_many=fetch_many,
|
|
cleanup=cleanup,
|
|
description=[(a.name,) for a in stmt.get_attributes()])
|
|
|
|
@staticmethod
|
|
def _translate_placeholders(sql):
|
|
# %s is treated as a placeholder wherever it appears, including
|
|
# inside quoted strings, and %% as an escaped literal percent -
|
|
# mirroring psycopg. Pass literal values as parameters.
|
|
if '%' not in sql:
|
|
return sql
|
|
counter = itertools.count(1)
|
|
def replace(match):
|
|
if match.group(0) == '%%':
|
|
return '%'
|
|
return '$%d' % next(counter)
|
|
return re.sub('%%|%s', replace, sql)
|
|
|
|
|
|
class AsyncPgAdapter(Psycopg3Adapter):
|
|
def __init__(self):
|
|
super(AsyncPgAdapter, self).__init__()
|
|
self.json_type = Json
|
|
self.jsonb_type = Json
|
|
|
|
|
|
class AsyncPgAtomic(_callable_context_manager):
|
|
def __init__(self, db, *args, **kwargs):
|
|
self.db = db
|
|
self._begin_args = (args, kwargs)
|
|
|
|
def __enter__(self):
|
|
await_(self._abegin())
|
|
self.db._state.transactions.append(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.db._state.transactions.pop()
|
|
if exc_type:
|
|
self.rollback(False)
|
|
else:
|
|
try:
|
|
self.commit(False)
|
|
except Exception:
|
|
# asyncpg marks the transaction FAILED when commit errors,
|
|
# making rollback raise too - don't mask the original.
|
|
try:
|
|
self.rollback(False)
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
def commit(self, begin=True):
|
|
await_(self.acommit(begin))
|
|
|
|
def rollback(self, begin=True):
|
|
await_(self.arollback(begin))
|
|
|
|
async def _abegin(self):
|
|
a, k = self._begin_args
|
|
conn = await self.db.aconnect()
|
|
with __exception_wrapper__:
|
|
self._tx = conn.conn.transaction(*a, **k)
|
|
await self._tx.start()
|
|
return self._tx
|
|
|
|
async def acommit(self, begin=True):
|
|
with __exception_wrapper__:
|
|
await self._tx.commit()
|
|
if begin:
|
|
await self._abegin()
|
|
|
|
async def arollback(self, begin=True):
|
|
with __exception_wrapper__:
|
|
await self._tx.rollback()
|
|
if begin:
|
|
await self._abegin()
|
|
|
|
async def __aenter__(self):
|
|
await self._abegin()
|
|
self.db._state.transactions.append(self)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
self.db._state.transactions.pop()
|
|
if exc_type:
|
|
await self.arollback(False)
|
|
else:
|
|
try:
|
|
await self.acommit(False)
|
|
except Exception:
|
|
# asyncpg marks the transaction FAILED when commit errors,
|
|
# making rollback raise too - don't mask the original.
|
|
try:
|
|
await self.arollback(False)
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
|
|
class AsyncPostgresqlDatabase(AsyncDatabaseMixin, PostgresqlDatabase):
|
|
psycopg2_adapter = psycopg3_adapter = AsyncPgAdapter
|
|
|
|
def init(self, database, **kwargs):
|
|
# asyncpg has no psycopg-style isolation-level constants; keep the
|
|
# raw value and apply it per-connection in register_adapters().
|
|
self._async_isolation_level = kwargs.pop('isolation_level', None)
|
|
super(AsyncPostgresqlDatabase, self).init(database, **kwargs)
|
|
|
|
async def register_adapters(self, conn):
|
|
def encode_json(val):
|
|
return val if isinstance(val, bytes) else val.encode('utf8')
|
|
|
|
def decode_json(bval):
|
|
return json.loads(bval.decode())
|
|
|
|
await conn.set_type_codec(
|
|
'json', encoder=encode_json, decoder=decode_json,
|
|
schema='pg_catalog', format='binary')
|
|
|
|
def encode_jsonb(val):
|
|
if isinstance(val, bytes):
|
|
return b'\x01' + val
|
|
return b'\x01' + val.encode('utf8')
|
|
|
|
def decode_jsonb(bval):
|
|
return json.loads(bval[1:].decode())
|
|
|
|
await conn.set_type_codec(
|
|
'jsonb', encoder=encode_jsonb, decoder=decode_jsonb,
|
|
schema='pg_catalog', format='binary')
|
|
|
|
if self._async_isolation_level:
|
|
await conn.execute(
|
|
'SET SESSION CHARACTERISTICS AS TRANSACTION '
|
|
'ISOLATION LEVEL %s' % self._async_isolation_level)
|
|
|
|
async def _create_pool_async(self):
|
|
if asyncpg is None:
|
|
raise ImproperlyConfigured('asyncpg is not installed')
|
|
if self.database and self.database.startswith(
|
|
('postgresql://', 'postgres://')):
|
|
db_params = {'dsn': self.database}
|
|
else:
|
|
db_params = {'database': self.database}
|
|
return await asyncpg.create_pool(
|
|
min_size=self._pool_min_size,
|
|
max_size=self._pool_size,
|
|
init=self.register_adapters,
|
|
**db_params,
|
|
**self.connect_params)
|
|
|
|
async def _pool_acquire(self):
|
|
conn = await asyncio.wait_for(
|
|
self._pool.acquire(),
|
|
timeout=self._acquire_timeout)
|
|
return AsyncPostgresqlConnection(conn)
|
|
|
|
async def _pool_release(self, conn):
|
|
if conn and conn.conn:
|
|
# Roll back any transaction left open, e.g. by a dead task. asyncpg
|
|
# records the started transaction in conn._top_xact; rolling back
|
|
# through it clears that bookkeeping. A raw "ROLLBACK" only resets
|
|
# the server and leaves _top_xact set, so the pool's own reset would
|
|
# still log "Resetting connection with an active transaction".
|
|
top_xact = getattr(conn.conn, '_top_xact', None)
|
|
try:
|
|
if top_xact is not None:
|
|
await top_xact.rollback()
|
|
elif conn.conn.is_in_transaction():
|
|
await conn.conn.execute('ROLLBACK')
|
|
except Exception:
|
|
logger.warning('Error rolling back connection', exc_info=True)
|
|
await self._pool.release(conn.conn)
|
|
|
|
async def _pool_close(self):
|
|
await self._pool.close()
|
|
|
|
def atomic(self, *args, **kwargs):
|
|
return AsyncPgAtomic(self, *args, **kwargs)
|
|
def transaction(self, *args, **kwargs):
|
|
return AsyncPgAtomic(self, *args, **kwargs)
|
|
def savepoint(self, *args, **kwargs):
|
|
return AsyncPgAtomic(self, *args, **kwargs)
|