Source code for flask_resty.pagination
import base64
from dataclasses import dataclass
from typing import Any, List, Tuple, Union
import flask
import sqlalchemy as sa
from marshmallow import ValidationError, fields
from flask_resty.sorting import FieldOrderings, FieldSortingBase
from flask_resty.view import ModelView
from . import meta
from .exceptions import ApiError
from .utils import if_none
# -----------------------------------------------------------------------------
_NULL = str(sa.null())
class PaginationBase:
"""The base class for pagination components.
Pagination components control the list view fetches individual pages of
data, as opposed to the full collection. They handle limiting the number
of returned records, and fetching additional records after the initial
page.
Subclasses must implement :py:meth:`get_page` to provide the pagination
logic.
"""
def adjust_sort_ordering(
self, view, field_orderings: FieldOrderings
) -> FieldOrderings:
return field_orderings
def get_page(self, query, view):
"""Restrict the specified query to a single page.
:param query: The query to paginate.
:type query: :py:class:`sqlalchemy.orm.query.Query`
:param view: The view with the model we wish to paginate.
:type view: :py:class:`ModelView`
:return: The paginated query
:rtype: :py:class:`sqlalchemy.orm.query.Query`
:raises: A :py:class:`NotImplementedError` if no implementation is
provided.
"""
raise NotImplementedError()
def get_item_meta(self, item, view):
"""Build pagination metadata for a single item.
:param item: An instance of the :py:attr:`ModelView.model`.
:type item: obj
:param view: The view with the :py:attr:`ModelView.model`.
:type view: :py:class:`ModelView`
"""
return None
# -----------------------------------------------------------------------------
class LimitPaginationBase(PaginationBase):
"""The base scheme for pagination components that limit fetched item count.
The pagination metadata will indicate the presence of additional items with
the ``has_next_page`` property.
Subclasses must implement :py:meth:`get_limit` to provide maximum number
of returned items.
"""
def get_page(self, query, view) -> List:
limit = self.get_limit()
if limit is not None:
query = query.limit(limit + 1)
items = query.all()
if limit is not None and len(items) > limit:
has_next_page = True
items = items[:limit]
else:
has_next_page = False
meta.update_response_meta({"has_next_page": has_next_page})
return items
def get_limit(self):
"""Override this method to return the maximum number of returned items.
:rtype: int
"""
raise NotImplementedError()
[docs]class MaxLimitPagination(LimitPaginationBase):
"""Return up to a fixed maximum number of items.
This is not especially useful and is included only for completeness.
:param int max_limit: The maximum number of items to retrieve.
"""
def __init__(self, max_limit):
self._max_limit = max_limit
[docs]class LimitPagination(LimitPaginationBase):
"""A pagination scheme that takes a user-specified limit.
This is not especially useful and is included only for completeness.
This pagination scheme uses the :py:attr:`limit_arg` query parameter to
limit the number of items returned by the query.
If no such limit is explicitly specified, this uses `default_limit`. If
`max_limit` is specified, then the user-specified limit may not exceed
`max_limit`.
:param int default_limit: The default maximum number of items to retrieve,
if the user does not specify an explicit value.
:param int max_limit: The maximum number of items the user is allowed to
request.
"""
#: The name of the query parameter to inspect for the LIMIT value.
limit_arg = "limit"
def __init__(self, default_limit=None, max_limit=None):
self._default_limit = if_none(default_limit, max_limit)
self._max_limit = max_limit
if self._max_limit is not None:
assert (
self._default_limit <= self._max_limit
), "default limit exceeds max limit"
[docs] def get_limit(self):
limit = flask.request.args.get(self.limit_arg)
try:
return self.parse_limit(limit)
except ApiError as e:
raise e.update({"source": {"parameter": self.limit_arg}})
def parse_limit(self, limit):
if limit is None:
return self._default_limit
try:
limit = int(limit)
except ValueError as e:
raise ApiError(400, {"code": "invalid_limit"}) from e
if limit < 0:
raise ApiError(400, {"code": "invalid_limit"})
if self._max_limit is not None:
limit = min(limit, self._max_limit)
return limit
[docs]class LimitOffsetPagination(LimitPagination):
"""A pagination scheme that takes a user-specified limit and offset.
This pagination scheme takes a user-specified limit and offset. It will
retrieve up to the specified number of items, beginning at the specified
offset.
"""
#: The name of the query parameter to inspect for the OFFSET value.
offset_arg = "offset"
[docs] def get_page(self, query, view):
offset = self.get_offset()
query = query.offset(offset)
return super().get_page(query, view)
def get_offset(self):
offset = flask.request.args.get(self.offset_arg)
try:
return self.parse_offset(offset)
except ApiError as e:
raise e.update({"source": {"parameter": self.offset_arg}})
def parse_offset(self, offset):
if offset is None:
return 0
try:
offset = int(offset)
except ValueError as e:
raise ApiError(400, {"code": "invalid_offset"}) from e
if offset < 0:
raise ApiError(400, {"code": "invalid_offset"})
return offset
[docs]class PagePagination(LimitOffsetPagination):
"""A pagination scheme that fetches a particular fixed-size page.
This works similar to `LimitOffsetPagination`. The limit used will always
be the fixed page size. The offset will be page * page_size.
:param int page_size: The fixed number of items per page.
"""
#: The name of the query parameter to inspect for the page value.
page_arg = "page"
def __init__(self, page_size):
super().__init__()
self._page_size = page_size
def get_offset(self):
return self.get_request_page() * self._page_size
def get_request_page(self):
page = flask.request.args.get(self.page_arg)
try:
return self.parse_page(page)
except ApiError as e:
raise e.update({"source": {"parameter": self.page_arg}})
def parse_page(self, page):
if page is None:
return 0
try:
page = int(page)
except ValueError as e:
raise ApiError(400, {"code": "invalid_page"}) from e
if page < 0:
raise ApiError(400, {"code": "invalid_page"})
return page
# -----------------------------------------------------------------------------
Cursor = Tuple[Any, ...]
@dataclass
class CursorInfo:
reversed: bool
cursor: Union[str, None]
cursor_arg: Union[str, None]
limit: Union[str, None]
limit_arg: Union[str, None]
[docs]class CursorPaginationBase(LimitPagination):
"""The base class for pagination schemes that use cursors.
Unlike with offsets that identify items by relative position, cursors
identify position by content. This allows continuous pagination without
concerns about page shear on dynamic collections. This makes cursor-based
pagination especially effective for lists with infinite scroll dynamics,
as offset-based pagination can miss or duplicate items due to inserts or
deletes.
It's also more efficient against the database, as the cursor condition can
be cheaply evaluated as a filter against an index.
:param bool validate: If unset, bypass validation on cursor values. This is
useful if the deserializer field imposes validation that will fail for
on cursor values for items actually present.
"""
#: The name of the query parameter to inspect for the cursor value.
cursor_arg = "cursor"
limit_arg = "limit"
#: the name of the query parameter to inspect for explicit forward pagination
after_arg = "after"
first_arg = "first"
#: the name of the query parameter to inspect for explicit backward pagination
before_arg = "before"
last_arg = "last"
def __init__(self, *args, validate_values=True, **kwargs):
super().__init__(*args, **kwargs)
self._validate_values = validate_values
def try_get_arg(self, arg):
value = flask.request.args.get(arg)
if value is not None:
return (value, arg)
return (None, None)
# There are a number of different cases that this covers in order to be backwards compatible with
def get_cursor_info(self) -> CursorInfo:
cursor = None
cursor_arg = None
limit = None
limit_arg = None
# Unambiguous cases where a cursor is provided.
if self.after_arg in flask.request.args:
reversed = False
cursor, cursor_arg = self.try_get_arg(self.after_arg)
limit, limit_arg = self.try_get_arg(self.first_arg)
elif self.before_arg in flask.request.args:
reversed = True
cursor, cursor_arg = self.try_get_arg(self.before_arg)
limit, limit_arg = self.try_get_arg(self.last_arg)
# Ambiguous cases where limits are provided but not cursors
# Relay sometimes sends both first and after, default to "first"
# in keeping with the cursor precedence
elif self.first_arg in flask.request.args:
reversed = False
limit, limit_arg = self.try_get_arg(self.first_arg)
elif self.last_arg in flask.request.args:
reversed = True
limit, limit_arg = self.try_get_arg(self.last_arg)
# legacy "cursor_arg" config cases always map to after/first
else:
reversed = False
cursor, cursor_arg = self.try_get_arg(self.cursor_arg)
limit, limit_arg = self.try_get_arg(self.limit_arg)
return CursorInfo(reversed, cursor, cursor_arg, limit, limit_arg)
[docs] def get_limit(self):
cursor_info = self.get_cursor_info()
try:
return self.parse_limit(cursor_info.limit)
except ApiError as e:
raise e.update({"source": {"parameter": cursor_info.limit_arg}})
@property
def reversed(self):
return self.get_cursor_info().reversed
[docs] def adjust_sort_ordering(
self, view: ModelView, field_orderings
) -> FieldOrderings:
"""Ensure the query is sorted correctly and get the field orderings.
The implementation of cursor-based pagination in Flask-RESTy requires
that the query be sorted in a fully deterministic manner. The timestamp
columns usually used in sorting do not quite qualify, as two different
rows can have the same timestamp. This method adds the ID fields to the
sorting criterion, then returns the field orderings for use in the
other methods, as in `get_field_orderings` below.
:param view: The view with the model we wish to paginate.
:type view: :py:class:`ModelView`
:return: The field orderings necessary to do cursor pagination deterministically
:rtype: FieldOrderings
"""
# ignore the passed in sort so that it's consistent
# with further calls in get_page
return self.get_field_orderings(view)
def get_field_orderings(self, view: ModelView):
sorting: FieldSortingBase = view.sorting
assert (
sorting is not None
), "sorting must be defined when using cursor pagination"
sorting_field_orderings = sorting.get_request_field_orderings(view)
sorting_ordering_fields = frozenset(
field_name for field_name, _ in sorting_field_orderings
)
sorting_ordering_fields = frozenset(
field_name for field_name, _ in sorting_field_orderings
)
# For convenience, use the ascending setting on the last explicit
# ordering when possible, such that reversing the sort will reverse
# the IDs as well.
if sorting_field_orderings:
last_field_asc = sorting_field_orderings[-1][1]
else:
last_field_asc = True
missing_field_orderings = tuple(
(id_field, last_field_asc)
for id_field in view.id_fields
if id_field not in sorting_ordering_fields
)
field_ordering = sorting_field_orderings + missing_field_orderings
if self.reversed:
field_ordering = tuple(
(field, not order) for field, order in field_ordering
)
return field_ordering
[docs] def get_request_cursor(self, view, field_orderings):
"""Get the cursor value specified in the request.
Given the view and the field_orderings as above, this method will read
the encoded cursor from the query, then return the cursor as a tuple of
the field values in the cursor.
This parsed cursor can then be used in `get_filter`.
:param view: The view with the model we wish to paginate.
:type view: :py:class:`ModelView`
:param field_orderings: A sequence of field_ordering tuples
:type field_orderings: seq
:return: A cursor value
:rtype: str
:raises: :py:class:`ApiError` if an invalid cursor is provided in
`cursor_arg`.
"""
cursor_info = self.get_cursor_info()
if cursor_info.cursor is None:
return None
try:
return self.parse_cursor(view, cursor_info.cursor, field_orderings)
except ApiError as e:
raise e.update({"source": {"parameter": cursor_info.cursor_arg}})
def parse_cursor(
self,
view: ModelView,
cursor: str,
field_orderings: FieldOrderings,
) -> Cursor:
cursor = self.decode_cursor(cursor)
if len(cursor) != len(field_orderings):
raise ApiError(400, {"code": "invalid_cursor.length"})
deserializer = view.deserializer
column_fields = (
deserializer.fields[field_name]
for field_name, _ in field_orderings
)
try:
cursor = tuple(
self.deserialize_value(field, value)
for field, value in zip(column_fields, cursor)
)
except ValidationError as e:
raise ApiError.from_validation_error(
400, e, self.format_validation_error
) from e
return cursor
def decode_cursor(self, cursor: str) -> Tuple[str, ...]:
try:
cursor = cursor.split(".")
cursor = tuple(self.decode_value(value) for value in cursor)
except (TypeError, ValueError) as e:
raise ApiError(400, {"code": "invalid_cursor.encoding"}) from e
return cursor
def decode_value(self, value: str):
value = value.encode("ascii")
value += (3 - ((len(value) + 3) % 4)) * b"=" # Add back padding.
value = base64.urlsafe_b64decode(value).decode()
return None if value == _NULL else value
def deserialize_value(self, field, value):
if value is None:
return None
return (
field.deserialize(value)
if self._validate_values
# Cursors don't need to be fully valid values; they just need to be
# the correct type for sorting, so it can make sense to bypass
# validation.
else field._deserialize(value, None, None)
)
def format_validation_error(self, message, path):
return {"code": "invalid_cursor", "detail": message}
[docs] def get_filter(
self, view, field_orderings: FieldOrderings, cursor: Cursor
):
"""Build the filter clause corresponding to a cursor.
Given the field orderings and the cursor as above, this will construct
a filter clause that can be used to filter a query to return only items
after the specified cursor, per the specified field orderings. Use this
to apply the equivalent of the offset specified by the cursor.
:param view: The view with the model we wish to paginate.
:type view: :py:class:`ModelView`
:param field_orderings: A sequence of field_ordering tuples derived from the view's Sorting with explicit id ordering
:type field_orderings: seq
:param cursor: A set of values corresponding to the fields in
`field_orderings`
:type cursor: seq
:return: A filter clause
"""
sorting: FieldSortingBase = view.sorting
column_cursors = tuple(
(sorting.get_column(view, field_name), asc, value)
for (field_name, asc), value in zip(field_orderings, cursor)
)
return sa.or_(
self.get_filter_clause(column_cursors[: i + 1])
for i in range(len(column_cursors))
)
@staticmethod
def get_previous_clause(column_cursors):
if not column_cursors:
return None
clauses = [
column.isnot_distinct_from(value)
for column, _, value in column_cursors
]
return sa.and_(*clauses)
@staticmethod
def _handle_nullable(column, value, is_nullable):
if is_nullable:
return sa.or_(column.is_(None), (column > value))
else:
return column > value
def _prepare_current_clause(self, column, asc, value):
is_nullable = getattr(column.expression, "nullable", True)
# SQL Alchemy won't let you > or < a boolean, so we convert
# to an integer, the DB's seem to handle this just fine
if isinstance(value, bool):
column = sa.cast(column, sa.Integer)
value = int(value)
if asc:
if value is None:
return None
elif value is not None:
current_clause = self._handle_nullable(
column, value, is_nullable
)
else:
current_clause = column > value
else:
if value is None:
current_clause = column.isnot(None)
else:
current_clause = column < value
return current_clause
def get_filter_clause(self, column_cursors):
previous_clauses = self.get_previous_clause(column_cursors[:-1])
column, asc, value = column_cursors[-1]
current_clause = self._prepare_current_clause(column, asc, value)
if previous_clauses is None:
return current_clause
return sa.and_(previous_clauses, current_clause)
[docs] def make_cursors(self, items, view, field_orderings):
"""Build a cursor for each of many items.
This method creates a cursor for each item in `items`. It produces the
same cursors as :py:meth:`make_cursor`, but is slightly more efficient
in cases where cursors for multiple items are required.
:param seq items: A sequence of instances of :py:attr:`ApiView.model`
:param view: The view we wish to paginate.
:type view: :py:class:`ModelView`
:param seq field_orderings: A sequence of (field, asc?).
:return: A sequence of :py:class:`marshmallow.Field`.
:rtype: seq
"""
column_fields = self.get_column_fields(view, field_orderings)
return tuple(self.render_cursor(item, column_fields) for item in items)
[docs] def make_cursor(self, item, view, field_orderings):
"""Build a cursor for a given item.
Given an item and the field orderings as above, this builds a cursor
for the item. This cursor encodes the value for each field on the item
per the specified field orderings.
This cursor should be returned in page or item metadata to allow
pagination continuing after the cursor for the item.
:param obj item: An instance :py:attr:`ApiView.model`
:param view: The view we wish to paginate.
:type view: :py:class:`ModelView`
:param seq field_orderings: A sequence of (field, asc?).
:return: A sequence of :py:class:`marshmallow.Field`.
:rtype: seq
"""
column_fields = self.get_column_fields(view, field_orderings)
return self.render_cursor(item, column_fields)
def get_column_fields(self, view, field_orderings):
serializer = view.serializer
return tuple(
serializer.fields[field_name] for field_name, _ in field_orderings
)
def render_cursor(self, item, column_fields):
cursor = tuple(
field._serialize(getattr(item, field.name), field.name, item)
for field in column_fields
)
return self.encode_cursor(cursor)
def encode_cursor(self, cursor):
return ".".join(self.encode_value(value) for value in cursor)
def encode_value(self, value):
if value is None:
value = _NULL
value = str(value)
value = value.encode()
value = base64.urlsafe_b64encode(value)
value = value.rstrip(b"=") # Strip padding.
return value.decode("ascii")
[docs]class RelayCursorPagination(CursorPaginationBase):
"""A pagination scheme that works with the Relay specification.
This pagination scheme assigns a cursor to each retrieved item. The page
metadata will contain an array of cursors, one per item. The item metadata
will include the cursor for the fetched item.
For Relay Cursor Connections Specification, see
https://facebook.github.io/relay/graphql/connections.htm.
"""
def __init__(
self,
*args,
page_info_arg=None,
default_include_page_info=False,
**kwargs,
):
super().__init__(*args, **kwargs)
self._default_include_page_info = default_include_page_info
self.page_info_arg = page_info_arg
def get_page_info(self, query, view, field_orderings, cursor, items):
include_page_info = (
self.deserialize_value(
fields.Boolean(),
flask.request.args.get(
self.page_info_arg, self._default_include_page_info
),
)
if self.page_info_arg
else self._default_include_page_info
)
if not include_page_info:
return {}
total = query.count()
index = 0
if cursor:
filter_clause = self.get_filter(
view,
tuple((field, not order) for field, order in field_orderings),
cursor,
)
index = query.filter(filter_clause).count() + 1
# in the reversed case, both the `order by` and sort are inverted.
# so in practice this gives us a reverse index, e.g. distance from
# the end of the list. We normalize it back by subtracting from the total
if self.reversed:
before_index = total - index
index = max(before_index - len(items), 0)
return {"index": index, "total": total}
[docs] def get_page(self, query, view):
field_orderings = self.get_field_orderings(view)
cursor_in = self.get_request_cursor(view, field_orderings)
page_query = query
if cursor_in is not None:
page_query = page_query.filter(
self.get_filter(view, field_orderings, cursor_in)
)
items = super().get_page(page_query, view)
if self.reversed:
items.reverse()
# Relay expects a cursor for each item.
cursors_out = self.make_cursors(items, view, field_orderings)
page_info = self.get_page_info(
query, view, field_orderings, cursor_in, items
)
meta.update_response_meta({"cursors": cursors_out, **page_info})
return items
[docs] def get_item_meta(self, item, view):
cursor = self.make_cursor(item, view, self.get_field_orderings(view))
return {"cursor": cursor}