Source code for flask_resty.filtering

import copy
import functools

import flask
import marshmallow
import sqlalchemy as sa
from marshmallow import ValidationError
from sqlalchemy import sql

from .exceptions import ApiError

# -----------------------------------------------------------------------------

# Field.missing is deprecated in favor of Field.load_default in marshmallow 3.13.0
_USE_LOAD_DEFAULT = marshmallow.__version_info__ >= (3, 13)


[docs]class ArgFilterBase: """An abstract specification of a filter from a query argument. Implementing classes must provide :py:meth:`maybe_set_arg_name` and :py:meth:`filter_query`. """
[docs] def maybe_set_arg_name(self, arg_name): """Set the name of the argument to which this filter is bound. :param str arg_name: The name of the field to filter against. :raises: :py:class:`NotImplementedError` if no implementation is provided. """ raise NotImplementedError()
[docs] def filter_query(self, query, view, arg_value): """Filter the query. :param query: The query to filter. :type query: :py:class:`sqlalchemy.orm.query.Query` :param view: The view with the model we wish to filter for. :type view: :py:class:`ModelView` :param str arg_value: The filter specification :return: The filtered query :rtype: :py:class:`sqlalchemy.orm.query.Query` :raises: :py:class:`NotImplementedError` if no implementation is provided. """ raise NotImplementedError()
# -----------------------------------------------------------------------------
[docs]class FieldFilterBase(ArgFilterBase): """A filter that uses a marshmallow field to deserialize its value. Implementing classes must provide :py:meth:`get_filter_field` and :py:meth:`get_filter_clause`. :param str separator: Character that separates individual elements in the query value. :param bool allow_empty: If set, allow filtering for empty values; otherwise, filter out all items on an empty value. :param bool skip_invalid: If set, ignore invalid filter values instead of throwing an API error. """ def __init__( self, *, separator=",", allow_empty=False, skip_invalid=False ): self._separator = separator self._allow_empty = allow_empty self._skip_invalid = skip_invalid
[docs] def maybe_set_arg_name(self, arg_name): pass
[docs] def filter_query(self, query, view, arg_value): filter = self.get_filter(view, arg_value) if filter is None: return query return query.filter(filter)
def get_filter(self, view, arg_value): if arg_value is None: return self.get_default_filter(view) if not arg_value and not self._allow_empty: return sql.false() if not self._separator or self._separator not in arg_value: return self.get_element_filter(view, arg_value) return sa.or_( self.get_element_filter(view, value_raw) for value_raw in arg_value.split(self._separator) ) def get_default_filter(self, view): field = self.get_field(view) if field.required: raise ApiError(400, {"code": "invalid_filter.missing"}) load_default = ( field.load_default if _USE_LOAD_DEFAULT else field.missing ) value = load_default() if callable(load_default) else load_default if value is marshmallow.missing: return None return self.get_element_filter(view, value) def get_element_filter(self, view, value): field = self.get_field(view) try: value = self.deserialize(field, value) except ValidationError as e: if self._skip_invalid: return sql.false() raise ApiError.from_validation_error( 400, e, self.format_validation_error ) from e return self.get_filter_clause(view, value)
[docs] def deserialize(self, field, value_raw): """Overridable hook for deserializing a value. :param field: The marshmallow field. :type field: :py:class:`marshmallow.fields.Field` :param value_raw: The value to deserialize. :return: The deserialized value. """ return field.deserialize(value_raw)
def format_validation_error(self, message, path): return {"code": "invalid_filter", "detail": message}
[docs] def get_field(self, view): """Get the marshmallow field for deserializing filter values. :param view: The view with the model we wish to filter for. :type view: :py:class:`ModelView` :raises: :py:class:`NotImplementedError` if no implementation is provided. """ raise NotImplementedError()
[docs] def get_filter_clause(self, view, value): """Build the filter clause for the deserialized value. :param view: The view with the model we wish to filter for. :type view: :py:class:`ModelView` :param str value: The right-hand side of the WHERE clause. :raises: :py:class:`NotImplementedError` if no implementation is provided. """ raise NotImplementedError()
[docs]class ColumnFilter(FieldFilterBase): """A filter that operates on the value of a database column. This filter relies on the schema to deserialize the query argument values. `ColumnFilter` cannot normally be used for columns that do not appear on the schema, but such columns can be added to the schema with fields that have both `load_only` and `dump_only` set. :param str column_name: The name of the column to filter against. :param func operator: A callable that returns the filter expression given the column and the filter value. :param bool required: If set, fail if this filter is not specified. :param bool validate: If unset, bypass validation on the field. This is useful if the field specifies validation rule for inputs that are not relevant for filters. """ def __init__( self, column_name=None, operator=None, *, required=False, missing=marshmallow.missing, validate=True, **kwargs, ): super().__init__(**kwargs) if operator is None and callable(column_name): operator = column_name column_name = None if not operator: raise TypeError("must specify operator") self._has_explicit_column_name = column_name is not None self._column_name = column_name self._operator = operator self._fields = {} self._required = required self._missing = missing self._validate = validate
[docs] def maybe_set_arg_name(self, arg_name): """Set `arg_name` as the column name if no explicit value is available. :param str arg_name: The name of the column to filter against. """ if self._has_explicit_column_name: return if self._column_name and self._column_name != arg_name: raise TypeError( "cannot use ColumnFilter without explicit column name for multiple arg names" ) self._column_name = arg_name
[docs] def get_field(self, view): """Construct the marshmallow field for deserializing filter values. This takes the field from the deserializer, then creates a copy with the desired semantics around missing values. :param view: The view with the model we wish to filter for. :type view: :py:class:`ModelView` """ base_field = view.deserializer.fields[self._column_name] try: field = self._fields[base_field] except KeyError: # We don't want the default value handling on the original field, # as that's only relevant for object deserialization. field = copy.deepcopy(base_field) field.required = self._required if _USE_LOAD_DEFAULT: field.load_default = self._missing else: field.missing = self._missing self._fields[base_field] = field return field
[docs] def get_filter_clause(self, view, value): column = getattr(view.model, self._column_name) return self._operator(column, value)
[docs] def deserialize(self, field, value_raw): """Deserialize `value_raw`, optionally skipping validation. :param field: The marshmallow field. :type field: :py:class:`marshmallow.fields.Field` :param value_raw: The value to deserialize. :return: The deserialized value. """ if not self._validate: # We may not want to apply the same validation for filters as we do # on model fields. This bypasses the irrelevant handling of # missing and None values, and skips the validation check. return field._deserialize(value_raw, None, None) return super().deserialize(field, value_raw)
[docs]class ModelFilter(FieldFilterBase): """An arbitrary filter against the model. :param field: A marshmallow field for deserializing filter values. :type field: :py:class:`marshmallow.fields.Field` :param filter: A callable that returns the filter expression given the model and the filter value. :param dict kwargs: Passed to :py:class:`FieldFilterBase`. """ def __init__(self, field, filter, **kwargs): super().__init__(**kwargs) self._field = field self._filter = filter
[docs] def get_field(self, view): return self._field
[docs] def get_filter_clause(self, view, value): return self._filter(view.model, value)
# -----------------------------------------------------------------------------
[docs]def model_filter(field, **kwargs): """A convenience decorator for building a `ModelFilter`. This decorator allows building a `ModelFilter` around a named function:: @model_filter(fields.String(required=True)) def filter_color(model, value): return model.color == value :param field: A marshmallow field for deserializing filter values. :type field: :py:class:`marshmallow.fields.Field` :param dict kwargs: Passed to :py:class:`ModelFilter`. """ def wrapper(func): filter_field = ModelFilter(field, func, **kwargs) functools.update_wrapper(filter_field, func) return filter_field return wrapper
# -----------------------------------------------------------------------------
[docs]class Filtering: """Container for the arg filters on a :py:class:`ModelView`. :param dict kwargs: A mapping from filter field names to filters. """ def __init__(self, **kwargs): self._arg_filters = { arg_name: self.make_arg_filter(arg_name, arg_filter) for arg_name, arg_filter in kwargs.items() } def make_arg_filter(self, arg_name, arg_filter): if callable(arg_filter): arg_filter = ColumnFilter(arg_name, arg_filter) arg_filter.maybe_set_arg_name(arg_name) return arg_filter
[docs] def filter_query(self, query, view): """Filter a query using the configured filters and the request args. :param query: The query to filter. :type query: :py:class:`sqlalchemy.orm.query.Query` :param view: The view with the model we wish to filter for. :type view: :py:class:`ModelView` :return: The filtered query :rtype: :py:class:`sqlalchemy.orm.query.Query` """ args = flask.request.args for arg_name, arg_filter in self._arg_filters.items(): try: arg_value = args[arg_name] except KeyError: arg_value = None try: query = arg_filter.filter_query(query, view, arg_value) except ApiError as e: raise e.update({"source": {"parameter": arg_name}}) return query
def __or__(self, other): """Combine two `Filtering` instances. `Filtering` supports view inheritance by implementing the `|` operator. For example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will create a new `Filtering` instance with filters for each `foo`, `bar` and `baz`. Filters on the right-hand side take precedence where each `Filtering` instance has the same key. """ if not isinstance(other, Filtering): return NotImplemented return self.__class__(**{**self._arg_filters, **other._arg_filters})