If your use-case is to cache the result of computationally intensive functions in your pytest test suites, pytest already has a file-based cache. See the docs for more info.
This being said, I had a few extra requirements:
- I wanted to be able to call the cached function directly in the test instead of from a fixture
- I wanted to cache complex python objects, not just simple python primitives/containers
- I wanted an implementation that could refresh the cache intelligently (or be forced to invalidate only a single key)
Thus I came up with my own wrapper for the pytest cache, which you
can find below. The implementation is fully documented, but if you
need more info let me know and I'll be happy to edit this answer :)
Enjoy:
from base64 import b64encode, b64decode
import hashlib
import inspect
import pickle
from typing import Any, Optional
import pytest
__all__ = ['cached']
@pytest.fixture
def cached(request):
def _cached(func: callable, *args, _invalidate_cache: bool = False, _refresh_key: Optional[Any] = None, **kwargs):
"""Caches the result of func(*args, **kwargs) cross-testrun.
Cache invalidation can be performed by passing _invalidate_cache=True or a _refresh_key can
be passed for improved control on invalidation policy.
For example, given a function that executes a side effect such as querying a database:
result = query(sql)
can be cached as follows:
refresh_key = query(sql=fast_refresh_sql)
result = cached(query, sql=slow_or_expensive_sql, _refresh_key=refresh_key)
or can be directly invalidated if you are doing rapid iteration of your test:
result = cached(query, sql=sql, _invalidate_cache=True)
Args:
func (callable): Callable that will be called
_invalidate_cache (bool, optional): Whether or not to invalidate_cache. Defaults to False.
_refresh_key (Optional[Any], optional): Refresh key to provide a programmatic way to invalidate cache. Defaults to None.
*args: Positional args to pass to func
**kwargs: Keyword args to pass to func
Returns:
_type_: _description_
"""
# get debug info
# see https://stackoverflow.com/a/24439444/4442749
try:
func_name = getattr(func, '__name__', repr(func))
except:
func_name = '<function>'
try:
caller = inspect.getframeinfo(inspect.stack()[1][0])
except:
func_name = '<file>:<lineno>'
call_key = _create_call_key(func, None, *args, **kwargs)
cached_value = request.config.cache.get(call_key, {"refresh_key": None, "value": None})
value = cached_value["value"]
current_refresh_key = str(b64encode(pickle.dumps(_refresh_key)), encoding='utf8')
cached_refresh_key = cached_value.get("refresh_key")
if (
_invalidate_cache # force invalidate
or cached_refresh_key is None # first time caching this call
or current_refresh_key != cached_refresh_key # refresh_key has changed
):
print("Cache invalidated for '%s' @ %s:%d" % (func_name, caller.filename, caller.lineno))
result = func(*args, **kwargs)
value = str(b64encode(pickle.dumps(result)), encoding='utf8')
request.config.cache.set(
key=call_key,
value={
"refresh_key": current_refresh_key,
"value": value
}
)
else:
print("Cache hit for '%s' @ %s:%d" % (func_name, caller.filename, caller.lineno))
result = pickle.loads(b64decode(bytes(value, encoding='utf8')))
return result
return _cached
_args_marker = object()
_kwargs_marker = object()
def _create_call_key(func: callable, refresh_key: Any, *args, **kwargs):
"""Produces a hex hash str of the call func(*args, **kwargs)"""
# producing a key from func + args
# see https://stackoverflow.com/a/10220908/4442749
call_key = pickle.dumps(
(func, refresh_key) +
(_args_marker, ) +
tuple(args) +
(_kwargs_marker,) +
tuple(sorted(kwargs.items()))
)
# create a hex digest of the key for the filename
m = hashlib.sha256()
m.update(bytes(call_key))
return m.digest().hex()