import json try: import mysql.connector as mysql_connector except ImportError: mysql_connector = None try: import mariadb except ImportError: mariadb = None from peewee import Expression from peewee import ImproperlyConfigured from peewee import Insert from peewee import JSONField from peewee import JSONPath from peewee import MySQLDatabase from peewee import Node from peewee import NodeList from peewee import OP from peewee import SQL from peewee import TextField from peewee import Value from peewee import fn from playhouse.pool import _PooledMySQLDatabase class MySQLConnectorDatabase(MySQLDatabase): def _connect(self): if mysql_connector is None: raise ImproperlyConfigured('MySQL connector not installed!') return mysql_connector.connect(db=self.database, autocommit=True, **self.connect_params) def cursor(self, named_cursor=None): if self.is_closed(): if self.autoconnect: self.connect() else: raise InterfaceError('Error, database connection not opened.') return self._state.conn.cursor(buffered=True) def get_binary_type(self): return mysql_connector.Binary class PooledMySQLConnectorDatabase(_PooledMySQLDatabase, MySQLConnectorDatabase): pass class MariaDBConnectorDatabase(MySQLDatabase): mariadb = True def _connect(self): if mariadb is None: raise ImproperlyConfigured('mariadb connector not installed!') self.connect_params.pop('charset', None) self.connect_params.pop('sql_mode', None) self.connect_params.pop('use_unicode', None) return mariadb.connect(db=self.database, autocommit=True, **self.connect_params) def cursor(self, named_cursor=None): if self.is_closed(): if self.autoconnect: self.connect() else: raise InterfaceError('Error, database connection not opened.') return self._state.conn.cursor(buffered=True) def _set_server_version(self, conn): version = conn.server_version version, point = divmod(version, 100) version, minor = divmod(version, 100) self.server_version = (version, minor, point) if self.server_version >= (10, 5, 0): self.returning_clause = True def last_insert_id(self, cursor, query_type=None): if not self.returning_clause: return cursor.lastrowid elif query_type == Insert.SIMPLE: try: return cursor[0][0] except (AttributeError, IndexError): return cursor.lastrowid return cursor def get_binary_type(self): return mariadb.Binary class PooledMariaDBConnectorDatabase(_PooledMySQLDatabase, MariaDBConnectorDatabase): pass def Match(columns, expr, modifier=None): if isinstance(columns, (list, tuple)): match = fn.MATCH(*columns) # Tuple of one or more columns / fields. else: match = fn.MATCH(columns) # Single column / field. args = expr if modifier is None else NodeList((expr, SQL(modifier))) return NodeList((match, fn.AGAINST(args))) class _MySQLJSONPath(JSONPath): def contains_any(self, value): f = self._field lhs = fn.json_extract(f, f._helper._path(self._keys)) rhs = Value(f._dumps(value), converter=False) return Expression(fn.JSON_OVERLAPS(lhs, rhs), OP.EQ, 1) class MySQLJSONField(JSONField): def contains_any(self, value): rhs = Value(self._dumps(value), converter=False) return Expression(fn.JSON_OVERLAPS(self, rhs), OP.EQ, 1) def __getitem__(self, key): return _MySQLJSONPath(self, (key,)) def path(self, *keys): return _MySQLJSONPath(self, keys)