Hi Dom!
Great to hear you like the look of the forum - we’re continuing to work on it so all feedback is very welcome!
Thanks for your questions - let me tackle these one by one:
- Would it be possible to do this already in the Bitfount platform?
Unfortunately not at the moment. We do have the notion of weighted updates ‘under the hood’, but don’t yet expose it to users. We’ve filed a ticket our end to look at exposing it so it is easier for users to experiment with. In the meantime, we do have a work-around if you’d like to try it right now. I’ll DM the code to you for that.
- Would it also be advised to update the parameters on a per-batch rather than per-epoch basis?
In general, we advise that parameters are updated per-epoch, as this generally leads to faster convergence/training time. It is also the default setting in the system.
However, addressing the question more broadly, it is possible that one may want to experiment with per-batch synchronisation, in cases where there is a desire to more closely match the training process of ‘regular’ central training, such as in comparing federated/central training baselines. Bitfount therefore supports enabling per-batch synchronisation. Enabling this setting results in the ability to carry out ‘regular’ batch SGD, but by communicating the gradients between pod and modeller through the platform, rather than doing so centrally on a single CPU/GPU machine. However, as you may already guess, synchronising after every batch introduces a lot of overhead because of the aforementioned communication, slowing the training process down. If you’d like to try this out though, you would need to modify the FederatedAveraging algorithm, to set the number of steps (i.e. batches) between parameter updates, and run a protocol which uses that algorithm as follows:
protocol = FederatedAveraging(
algorithm=FederatedModelTraining(model=model),
steps_between_parameter_updates=1,
)
protocol.run(pod_identifiers=pod_identifiers)
In support of our recommendation to synchronise every epoch, it has been found that generally, Federated Learning (or Federated Averaging) results in faster model convergence than Federated SGD over batches, just from the perspective of the learning problem and not taking into account the communication overhead (which if we did include, would show that it is even more advisable to synchronise less often). See for example Table 2 in the work by McMahan et. al.:
[https://arxiv.org/pdf/1602.05629.pdf](https://Communication-Efficient Learning of Deep Networks from Decentralized Data)
I hope that all makes sense, and please do feel free to continue the discussion here if you have any further thoughts/questions!