import logging
import warnings
import threading
from banal import ensure_list
from sqlalchemy import func, select, false
from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError
from dataset.types import Types, MYSQL_LENGTH_TYPES
from dataset.util import index_name
from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns
from dataset.util import normalize_column_name, normalize_column_key
log = logging.getLogger(__name__)
[docs]class Table(object):
"""Represents a table in a database and exposes common operations."""
PRIMARY_DEFAULT = "id"
def __init__(
self,
database,
table_name,
primary_id=None,
primary_type=None,
primary_increment=None,
auto_create=False,
):
"""Initialise the table from database schema."""
self.db = database
self.name = normalize_table_name(table_name)
self._table = None
self._columns = None
self._indexes = []
self._primary_id = (
primary_id if primary_id is not None else self.PRIMARY_DEFAULT
)
self._primary_type = primary_type if primary_type is not None else Types.integer
if primary_increment is None:
primary_increment = self._primary_type in (Types.integer, Types.bigint)
self._primary_increment = primary_increment
self._auto_create = auto_create
@property
def exists(self):
"""Check to see if the table currently exists in the database."""
if self._table is not None:
return True
return self.name in self.db
@property
def table(self):
"""Get a reference to the table, which may be reflected or created."""
if self._table is None:
self._sync_table(())
return self._table
@property
def _column_keys(self):
"""Get a dictionary of all columns and their case mapping."""
if not self.exists:
return {}
with self.db.lock:
if self._columns is None:
# Initialise the table if it doesn't exist
table = self.table
self._columns = {}
for column in table.columns:
name = normalize_column_name(column.name)
key = normalize_column_key(name)
if key in self._columns:
log.warning("Duplicate column: %s", name)
self._columns[key] = name
return self._columns
@property
def columns(self):
"""Get a listing of all columns that exist in the table."""
return list(self._column_keys.values())
[docs] def has_column(self, column):
"""Check if a column with the given name exists on this table."""
key = normalize_column_key(normalize_column_name(column))
return key in self._column_keys
def _get_column_name(self, name):
"""Find the best column name with case-insensitive matching."""
name = normalize_column_name(name)
key = normalize_column_key(name)
return self._column_keys.get(key, name)
[docs] def insert(self, row, ensure=None, types=None):
"""Add a ``row`` dict by inserting it into the table.
If ``ensure`` is set, any of the keys of the row are not
table columns, they will be created automatically.
During column creation, ``types`` will be checked for a key
matching the name of a column to be created, and the given
SQLAlchemy column type will be used. Otherwise, the type is
guessed from the row value, defaulting to a simple unicode
field.
::
data = dict(title='I am a banana!')
table.insert(data)
Returns the inserted row's primary key.
"""
row = self._sync_columns(row, ensure, types=types)
res = self.db.executable.execute(self.table.insert(row))
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
return True
[docs] def insert_ignore(self, row, keys, ensure=None, types=None):
"""Add a ``row`` dict into the table if the row does not exist.
If rows with matching ``keys`` exist no change is made.
Setting ``ensure`` results in automatically creating missing columns,
i.e., keys of the row are not table columns.
During column creation, ``types`` will be checked for a key
matching the name of a column to be created, and the given
SQLAlchemy column type will be used. Otherwise, the type is
guessed from the row value, defaulting to a simple unicode
field.
::
data = dict(id=10, title='I am a banana!')
table.insert_ignore(data, ['id'])
"""
row = self._sync_columns(row, ensure, types=types)
if self._check_ensure(ensure):
self.create_index(keys)
args, _ = self._keys_to_args(row, keys)
if self.count(**args) == 0:
return self.insert(row, ensure=False)
return False
[docs] def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
"""Add many rows at a time.
This is significantly faster than adding them one by one. Per default
the rows are processed in chunks of 1000 per commit, unless you specify
a different ``chunk_size``.
See :py:meth:`insert() <dataset.Table.insert>` for details on
the other parameters.
::
rows = [dict(name='Dolly')] * 10000
table.insert_many(rows)
"""
# Sync table before inputting rows.
sync_row = {}
for row in rows:
# Only get non-existing columns.
sync_keys = list(sync_row.keys())
for key in [k for k in row.keys() if k not in sync_keys]:
# Get a sample of the new column(s) from the row.
sync_row[key] = row[key]
self._sync_columns(sync_row, ensure, types=types)
# Get columns name list to be used for padding later.
columns = sync_row.keys()
chunk = []
for index, row in enumerate(rows):
chunk.append(row)
# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk)
chunk = []
[docs] def update(self, row, keys, ensure=None, types=None, return_count=False):
"""Update a row in the table.
The update is managed via the set of column names stated in ``keys``:
they will be used as filters for the data to be updated, using the
values in ``row``.
::
# update all entries with id matching 10, setting their title
# columns
data = dict(id=10, title='I am a banana!')
table.update(data, ['id'])
If keys in ``row`` update columns not present in the table, they will
be created based on the settings of ``ensure`` and ``types``, matching
the behavior of :py:meth:`insert() <dataset.Table.insert>`.
"""
row = self._sync_columns(row, ensure, types=types)
args, row = self._keys_to_args(row, keys)
clause = self._args_to_clause(args)
if not len(row):
return self.count(clause)
stmt = self.table.update(whereclause=clause, values=row)
rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
if return_count:
return self.count(clause)
[docs] def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""Update many rows in the table at a time.
This is significantly faster than updating them one by one. Per default
the rows are processed in chunks of 1000 per commit, unless you specify
a different ``chunk_size``.
See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters.
"""
keys = ensure_list(keys)
chunk = []
columns = []
for index, row in enumerate(rows):
columns.extend(
col for col in row.keys() if (col not in columns) and (col not in keys)
)
# bindparam requires names to not conflict (cannot be "id" for id)
for key in keys:
row["_%s" % key] = row[key]
row.pop(key)
chunk.append(row)
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = self.table.update(
whereclause=and_(True, *cl),
values={col: bindparam(col, required=False) for col in columns},
)
self.db.executable.execute(stmt, chunk)
chunk = []
[docs] def upsert(self, row, keys, ensure=None, types=None):
"""An UPSERT is a smart combination of insert and update.
If rows with matching ``keys`` exist they will be updated, otherwise a
new row is inserted in the table.
::
data = dict(id=10, title='I am a banana!')
table.upsert(data, ['id'])
"""
row = self._sync_columns(row, ensure, types=types)
if self._check_ensure(ensure):
self.create_index(keys)
row_count = self.update(row, keys, ensure=False, return_count=True)
if row_count == 0:
return self.insert(row, ensure=False)
return True
[docs] def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""
Sorts multiple input rows into upserts and inserts. Inserts are passed
to insert and upserts are updated.
See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`.
"""
# Removing a bulk implementation in 5e09aba401. Doing this one by one
# is incredibly slow, but doesn't run into issues with column creation.
for row in rows:
self.upsert(row, keys, ensure=ensure, types=types)
[docs] def delete(self, *clauses, **filters):
"""Delete rows from the table.
Keyword arguments can be used to add column-based filters. The filter
criterion will always be equality:
::
table.delete(place='Berlin')
If no arguments are given, all records are deleted.
"""
if not self.exists:
return False
clause = self._args_to_clause(filters, clauses=clauses)
stmt = self.table.delete(whereclause=clause)
rp = self.db.executable.execute(stmt)
return rp.rowcount > 0
def _reflect_table(self):
"""Load the tables definition from the database."""
with self.db.lock:
self._columns = None
try:
self._table = SQLATable(
self.name, self.db.metadata, schema=self.db.schema, autoload=True
)
except NoSuchTableError:
self._table = None
def _threading_warn(self):
if self.db.in_transaction and threading.active_count() > 1:
warnings.warn(
"Changing the database schema inside a transaction "
"in a multi-threaded environment is likely to lead "
"to race conditions and synchronization issues.",
RuntimeWarning,
)
def _sync_table(self, columns):
"""Lazy load, create or adapt the table structure in the database."""
if self._table is None:
# Load an existing table from the database.
self._reflect_table()
if self._table is None:
# Create the table with an initial set of columns.
if not self._auto_create:
raise DatasetException("Table does not exist: %s" % self.name)
# Keep the lock scope small because this is run very often.
with self.db.lock:
self._threading_warn()
self._table = SQLATable(
self.name, self.db.metadata, schema=self.db.schema
)
if self._primary_id is not False:
# This can go wrong on DBMS like MySQL and SQLite where
# tables cannot have no columns.
column = Column(
self._primary_id,
self._primary_type,
primary_key=True,
autoincrement=self._primary_increment,
)
self._table.append_column(column)
for column in columns:
if not column.name == self._primary_id:
self._table.append_column(column)
self._table.create(self.db.executable, checkfirst=True)
self._columns = None
elif len(columns):
with self.db.lock:
self._reflect_table()
self._threading_warn()
for column in columns:
if not self.has_column(column.name):
self.db.op.add_column(self.name, column, schema=self.db.schema)
self._reflect_table()
def _sync_columns(self, row, ensure, types=None):
"""Create missing columns (or the table) prior to writes.
If automatic schema generation is disabled (``ensure`` is ``False``),
this will remove any keys from the ``row`` for which there is no
matching column.
"""
ensure = self._check_ensure(ensure)
types = types or {}
types = {self._get_column_name(k): v for (k, v) in types.items()}
out = {}
sync_columns = {}
for name, value in row.items():
name = self._get_column_name(name)
if self.has_column(name):
out[name] = value
elif ensure:
_type = types.get(name)
if _type is None:
_type = self.db.types.guess(value)
sync_columns[name] = Column(name, _type)
out[name] = value
self._sync_table(sync_columns.values())
return out
def _check_ensure(self, ensure):
if ensure is None:
return self.db.ensure_schema
return ensure
def _generate_clause(self, column, op, value):
if op in ("like",):
return self.table.c[column].like(value)
if op in ("ilike",):
return self.table.c[column].ilike(value)
if op in ("notlike",):
return self.table.c[column].notlike(value)
if op in ("notilike",):
return self.table.c[column].notilike(value)
if op in (">", "gt"):
return self.table.c[column] > value
if op in ("<", "lt"):
return self.table.c[column] < value
if op in (">=", "gte"):
return self.table.c[column] >= value
if op in ("<=", "lte"):
return self.table.c[column] <= value
if op in ("=", "==", "is"):
return self.table.c[column] == value
if op in ("!=", "<>", "not"):
return self.table.c[column] != value
if op in ("in",):
return self.table.c[column].in_(value)
if op in ("notin",):
return self.table.c[column].notin_(value)
if op in ("between", ".."):
start, end = value
return self.table.c[column].between(start, end)
if op in ("startswith",):
return self.table.c[column].like(value + "%")
if op in ("endswith",):
return self.table.c[column].like("%" + value)
return false()
def _args_to_clause(self, args, clauses=()):
clauses = list(clauses)
for column, value in args.items():
column = self._get_column_name(column)
if not self.has_column(column):
clauses.append(false())
elif isinstance(value, (list, tuple, set)):
clauses.append(self._generate_clause(column, "in", value))
elif isinstance(value, dict):
for op, op_value in value.items():
clauses.append(self._generate_clause(column, op, op_value))
else:
clauses.append(self._generate_clause(column, "=", value))
return and_(True, *clauses)
def _args_to_order_by(self, order_by):
orderings = []
for ordering in ensure_list(order_by):
if ordering is None:
continue
column = ordering.lstrip("-")
column = self._get_column_name(column)
if not self.has_column(column):
continue
if ordering.startswith("-"):
orderings.append(self.table.c[column].desc())
else:
orderings.append(self.table.c[column].asc())
return orderings
def _keys_to_args(self, row, keys):
keys = [self._get_column_name(k) for k in ensure_list(keys)]
row = row.copy()
args = {k: row.pop(k, None) for k in keys}
return args, row
[docs] def create_column(self, name, type, **kwargs):
"""Create a new column ``name`` of a specified type.
::
table.create_column('created_at', db.types.datetime)
`type` corresponds to an SQLAlchemy type as described by
`dataset.db.Types`. Additional keyword arguments are passed
to the constructor of `Column`, so that default values, and
options like `nullable` and `unique` can be set.
::
table.create_column('key', unique=True, nullable=False)
table.create_column('food', default='banana')
"""
name = self._get_column_name(name)
if self.has_column(name):
log.debug("Column exists: %s" % name)
return
self._sync_table((Column(name, type, **kwargs),))
[docs] def create_column_by_example(self, name, value):
"""
Explicitly create a new column ``name`` with a type that is appropriate
to store the given example ``value``. The type is guessed in the same
way as for the insert method with ``ensure=True``.
::
table.create_column_by_example('length', 4.2)
If a column of the same name already exists, no action is taken, even
if it is not of the type we would have created.
"""
type_ = self.db.types.guess(value)
self.create_column(name, type_)
[docs] def drop_column(self, name):
"""
Drop the column ``name``.
::
table.drop_column('created_at')
"""
if self.db.engine.dialect.name == "sqlite":
raise RuntimeError("SQLite does not support dropping columns.")
name = self._get_column_name(name)
with self.db.lock:
if not self.exists or not self.has_column(name):
log.debug("Column does not exist: %s", name)
return
self._threading_warn()
self.db.op.drop_column(self.table.name, name, schema=self.table.schema)
self._reflect_table()
[docs] def drop(self):
"""Drop the table from the database.
Deletes both the schema and all the contents within it.
"""
with self.db.lock:
if self.exists:
self._threading_warn()
self.table.drop(self.db.executable, checkfirst=True)
self._table = None
self._columns = None
self.db._tables.pop(self.name, None)
[docs] def has_index(self, columns):
"""Check if an index exists to cover the given ``columns``."""
if not self.exists:
return False
columns = set([self._get_column_name(c) for c in ensure_list(columns)])
if columns in self._indexes:
return True
for column in columns:
if not self.has_column(column):
return False
indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema)
for index in indexes:
idx_columns = index.get("column_names", [])
if len(columns.intersection(idx_columns)) == len(columns):
self._indexes.append(columns)
return True
if self.table.primary_key is not None:
pk_columns = [c.name for c in self.table.primary_key.columns]
if len(columns.intersection(pk_columns)) == len(columns):
self._indexes.append(columns)
return True
return False
[docs] def create_index(self, columns, name=None, **kw):
"""Create an index to speed up queries on a table.
If no ``name`` is given a random name is created.
::
table.create_index(['name', 'country'])
"""
columns = [self._get_column_name(c) for c in ensure_list(columns)]
with self.db.lock:
if not self.exists:
raise DatasetException("Table has not been created yet.")
for column in columns:
if not self.has_column(column):
return
if not self.has_index(columns):
self._threading_warn()
name = name or index_name(self.name, columns)
columns = [self.table.c[c] for c in columns]
# MySQL crashes out if you try to index very long text fields,
# apparently. This defines (a somewhat random) prefix that
# will be captured by the index, after which I assume the engine
# conducts a more linear scan:
mysql_length = {}
for col in columns:
if isinstance(col.type, MYSQL_LENGTH_TYPES):
mysql_length[col.name] = 10
kw["mysql_length"] = mysql_length
idx = Index(name, *columns, **kw)
idx.create(self.db.executable)
[docs] def find(self, *_clauses, **kwargs):
"""Perform a simple search on the table.
Simply pass keyword arguments as ``filter``.
::
results = table.find(country='France')
results = table.find(country='France', year=1980)
Using ``_limit``::
# just return the first 10 rows
results = table.find(country='France', _limit=10)
You can sort the results by single or multiple columns. Append a minus
sign to the column name for descending order::
# sort results by a column 'year'
results = table.find(country='France', order_by='year')
# return all rows sorted by multiple columns (descending by year)
results = table.find(order_by=['country', '-year'])
You can also submit filters based on criteria other than equality,
see :ref:`advanced_filters` for details.
To run more complex queries with JOINs, or to perform GROUP BY-style
aggregation, you can also use :py:meth:`db.query() <dataset.Database.query>`
to run raw SQL queries instead.
"""
if not self.exists:
return iter([])
_limit = kwargs.pop("_limit", None)
_offset = kwargs.pop("_offset", 0)
order_by = kwargs.pop("order_by", None)
_streamed = kwargs.pop("_streamed", False)
_step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0:
_step = None
order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
if len(order_by):
query = query.order_by(*order_by)
conn = self.db.executable
if _streamed:
conn = self.db.engine.connect()
conn = conn.execution_options(stream_results=True)
return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step)
[docs] def find_one(self, *args, **kwargs):
"""Get a single result from the table.
Works just like :py:meth:`find() <dataset.Table.find>` but returns one
result, or ``None``.
::
row = table.find_one(country='United States')
"""
if not self.exists:
return None
kwargs["_limit"] = 1
kwargs["_step"] = None
resiter = self.find(*args, **kwargs)
try:
for row in resiter:
return row
finally:
resiter.close()
[docs] def count(self, *_clauses, **kwargs):
"""Return the count of results for the given filter set."""
# NOTE: this does not have support for limit and offset since I can't
# see how this is useful. Still, there might be compatibility issues
# with people using these flags. Let's see how it goes.
if not self.exists:
return 0
args = self._args_to_clause(kwargs, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = query.select_from(self.table)
rp = self.db.executable.execute(query)
return rp.fetchone()[0]
[docs] def __len__(self):
"""Return the number of rows in the table."""
return self.count()
[docs] def distinct(self, *args, **_filter):
"""Return all the unique (distinct) values for the given ``columns``.
::
# returns only one row per year, ignoring the rest
table.distinct('year')
# works with multiple columns, too
table.distinct('year', 'country')
# you can also combine this with a filter
table.distinct('year', country='China')
"""
if not self.exists:
return iter([])
columns = []
clauses = []
for column in args:
if isinstance(column, ClauseElement):
clauses.append(column)
else:
if not self.has_column(column):
raise DatasetException("No such column: %s" % column)
columns.append(self.table.c[column])
clause = self._args_to_clause(_filter, clauses=clauses)
if not len(columns):
return iter([])
q = expression.select(
columns,
distinct=True,
whereclause=clause,
order_by=[c.asc() for c in columns],
)
return self.db.query(q)
# Legacy methods for running find queries.
all = find
[docs] def __iter__(self):
"""Return all rows of the table as simple dictionaries.
Allows for iterating over all rows in the table without explicitly
calling :py:meth:`find() <dataset.Table.find>`.
::
for row in table:
print(row)
"""
return self.find()
def __repr__(self):
"""Get table representation."""
return "<Table(%s)>" % self.table.name