import json import logging import uuid from peewee import * from peewee import ColumnBase from peewee import Expression from peewee import FieldDatabaseHook from peewee import Node from peewee import NodeList from peewee import Psycopg2Adapter from peewee import Psycopg3Adapter from peewee import __exception_wrapper__ from playhouse.pool import _PooledPostgresqlDatabase try: from psycopg2cffi import compat compat.register() except ImportError: pass try: from psycopg2.extras import register_hstore except ImportError: def register_hstore(*args): pass try: from psycopg.types import TypeInfo from psycopg.types.hstore import register_hstore as register_hstore_pg3 except ImportError: def register_hstore_pg3(*args): pass logger = logging.getLogger('peewee') HCONTAINS_DICT = '@>' HCONTAINS_KEYS = '?&' HCONTAINS_KEY = '?' HCONTAINS_ANY_KEY = '?|' HKEY = '->' HUPDATE = '||' ACONTAINS = '@>' ACONTAINED_BY = '<@' ACONTAINS_ANY = '&&' TS_MATCH = '@@' JSONB_CONTAINS = '@>' JSONB_CONTAINED_BY = '<@' JSONB_CONTAINS_KEY = '?' JSONB_CONTAINS_ANY_KEY = '?|' JSONB_CONTAINS_ALL_KEYS = '?&' JSONB_EXISTS = '?' JSONB_REMOVE = '-' JSONB_PATH_REMOVE = '#-' JSONB_PATH = '#>' JSONB_PATH_EXISTS = '@?' JSONB_PATH_MATCH = '@@' class Json(Node): # Fallback JSON handler. __slots__ = ('value',) def __init__(self, value, dumps=None): self.value = value self.dumps = dumps or json.dumps def __sql__(self, ctx): return ctx.value(self.value, self.dumps) class _LookupNode(ColumnBase): def __init__(self, node, parts): self.node = node self.parts = parts super(_LookupNode, self).__init__() def clone(self): return type(self)(self.node, list(self.parts)) def __hash__(self): return hash((self.__class__.__name__, id(self))) class ObjectSlice(_LookupNode): @classmethod def create(cls, node, value): if isinstance(value, slice): stop = value.stop - 1 if value.stop is not None else None parts = [value.start or 0, stop] elif isinstance(value, int): parts = [value] elif isinstance(value, Node): parts = value else: # Assumes colon-separated integer indexes. parts = [int(i) for i in value.split(':')] return cls(node, parts) def __sql__(self, ctx): ctx.sql(self.node) if isinstance(self.parts, Node): ctx.literal('[').sql(self.parts).literal(']') else: ctx.literal('[%s]' % ':'.join([str(p + 1) if p is not None else '' for p in self.parts])) return ctx def __getitem__(self, value): return ObjectSlice.create(self, value) class IndexedFieldMixin(object): default_index_type = 'GIN' def __init__(self, *args, **kwargs): kwargs.setdefault('index', True) # By default, use an index. super(IndexedFieldMixin, self).__init__(*args, **kwargs) class ArrayField(IndexedFieldMixin, Field): passthrough = True def __init__(self, field_class=IntegerField, field_kwargs=None, dimensions=1, convert_values=False, *args, **kwargs): self.__field = field_class(**(field_kwargs or {})) self.dimensions = dimensions self.convert_values = convert_values self.field_type = self.__field.field_type super(ArrayField, self).__init__(*args, **kwargs) def bind(self, model, name, set_attribute=True): ret = super(ArrayField, self).bind(model, name, set_attribute) self.__field.bind(model, '__array_%s' % name, False) return ret def ddl_datatype(self, ctx): data_type = self.__field.ddl_datatype(ctx) return NodeList((data_type, SQL('[]' * self.dimensions)), glue='') def db_value(self, value): if value is None or isinstance(value, Node): return value elif self.convert_values: return self._process(self.__field.db_value, value, self.dimensions) else: return value if isinstance(value, list) else list(value) def python_value(self, value): if self.convert_values and value is not None: conv = self.__field.python_value if isinstance(value, list): return self._process(conv, value, self.dimensions) else: return conv(value) else: return value def _process(self, conv, value, dimensions): dimensions -= 1 if dimensions == 0: return [conv(v) for v in value] else: return [self._process(conv, v, dimensions) for v in value] def __getitem__(self, value): return ObjectSlice.create(self, value) def _e(op): def inner(self, rhs): return Expression(self, op, ArrayValue(self, rhs)) return inner __eq__ = _e(OP.EQ) __ne__ = _e(OP.NE) __gt__ = _e(OP.GT) __ge__ = _e(OP.GTE) __lt__ = _e(OP.LT) __le__ = _e(OP.LTE) __hash__ = Field.__hash__ def contains(self, *items): return Expression(self, ACONTAINS, ArrayValue(self, items)) def contains_any(self, *items): return Expression(self, ACONTAINS_ANY, ArrayValue(self, items)) def contained_by(self, *items): return Expression(self, ACONTAINED_BY, ArrayValue(self, items)) class ArrayValue(Node): def __init__(self, field, value): self.field = field self.value = value def __sql__(self, ctx): return (ctx .sql(AsIs(self.value)) .literal('::') .sql(self.field.ddl_datatype(ctx))) class DateTimeTZField(DateTimeField): field_type = 'TIMESTAMPTZ' class HStoreField(IndexedFieldMixin, Field): field_type = 'HSTORE' __hash__ = Field.__hash__ def __getitem__(self, key): return Expression(self, HKEY, Value(key)) def keys(self): return fn.akeys(self) def values(self): return fn.avals(self) def items(self): return fn.hstore_to_matrix(self) def slice(self, *args): return fn.slice(self, AsIs(list(args))) def exists(self, key): return fn.exist(self, key) def defined(self, key): return fn.defined(self, key) def update(self, __data=None, **data): if __data is not None: data.update(__data) return Expression(self, HUPDATE, data) def delete(self, *keys): value = Cast(AsIs(list(keys)), 'text[]') return fn.delete(self, value) def contains(self, value): if isinstance(value, dict): rhs = AsIs(value) return Expression(self, HCONTAINS_DICT, rhs) elif isinstance(value, (list, tuple)): rhs = AsIs(value) return Expression(self, HCONTAINS_KEYS, rhs) return Expression(self, HCONTAINS_KEY, value) def contains_any(self, *keys): return Expression(self, HCONTAINS_ANY_KEY, AsIs(list(keys))) def _path_array(parts): # Build a Postgres `text[]` literal from a sequence of path components. # Used by the jsonb_set/jsonb_insert wrappers on JSON lookups and fields. return Cast(AsIs([str(p) for p in parts], False), 'text[]') class _JsonLookupBase(_LookupNode): def __init__(self, node, parts, as_json=False): super(_JsonLookupBase, self).__init__(node, parts) self._jsonb = getattr(node, '_json_type', 'jsonb') == 'jsonb' self._as_json = as_json def clone(self): return type(self)(self.node, list(self.parts), self._as_json) @Node.copy def as_json(self, as_json=True): self._as_json = as_json def concat(self, rhs): if not isinstance(rhs, Node): rhs = self.node.json_type(rhs) return Expression(self.as_json(True), OP.CONCAT, rhs) def contains(self, other): if not isinstance(other, Node): other = self.node.json_type(other) return Expression(self.as_json(True), JSONB_CONTAINS, other) def contained_by(self, other): if not isinstance(other, Node): other = self.node.json_type(other) return Expression(self.as_json(True), JSONB_CONTAINED_BY, other) def contains_any(self, *keys): return Expression( self.as_json(True), JSONB_CONTAINS_ANY_KEY, AsIs(list(keys), False)) def contains_all(self, *keys): return Expression( self.as_json(True), JSONB_CONTAINS_ALL_KEYS, AsIs(list(keys), False)) def has_key(self, key): return Expression(self.as_json(True), JSONB_CONTAINS_KEY, key) def remove(self): parts = [str(p) if isinstance(p, int) else p for p in self.parts] value = AsIs(parts, False) return Expression(self.node, JSONB_PATH_REMOVE, value) def length(self): func = fn.jsonb_array_length if self._jsonb else fn.json_array_length return func(self.as_json(True)) def extract(self, *path): path = [str(p) if isinstance(p, int) else p for p in path] func = fn.jsonb_extract_path if self._jsonb else fn.json_extract_path return func(self.as_json(True), *path) def path(self, *keys): return JsonPath(self.as_json(True), keys, as_json=True) def _resolve_root(self): # Walk up through chained lookups (e.g. lookup.path(...) wraps the # lookup as the inner node) to find the underlying field and the full # list of path parts relative to it. node = self parts = [] while isinstance(node, _JsonLookupBase): parts = list(node.parts) + parts node = node.node return node, parts def set(self, value): # jsonb_set(field, '{path}', value, create_missing=true) field, parts = self._resolve_root() return fn.jsonb_set(field, _path_array(parts), field._wrap_value(value), True) def replace(self, value): # jsonb_set(field, '{path}', value, create_missing=false): no-op when # the path doesn't exist. field, parts = self._resolve_root() return fn.jsonb_set(field, _path_array(parts), field._wrap_value(value), False) def insert(self, value): # SQLite-style "insert if missing": no-op when the path exists. # Postgres has no single-call equivalent, so we wrap jsonb_set in a # CASE that checks whether the path resolves. field, parts = self._resolve_root() current = JsonPath(field, parts, as_json=True) return Case(None, [ (current.is_null(), fn.jsonb_set(field, _path_array(parts), field._wrap_value(value), True)) ], field) def append(self, value): # Append to the array at this path. '-1' is the index of the last # element; with insert_after=true, jsonb_insert places `value` after # it. field, parts = self._resolve_root() path = _path_array(list(parts) + ['-1']) return fn.jsonb_insert(field, path, field._wrap_value(value), True) def update(self, value): # Postgres lacks a real merge-patch (e.g. what SQLite offers, following # RFC-7396), so this is just a shallow copy using '||'. field, parts = self._resolve_root() current = JsonPath(field, parts, as_json=True) merged = Expression(current, OP.CONCAT, field._wrap_value(value)) return fn.jsonb_set(field, _path_array(parts), merged, True) # SQL/JSON path expression operators applied to the value at this lookup. # These produce jsonb-typed expressions; usable on lookups off jsonb # columns. Requires PostgreSQL 12+. def path_exists(self, expr): return Expression(self.as_json(True), JSONB_PATH_EXISTS, _jsonpath(expr)) def path_match(self, expr): return Expression(self.as_json(True), JSONB_PATH_MATCH, _jsonpath(expr)) def path_query(self, expr): return fn.jsonb_path_query(self.as_json(True), _jsonpath(expr)) def path_query_array(self, expr): return fn.jsonb_path_query_array(self.as_json(True), _jsonpath(expr)) def path_query_first(self, expr): return fn.jsonb_path_query_first(self.as_json(True), _jsonpath(expr)) class JsonLookup(_JsonLookupBase): def __getitem__(self, value): return JsonLookup(self.node, self.parts + [value], self._as_json) def __sql__(self, ctx): ctx.sql(self.node) for part in self.parts[:-1]: ctx.literal('->').sql(part) if self.parts: (ctx .literal('->' if self._as_json else '->>') .sql(self.parts[-1])) return ctx class JsonPath(_JsonLookupBase): def __sql__(self, ctx): return (ctx .sql(self.node) .literal('#>' if self._as_json else '#>>') .sql(Value('{%s}' % ','.join(map(str, self.parts))))) class JSONField(FieldDatabaseHook, Field): field_type = 'JSON' _json_datatype = 'json' def __init__(self, dumps=None, **kwargs): self._dumps = dumps super(JSONField, self).__init__(**kwargs) def _db_hook(self, database): if database is None or not hasattr(database, '_adapter'): self.json_type = Json self.cast_json_case = True else: self.json_type = database._adapter.json_type self.cast_json_case = database._adapter.cast_json_case if self._dumps: dumps = self._dumps class _Json(self.json_type): def __init__(self, value): super(_Json, self).__init__(value, dumps=dumps) self.json_type = _Json def db_value(self, value): if value is None or isinstance(value, (Node, self.json_type)): return value return self.json_type(value) def to_value(self, value, case=False): # CASE WHEN id = 123 THEN x.json_data fails because the expression is # untyped, so we need an explicit cast with psycopg2. if case and self.cast_json_case: return Cast(self.json_type(value), self._json_datatype) return self.db_value(value) def __getitem__(self, value): return JsonLookup(self, [value]) def path(self, *keys): return JsonPath(self, keys, as_json=True) def concat(self, value): if not isinstance(value, Node): value = self.json_type(value) return super(JSONField, self).concat(value) def length(self): return fn.json_array_length(self) def extract(self, *path): path = [str(p) if isinstance(p, int) else p for p in path] return fn.json_extract_path(self, *path) def _wrap_value(self, value): if isinstance(value, Node): return value return self.json_type(value) def append(self, value): # Append to the document when it's an array. return fn.jsonb_insert(self, _path_array(['-1']), self._wrap_value(value), True) def update(self, value): # Shallow merge `value` into the root object via `||`. return Expression(self, OP.CONCAT, self._wrap_value(value)) def _jsonpath(expr): if isinstance(expr, Node): return expr return Cast(Value(expr, False), 'jsonpath') class BinaryJSONField(IndexedFieldMixin, JSONField): field_type = 'JSONB' _json_datatype = 'jsonb' __hash__ = Field.__hash__ def _db_hook(self, database): if database is None or not hasattr(database, '_adapter'): self.json_type = Json self.cast_json_case = True else: self.json_type = database._adapter.jsonb_type self.cast_json_case = database._adapter.cast_json_case if self._dumps: dumps = self._dumps class _Json(self.json_type): def __init__(self, value): super(_Json, self).__init__(value, dumps=dumps) self.json_type = _Json def contains(self, other): if not isinstance(other, Node): other = self.json_type(other) return Expression(self, JSONB_CONTAINS, other) def contained_by(self, other): if not isinstance(other, Node): other = self.json_type(other) return Expression(self, JSONB_CONTAINED_BY, other) def contains_any(self, *items): return Expression( self, JSONB_CONTAINS_ANY_KEY, AsIs(list(items), False)) def contains_all(self, *items): return Expression( self, JSONB_CONTAINS_ALL_KEYS, AsIs(list(items), False)) def has_key(self, key): return Expression(self, JSONB_CONTAINS_KEY, Value(key, False)) def remove(self, *items): value = Cast(AsIs(list(items), False), 'text[]') return Expression(self, JSONB_REMOVE, value) def length(self): return fn.jsonb_array_length(self) def extract(self, *path): path = [str(p) if isinstance(p, int) else p for p in path] return fn.jsonb_extract_path(self, *path) def path_exists(self, expr): # field @? jsonpath - true if the path matches anything in the value. return Expression(self, JSONB_PATH_EXISTS, _jsonpath(expr)) def path_match(self, expr): # field @@ jsonpath - true if the path's predicate evaluates to true. return Expression(self, JSONB_PATH_MATCH, _jsonpath(expr)) def path_query(self, expr): # jsonb_path_query is set-returning, use in .from_() with an alias. return fn.jsonb_path_query(self, _jsonpath(expr)) def path_query_array(self, expr): return fn.jsonb_path_query_array(self, _jsonpath(expr)) def path_query_first(self, expr): return fn.jsonb_path_query_first(self, _jsonpath(expr)) class TSVectorField(IndexedFieldMixin, TextField): field_type = 'TSVECTOR' __hash__ = Field.__hash__ def match(self, query, language=None, plain=False): params = (language, query) if language is not None else (query,) func = fn.plainto_tsquery if plain else fn.to_tsquery return Expression(self, TS_MATCH, func(*params)) def Match(field, query, language=None): params = (language, query) if language is not None else (query,) field_params = (language, field) if language is not None else (field,) return Expression( fn.to_tsvector(*field_params), TS_MATCH, fn.to_tsquery(*params)) class IntervalField(Field): field_type = 'INTERVAL' class FetchManyCursor(object): __slots__ = ('cursor', 'array_size', 'exhausted', 'iterable') def __init__(self, cursor, array_size=None): self.cursor = cursor self.array_size = array_size or cursor.itersize self.exhausted = False self.iterable = self.row_gen() @property def description(self): return self.cursor.description def close(self): if self.cursor is not None and not self.cursor.closed: self.cursor.close() def row_gen(self): try: while True: rows = self.cursor.fetchmany(self.array_size) if not rows: return for row in rows: yield row finally: self.close() def fetchone(self): if self.exhausted: return try: return next(self.iterable) except StopIteration: self.exhausted = True class ServerSideQuery(Node): def __init__(self, query, array_size=None): self.query = query self.array_size = array_size self._cursor_wrapper = None def __sql__(self, ctx): return self.query.__sql__(ctx) def __iter__(self): if self._cursor_wrapper is None: self._execute(self.query._database) return iter(self._cursor_wrapper.iterator()) def close(self): if self._cursor_wrapper is not None: self._cursor_wrapper.cursor.close() self._cursor_wrapper = None return True return False def iterator(self): if self._cursor_wrapper is None: self._execute(self.query._database) return self._cursor_wrapper.iterator() def _execute(self, database): if self._cursor_wrapper is None: cursor = database.execute(self.query, named_cursor=True, array_size=self.array_size) self._cursor_wrapper = self.query._get_cursor_wrapper(cursor) return self._cursor_wrapper def ServerSide(query, array_size=None): server_side_query = ServerSideQuery(query, array_size=array_size) for row in server_side_query: yield row class _empty_object(object): __slots__ = () def __nonzero__(self): return False __bool__ = __nonzero__ class Psycopg2ExtAdapter(Psycopg2Adapter): def register_hstore(self, conn): register_hstore(conn) def server_side_cursor(self, conn): # psycopg2 does not allow us to use these in autocommit, even if we ARE # inside a transaction - so specify withhold (not desirable!). return conn.cursor(name=str(uuid.uuid1()), withhold=True) class Psycopg3ExtAdapter(Psycopg3Adapter): def register_hstore(self, conn): info = TypeInfo.fetch(conn, 'hstore') register_hstore_pg3(info, conn) def server_side_cursor(self, conn): return conn.cursor(name=str(uuid.uuid1())) class PostgresqlExtDatabase(PostgresqlDatabase): psycopg2_adapter = Psycopg2ExtAdapter psycopg3_adapter = Psycopg3ExtAdapter def __init__(self, *args, **kwargs): self._register_hstore = kwargs.pop('register_hstore', False) self._server_side_cursors = kwargs.pop('server_side_cursors', False) super(PostgresqlExtDatabase, self).__init__(*args, **kwargs) def _connect(self): conn = super(PostgresqlExtDatabase, self)._connect() if self._register_hstore: self._adapter.register_hstore(conn) return conn def cursor(self, named_cursor=None): if self.is_closed(): if self.autoconnect: self.connect() else: raise InterfaceError('Error, database connection not opened.') if named_cursor: return self._adapter.server_side_cursor(self._state.conn) return self._state.conn.cursor() def execute(self, query, named_cursor=False, array_size=None, **context_options): ctx = self.get_sql_context(**context_options) sql, params = ctx.sql(query).query() named_cursor = named_cursor or (self._server_side_cursors and sql[:6].lower() == 'select') cursor = self.execute_sql(sql, params, named_cursor=named_cursor) if named_cursor: cursor = FetchManyCursor(cursor, array_size) return cursor def execute_sql(self, sql, params=None, named_cursor=None): logger.debug((sql, params)) with __exception_wrapper__: cursor = self.cursor(named_cursor=named_cursor) cursor.execute(sql, params or ()) return cursor class PooledPostgresqlExtDatabase(_PooledPostgresqlDatabase, PostgresqlExtDatabase): pass class Psycopg3Database(PostgresqlExtDatabase): def __init__(self, *args, **kwargs): kwargs['prefer_psycopg3'] = True super(Psycopg3Database, self).__init__(*args, **kwargs) class PooledPsycopg3Database(_PooledPostgresqlDatabase, Psycopg3Database): pass