There is no mechanism for this built in to jax.tree_util
. In a way, the question is ill-posed: tree flattening is applicable to a far more general class of objects than nested dicts as in your example; you can even define pytree flattening for any arbitrary object (see https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees), and it's not clear to me how you'd construct labels for flattened objects in this general case.
If you're only concerned with nested dicts and you want to generate these kinds of flattened labels, your best bet is probably to write your own Python code to construct the flattened keys and values; for example something like this might work:
p = {'a': {'b': 1.0, 'c': 2.0}}
def flatten(p, label=None):
if isinstance(p, dict):
for k, v in p.items():
yield from flatten(v, k if label is None else f"{label}.{k}")
else:
yield (label, p)
print(dict(flatten(p)))
# {'a.b': 1.0, 'a.c': 2.0}