How to create a validation dataset?
Hello!
I may need to split each client's train dataset into train and validation parts for grid search purposes (for example, tuning the stepsizes in a method). How can this be achieved in the framework?
If your client ids are sufficiently randomized (e.g. datasets under fedjax.datasets are; we appended a random id to the original client id when creating these datasets), you can use FederatedData.slice to take either the head or tail of the training dataset for use as validation set. Please see the end of this section on our dataset tutorial https://fedjax.readthedocs.io/en/latest/notebooks/dataset_tutorial.html#id1 for details.
Sorry, I misread your original question. A ClientDataset can be sliced using the [] operator (only slicing is supported):
_, client_dataset = next(fedjax.datasets.emnist.load_data()[0].clients())
client_dataset[:len(cd)//10]
This creates a new ClientDataset with the first 1/10 of the original.
Thanks for the suggestion! This is already a way to sort the problem out.
But can I create a new fedjax.core.FederatedData instance which would contain all clients (validation) data? I would find it handier to use.
In most cases you can use preprocess_cilent for this. Check out https://fedjax.readthedocs.io/en/latest/notebooks/dataset_tutorial.html#preprocessing-at-the-client-level. We can append to the current preprocessing another function that slices the partially processed examples as the final stage of "client level preprocessing", and the same original batch level preprocessing will be applied.
This will require some basic knowledge about the existing preprocessing for your dataset, as well as how examples (raw and preprocessed) are represented in FedJAX (see this part of the tutorial). Preprocessing of all prepackaged datasets are fairly simple, so looking at the corresponding py file under fedjax/datasets should give you a fairly good idea (example).