Any tff.Computation
(like next
) will always run the entire specified computation. If your tff.templates.IterativeProcess
is, for example, the result of tff.learning.build_federated_averaging_process
, its next
function will represent one round of the federated averaging algorithm.
The federated averaging algorithm runs training for a fixed number of epochs (let's say 1 for simplicity) over each local dataset, and averages the model updates in a data-weighted manner at the server in order to complete a round--see Algorithm 1 in the original federated averaging paper for a specification of the algorithm.
Now, for how TFF represents and executes this algorithm. From the documentation for build_federated_averaging_process
, the next
function has type signature:
(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)
TFF's type system represents a dataset as a tff.SequenceType
(this is the meaning of the *
above), so the second element in the parameter of the type signature represents a set (technically a multiset) of datasets with elements of type B
, placed at the clients.
What this means in your example is as follows. You have a list of tf.data.Datasets
, each of which represents the local data on each client--you can think of the list as representing the federated placement. In this context, TFF executing the entire specified computation means: TFF will treat every item in the list as a client to be trained on in this round. In the terms of the algorithm linked above, your list of datasets represents the set S_t.
TFF will faithfully execute one round of the federated averaging algorithm, with the Dataset
elements of your list representing the clients selected for this round. Training will be run for a single epoch on each client (in parallel); if datasets have different amounts of data, you are correct that the training on each client is likely to finish at different times. However, this is the correct semantics of a single round of the federated averaging algorithm, as opposed to a parameterization of a similar algorithm like Reptile, which runs for a fixed number of steps for each client.
If you wish to select a subset of clients to run a round of training on, this should be done in Python, before calling into TFF, e.g.:
state = iterative_process.initialize()
# ls is list of datasets
sampled_clients = random.sample(ls, N_CLIENTS)
state = iterative_process.next(state, sampled_clients)
Generally, you can think of the Python runtime as an "experiment driver" layer--any selection of clients, for example, should happen at this layer. See the beginning of this answer for further detail on this.