MNIST Example: Workflow API#
Welcome to the first OpenFL Experimental Workflow Interface tutorial! This notebook introduces the API to get up and running with your first horizontal federated learning workflow. This work has the following goals:
Simplify the federated workflow representation
Help users better understand the steps in federated learning (weight extraction, compression, etc.)
Designed to maintain data privacy
Aims for syntactic consistency with the Netflix MetaFlow project. Infrastructure reuse where possible.
What is it?#
The workflow interface is a new way of composing federated learning experiments with OpenFL. It was borne through conversations with researchers and existing users who had novel use cases that didn’t quite fit the standard horizontal federated learning paradigm.
Getting Started#
First we start by installing the necessary dependencies for the workflow interface
!fx experimental activate
!pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
Show code cell output
2024-12-18 22:39:26.421661: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-18 22:39:26.423459: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-18 22:39:26.429968: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-18 22:39:26.444879: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1734541766.471717 258172 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734541766.479343 258172 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-18 22:39:26.506693: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/karan/playground/openfl/venv/lib/python3.10/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. Support for replacing an already imported distutils is deprecated. In the future, this condition will fail. Register concerns at https://github.com/pypa/setuptools/issues/new?template=distutils-deprecation.yml
warnings.warn(
Requirement already satisfied: chardet in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 1)) (5.2.0)
Requirement already satisfied: charset-normalizer in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 2)) (3.4.0)
Requirement already satisfied: dill==0.3.6 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 3)) (0.3.6)
Requirement already satisfied: matplotlib>=2.0.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (3.10.0)
Requirement already satisfied: metaflow==2.7.15 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (2.7.15)
Requirement already satisfied: nbdev==2.3.12 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (2.3.12)
Requirement already satisfied: nbformat==5.10.4 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (5.10.4)
Requirement already satisfied: ray==2.9.2 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (2.9.2)
Requirement already satisfied: tabulate==0.9.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from -r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 9)) (0.9.0)
Requirement already satisfied: requests in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (2.32.3)
Requirement already satisfied: boto3 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (1.35.83)
Requirement already satisfied: pylint in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (3.3.2)
Requirement already satisfied: fastcore>=1.5.27 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (1.7.27)
Requirement already satisfied: execnb>=0.1.4 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.1.11)
Requirement already satisfied: astunparse in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (1.6.3)
Requirement already satisfied: ghapi>=1.0.3 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (1.0.6)
Requirement already satisfied: watchdog in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (6.0.0)
Requirement already satisfied: asttokens in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (3.0.0)
Requirement already satisfied: PyYAML in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (6.0.2)
Requirement already satisfied: fastjsonschema>=2.15 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (2.21.1)
Requirement already satisfied: jsonschema>=2.6 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (4.23.0)
Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (5.7.2)
Requirement already satisfied: traitlets>=5.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (5.14.3)
Requirement already satisfied: click>=7.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (8.1.7)
Requirement already satisfied: filelock in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (3.16.1)
Requirement already satisfied: msgpack<2.0.0,>=1.0.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (1.1.0)
Requirement already satisfied: packaging in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (24.2)
Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (5.29.1)
Requirement already satisfied: aiosignal in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (1.3.2)
Requirement already satisfied: frozenlist in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ray==2.9.2->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 8)) (1.5.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (4.55.3)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (1.4.7)
Requirement already satisfied: numpy>=1.23 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (2.0.2)
Requirement already satisfied: pillow>=8 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (11.0.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (2.9.0.post0)
Requirement already satisfied: ipython in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (8.30.0)
Requirement already satisfied: attrs>=22.2.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jsonschema>=2.6->nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (24.3.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jsonschema>=2.6->nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (2024.10.1)
Requirement already satisfied: referencing>=0.28.4 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jsonschema>=2.6->nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (0.35.1)
Requirement already satisfied: rpds-py>=0.7.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jsonschema>=2.6->nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (0.22.3)
Requirement already satisfied: platformdirs>=2.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jupyter-core!=5.0.*,>=4.12->nbformat==5.10.4->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 7)) (4.3.6)
Requirement already satisfied: six>=1.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib>=2.0.0->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 4)) (1.17.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from astunparse->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.45.1)
Requirement already satisfied: botocore<1.36.0,>=1.35.83 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from boto3->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (1.35.83)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from boto3->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (1.0.1)
Requirement already satisfied: s3transfer<0.11.0,>=0.10.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from boto3->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (0.10.4)
Requirement already satisfied: astroid<=3.4.0-dev0,>=3.3.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (3.3.6)
Requirement already satisfied: isort!=5.13.0,<6,>=4.2.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (5.13.2)
Requirement already satisfied: mccabe<0.8,>=0.6 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (0.7.0)
Requirement already satisfied: tomli>=1.1.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (2.2.1)
Requirement already satisfied: tomlkit>=0.10.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (0.13.2)
Requirement already satisfied: idna<4,>=2.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from requests->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from requests->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from requests->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (2024.12.14)
Requirement already satisfied: typing-extensions>=4.0.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from astroid<=3.4.0-dev0,>=3.3.5->pylint->metaflow==2.7.15->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 5)) (4.12.2)
Requirement already satisfied: decorator in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (5.1.1)
Requirement already satisfied: exceptiongroup in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (1.2.2)
Requirement already satisfied: jedi>=0.16 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.19.2)
Requirement already satisfied: matplotlib-inline in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.1.7)
Requirement already satisfied: pexpect>4.3 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (4.9.0)
Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (3.0.48)
Requirement already satisfied: pygments>=2.4.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (2.18.0)
Requirement already satisfied: stack_data in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.6.3)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jedi>=0.16->ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from pexpect>4.3->ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.7.0)
Requirement already satisfied: wcwidth in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.2.13)
Requirement already satisfied: executing>=1.2.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from stack_data->ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (2.1.0)
Requirement already satisfied: pure-eval in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from stack_data->ipython->execnb>=0.1.4->nbdev==2.3.12->-r /home/karan/playground/openfl/openfl-tutorials/experimental/workflow/workflow_interface_requirements.txt (line 6)) (0.2.3)
✔️ OK
Looking in indexes: https://download.pytorch.org/whl/cpu
Requirement already satisfied: torch in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (2.5.1+cpu)
Requirement already satisfied: torchvision in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (0.20.1+cpu)
Requirement already satisfied: filelock in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (3.16.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (4.12.2)
Requirement already satisfied: networkx in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (3.2.1)
Requirement already satisfied: jinja2 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (2024.2.0)
Requirement already satisfied: sympy==1.13.1 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: numpy in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torchvision) (2.0.2)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from torchvision) (11.0.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/karan/playground/openfl/venv/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)
We begin with the quintessential example of a small pytorch CNN model trained on the MNIST dataset. Let’s start by defining our data loaders, model, optimizer, and some helper functions like we would for any other deep learning experiment.
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import numpy as np
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
mnist_train = torchvision.datasets.MNIST(
"/tmp/files/",
train=True,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
mnist_test = torchvision.datasets.MNIST(
"/tmp/files/",
train=False,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def inference(network,test_loader):
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
accuracy = float(correct / len(test_loader.dataset))
return accuracy
Next we import the FLSpec, LocalRuntime, and placement decorators.
FLSpec– Defines the flow specification. User defined flows are subclasses of this.Runtime– Defines where the flow runs, infrastructure for task transitions (how information gets sent). TheLocalRuntimeruns the flow on a single node.aggregator/collaborator- placement decorators that define where the task will be assigned
from copy import deepcopy
from openfl.experimental.workflow.interface import FLSpec, Aggregator, Collaborator
from openfl.experimental.workflow.runtime import LocalRuntime
from openfl.experimental.workflow.placement import aggregator, collaborator
def FedAvg(models, weights=None):
new_model = models[0]
state_dicts = [model.state_dict() for model in models]
state_dict = new_model.state_dict()
for key in models[1].state_dict():
state_dict[key] = torch.from_numpy(np.average([state[key].numpy() for state in state_dicts],
axis=0,
weights=weights))
new_model.load_state_dict(state_dict)
return new_model
Now we come to the flow definition. The OpenFL Workflow Interface adopts the conventions set by Metaflow, that every workflow begins with start and concludes with the end task. The aggregator begins with an optionally passed in model and optimizer. The aggregator begins the flow with the start task, where the list of collaborators is extracted from the runtime (self.collaborators = self.runtime.collaborators) and is then used as the list of participants to run the task listed in self.next, aggregated_model_validation. The model, optimizer, and anything that is not explicitly excluded from the next function will be passed from the start function on the aggregator to the aggregated_model_validation task on the collaborator. Where the tasks run is determined by the placement decorator that precedes each task definition (@aggregator or @collaborator). Once each of the collaborators (defined in the runtime) complete the aggregated_model_validation task, they pass their current state onto the train task, from train to local_model_validation, and then finally to join at the aggregator. It is in join that an average is taken of the model weights, and the next round can begin.

class FederatedFlow(FLSpec):
def __init__(self, model=None, optimizer=None, rounds=3, **kwargs):
super().__init__(**kwargs)
if model is not None:
self.model = model
self.optimizer = optimizer
else:
self.model = Net()
self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
momentum=momentum)
self.rounds = rounds
@aggregator
def start(self):
print(f'Performing initialization for model')
self.collaborators = self.runtime.collaborators
self.private = 10
self.current_round = 0
self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])
@collaborator
def aggregated_model_validation(self):
print(f'Performing aggregated model validation for collaborator {self.input}')
self.agg_validation_score = inference(self.model, self.test_loader)
print(f'{self.input} value of {self.agg_validation_score}')
self.next(self.train)
@collaborator
def train(self):
self.model.train()
self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate,
momentum=momentum)
train_losses = []
for batch_idx, (data, target) in enumerate(self.train_loader):
self.optimizer.zero_grad()
output = self.model(data)
loss = F.nll_loss(output, target)
loss.backward()
self.optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: 1 [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
batch_idx * len(data), len(self.train_loader.dataset),
100. * batch_idx / len(self.train_loader), loss.item()))
self.loss = loss.item()
torch.save(self.model.state_dict(), 'model.pth')
torch.save(self.optimizer.state_dict(), 'optimizer.pth')
self.training_completed = True
self.next(self.local_model_validation)
@collaborator
def local_model_validation(self):
self.local_validation_score = inference(self.model, self.test_loader)
print(
f'Doing local model validation for collaborator {self.input}: {self.local_validation_score}')
self.next(self.join, exclude=['training_completed'])
@aggregator
def join(self, inputs):
self.average_loss = sum(input.loss for input in inputs) / len(inputs)
self.aggregated_model_accuracy = sum(
input.agg_validation_score for input in inputs) / len(inputs)
self.local_model_accuracy = sum(
input.local_validation_score for input in inputs) / len(inputs)
print(f'Average aggregated model validation values = {self.aggregated_model_accuracy}')
print(f'Average training loss = {self.average_loss}')
print(f'Average local model validation values = {self.local_model_accuracy}')
self.model = FedAvg([input.model for input in inputs])
self.optimizer = [input.optimizer for input in inputs][0]
self.current_round += 1
if self.current_round < self.rounds:
self.next(self.aggregated_model_validation,
foreach='collaborators', exclude=['private'])
else:
self.next(self.end)
@aggregator
def end(self):
print(f'This is the end of the flow')
Show code cell output
Aggregator step "start" registered
Collaborator step "aggregated_model_validation" registered
Collaborator step "train" registered
Collaborator step "local_model_validation" registered
Aggregator step "join" registered
Aggregator step "end" registered
You’ll notice in the FederatedFlow definition above that there were certain attributes that the flow was not initialized with, namely the train_loader and test_loader for each of the collaborators. These are private_attributes of the particular participant and (as the name suggests) are accessible ONLY to the particular participant through its task. Additionally these private attributes are always filtered out of the current state when transferring from collaborator to aggregator, and vice versa.
Users can directly specify a collaborator’s private attributes via collaborator.private_attributes which is a dictionary where key is name of the attribute and value is the object that is made accessible to collaborator. In this example, we segment shards of the MNIST dataset for four collaborators: Portland, Seattle, Chandler and Bangalore. Each shard / slice of the dataset is assigned to collaborator’s private_attributes.
Note that the private attributes are flexible, and user can choose to pass in a completely different type of object to any of the collaborators or aggregator (with an arbitrary name).
Subsequent tutorials shall show examples to assign private_attributes for aggregator and another methodology of specifying private attributes via a callable.
# Setup participants
aggregator = Aggregator()
aggregator.private_attributes = {}
# Setup collaborators with private attributes
collaborator_names = ['Portland', 'Seattle', 'Chandler','Bangalore']
collaborators = [Collaborator(name=name) for name in collaborator_names]
for idx, collaborator in enumerate(collaborators):
local_train = deepcopy(mnist_train)
local_test = deepcopy(mnist_test)
local_train.data = mnist_train.data[idx::len(collaborators)]
local_train.targets = mnist_train.targets[idx::len(collaborators)]
local_test.data = mnist_test.data[idx::len(collaborators)]
local_test.targets = mnist_test.targets[idx::len(collaborators)]
collaborator.private_attributes = {
'train_loader': torch.utils.data.DataLoader(local_train,batch_size=batch_size_train, shuffle=True),
'test_loader': torch.utils.data.DataLoader(local_test,batch_size=batch_size_train, shuffle=True)
}
local_runtime = LocalRuntime(aggregator=aggregator, collaborators=collaborators, backend='single_process')
print(f'Local runtime collaborators = {local_runtime.collaborators}')
Show code cell output
Local runtime collaborators = ['Portland', 'Seattle', 'Chandler', 'Bangalore']
Now that we have our flow and runtime defined, let’s run the experiment!
model = None
best_model = None
optimizer = None
flflow = FederatedFlow(model, optimizer, rounds=2, checkpoint=True)
flflow.runtime = local_runtime
flflow.run()
Show code cell output
Created flow FederatedFlow
Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Portland
Test set: Avg. loss: 2.3264, Accuracy: 309/2500 (12%)
Portland value of 0.12359999865293503
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 2.370591
Train Epoch: 1 [640/15000 (4%)] Loss: 2.265008
Train Epoch: 1 [1280/15000 (9%)] Loss: 2.300407
Train Epoch: 1 [1920/15000 (13%)] Loss: 2.249448
Train Epoch: 1 [2560/15000 (17%)] Loss: 2.251498
Train Epoch: 1 [3200/15000 (21%)] Loss: 2.267806
Train Epoch: 1 [3840/15000 (26%)] Loss: 2.201275
Train Epoch: 1 [4480/15000 (30%)] Loss: 2.181914
Train Epoch: 1 [5120/15000 (34%)] Loss: 2.115410
Train Epoch: 1 [5760/15000 (38%)] Loss: 2.086649
Train Epoch: 1 [6400/15000 (43%)] Loss: 1.970717
Train Epoch: 1 [7040/15000 (47%)] Loss: 1.829772
Train Epoch: 1 [7680/15000 (51%)] Loss: 1.933031
Train Epoch: 1 [8320/15000 (55%)] Loss: 1.816630
Train Epoch: 1 [8960/15000 (60%)] Loss: 1.799785
Train Epoch: 1 [9600/15000 (64%)] Loss: 1.677242
Train Epoch: 1 [10240/15000 (68%)] Loss: 1.540448
Train Epoch: 1 [10880/15000 (72%)] Loss: 1.904726
Train Epoch: 1 [11520/15000 (77%)] Loss: 1.326228
Train Epoch: 1 [12160/15000 (81%)] Loss: 1.266977
Train Epoch: 1 [12800/15000 (85%)] Loss: 1.248409
Train Epoch: 1 [13440/15000 (89%)] Loss: 1.293248
Train Epoch: 1 [14080/15000 (94%)] Loss: 1.274603
Train Epoch: 1 [14720/15000 (98%)] Loss: 1.013117
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.7313, Accuracy: 2043/2500 (82%)
Doing local model validation for collaborator Portland: 0.8172000050544739
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Seattle
Test set: Avg. loss: 2.3319, Accuracy: 272/2500 (11%)
Seattle value of 0.1088000014424324
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 2.352533
Train Epoch: 1 [640/15000 (4%)] Loss: 2.339137
Train Epoch: 1 [1280/15000 (9%)] Loss: 2.271505
Train Epoch: 1 [1920/15000 (13%)] Loss: 2.289989
Train Epoch: 1 [2560/15000 (17%)] Loss: 2.271177
Train Epoch: 1 [3200/15000 (21%)] Loss: 2.246639
Train Epoch: 1 [3840/15000 (26%)] Loss: 2.195283
Train Epoch: 1 [4480/15000 (30%)] Loss: 2.062479
Train Epoch: 1 [5120/15000 (34%)] Loss: 2.093780
Train Epoch: 1 [5760/15000 (38%)] Loss: 2.041380
Train Epoch: 1 [6400/15000 (43%)] Loss: 1.820046
Train Epoch: 1 [7040/15000 (47%)] Loss: 1.836269
Train Epoch: 1 [7680/15000 (51%)] Loss: 1.683574
Train Epoch: 1 [8320/15000 (55%)] Loss: 1.467967
Train Epoch: 1 [8960/15000 (60%)] Loss: 1.540522
Train Epoch: 1 [9600/15000 (64%)] Loss: 1.263291
Train Epoch: 1 [10240/15000 (68%)] Loss: 1.366162
Train Epoch: 1 [10880/15000 (72%)] Loss: 1.164680
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.912429
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.970741
Train Epoch: 1 [12800/15000 (85%)] Loss: 1.132654
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.851275
Train Epoch: 1 [14080/15000 (94%)] Loss: 1.117103
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.931540
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.6495, Accuracy: 2031/2500 (81%)
Doing local model validation for collaborator Seattle: 0.8123999834060669
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Chandler
Test set: Avg. loss: 2.3338, Accuracy: 284/2500 (11%)
Chandler value of 0.1136000007390976
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 2.352149
Train Epoch: 1 [640/15000 (4%)] Loss: 2.302715
Train Epoch: 1 [1280/15000 (9%)] Loss: 2.315893
Train Epoch: 1 [1920/15000 (13%)] Loss: 2.304854
Train Epoch: 1 [2560/15000 (17%)] Loss: 2.304877
Train Epoch: 1 [3200/15000 (21%)] Loss: 2.232794
Train Epoch: 1 [3840/15000 (26%)] Loss: 2.221907
Train Epoch: 1 [4480/15000 (30%)] Loss: 2.163441
Train Epoch: 1 [5120/15000 (34%)] Loss: 2.157472
Train Epoch: 1 [5760/15000 (38%)] Loss: 2.062167
Train Epoch: 1 [6400/15000 (43%)] Loss: 2.074321
Train Epoch: 1 [7040/15000 (47%)] Loss: 2.086485
Train Epoch: 1 [7680/15000 (51%)] Loss: 1.760424
Train Epoch: 1 [8320/15000 (55%)] Loss: 1.859421
Train Epoch: 1 [8960/15000 (60%)] Loss: 1.761246
Train Epoch: 1 [9600/15000 (64%)] Loss: 1.723659
Train Epoch: 1 [10240/15000 (68%)] Loss: 1.343333
Train Epoch: 1 [10880/15000 (72%)] Loss: 1.431239
Train Epoch: 1 [11520/15000 (77%)] Loss: 1.217595
Train Epoch: 1 [12160/15000 (81%)] Loss: 1.334101
Train Epoch: 1 [12800/15000 (85%)] Loss: 1.210616
Train Epoch: 1 [13440/15000 (89%)] Loss: 1.095489
Train Epoch: 1 [14080/15000 (94%)] Loss: 1.244167
Train Epoch: 1 [14720/15000 (98%)] Loss: 1.456573
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.7379, Accuracy: 1977/2500 (79%)
Doing local model validation for collaborator Chandler: 0.7907999753952026
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Bangalore
Test set: Avg. loss: 2.3345, Accuracy: 272/2500 (11%)
Bangalore value of 0.1088000014424324
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 2.357561
Train Epoch: 1 [640/15000 (4%)] Loss: 2.313371
Train Epoch: 1 [1280/15000 (9%)] Loss: 2.287950
Train Epoch: 1 [1920/15000 (13%)] Loss: 2.250711
Train Epoch: 1 [2560/15000 (17%)] Loss: 2.268091
Train Epoch: 1 [3200/15000 (21%)] Loss: 2.169279
Train Epoch: 1 [3840/15000 (26%)] Loss: 2.222546
Train Epoch: 1 [4480/15000 (30%)] Loss: 2.007314
Train Epoch: 1 [5120/15000 (34%)] Loss: 1.917653
Train Epoch: 1 [5760/15000 (38%)] Loss: 1.837887
Train Epoch: 1 [6400/15000 (43%)] Loss: 1.878475
Train Epoch: 1 [7040/15000 (47%)] Loss: 1.594017
Train Epoch: 1 [7680/15000 (51%)] Loss: 1.511708
Train Epoch: 1 [8320/15000 (55%)] Loss: 1.271856
Train Epoch: 1 [8960/15000 (60%)] Loss: 1.558927
Train Epoch: 1 [9600/15000 (64%)] Loss: 1.347723
Train Epoch: 1 [10240/15000 (68%)] Loss: 1.140704
Train Epoch: 1 [10880/15000 (72%)] Loss: 1.230179
Train Epoch: 1 [11520/15000 (77%)] Loss: 1.153878
Train Epoch: 1 [12160/15000 (81%)] Loss: 1.055537
Train Epoch: 1 [12800/15000 (85%)] Loss: 1.085349
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.762103
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.928343
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.936020
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.5911, Accuracy: 2113/2500 (85%)
Doing local model validation for collaborator Bangalore: 0.8452000021934509
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling join
Average aggregated model validation values = 0.11370000056922436
Average training loss = 1.084312453866005
Average local model validation values = 0.8163999915122986
Saving data artifacts for join
Saved data artifacts for join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Portland
Test set: Avg. loss: 0.6913, Accuracy: 2113/2500 (85%)
Portland value of 0.8452000021934509
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 1.038085
Train Epoch: 1 [640/15000 (4%)] Loss: 0.984846
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.956596
Train Epoch: 1 [1920/15000 (13%)] Loss: 1.218905
Train Epoch: 1 [2560/15000 (17%)] Loss: 1.034170
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.982977
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.791037
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.817634
Train Epoch: 1 [5120/15000 (34%)] Loss: 1.143449
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.992079
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.864237
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.905026
Train Epoch: 1 [7680/15000 (51%)] Loss: 1.082687
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.984108
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.872094
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.677046
Train Epoch: 1 [10240/15000 (68%)] Loss: 1.044158
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.805063
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.586559
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.802089
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.601361
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.684089
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.674800
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.822161
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.3770, Accuracy: 2212/2500 (88%)
Doing local model validation for collaborator Portland: 0.8848000168800354
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Seattle
Test set: Avg. loss: 0.6930, Accuracy: 2134/2500 (85%)
Seattle value of 0.853600025177002
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 1.074149
Train Epoch: 1 [640/15000 (4%)] Loss: 0.788044
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.824622
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.708563
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.743329
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.991388
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.698764
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.800052
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.866619
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.776506
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.761863
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.635450
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.523824
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.870733
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.598420
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.530209
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.842757
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.635391
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.490621
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.576472
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.357680
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.738054
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.490220
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.548587
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.3586, Accuracy: 2230/2500 (89%)
Doing local model validation for collaborator Seattle: 0.8920000195503235
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Chandler
Test set: Avg. loss: 0.7033, Accuracy: 2094/2500 (84%)
Chandler value of 0.8375999927520752
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.974749
Train Epoch: 1 [640/15000 (4%)] Loss: 1.142256
Train Epoch: 1 [1280/15000 (9%)] Loss: 1.060130
Train Epoch: 1 [1920/15000 (13%)] Loss: 1.345984
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.989349
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.891025
Train Epoch: 1 [3840/15000 (26%)] Loss: 1.026930
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.817803
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.893464
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.902959
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.776052
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.798137
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.700132
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.609538
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.676106
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.885856
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.794635
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.946624
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.588031
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.673586
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.605498
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.692368
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.727418
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.541666
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.3773, Accuracy: 2221/2500 (89%)
Doing local model validation for collaborator Chandler: 0.8884000182151794
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Bangalore
Test set: Avg. loss: 0.6856, Accuracy: 2127/2500 (85%)
Bangalore value of 0.8507999777793884
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 1.024403
Train Epoch: 1 [640/15000 (4%)] Loss: 0.831721
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.877109
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.689435
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.774114
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.671120
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.744448
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.772162
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.916608
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.591479
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.623087
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.545670
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.513708
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.736596
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.504368
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.795776
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.772787
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.594993
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.508895
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.499484
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.520032
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.492095
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.467968
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.747039
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.3471, Accuracy: 2233/2500 (89%)
Doing local model validation for collaborator Bangalore: 0.8931999802589417
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling join
Average aggregated model validation values = 0.8467999994754791
Average training loss = 0.6648633033037186
Average local model validation values = 0.88960000872612
Saving data artifacts for join
Saved data artifacts for join
Calling end
This is the end of the flow
Saving data artifacts for end
Saved data artifacts for end
/tmp/ipykernel_252106/3655034279.py:59: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
return F.log_softmax(x)
/home/karan/playground/openfl/venv/lib/python3.10/site-packages/torch/nn/_reduction.py:51: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
Now that the flow has completed, let’s get the final model and accuracy
print(f'Sample of the final model weights: {flflow.model.state_dict()["conv1.weight"][0]}')
print(f'\nFinal aggregated model accuracy for {flflow.rounds} rounds of training: {flflow.aggregated_model_accuracy}')
Show code cell output
Sample of the final model weights: tensor([[[ 0.1219, -0.0850, -0.0638, 0.0587, -0.2061],
[ 0.1559, -0.0204, 0.1003, 0.0273, -0.0150],
[ 0.1037, 0.0561, 0.1091, -0.0362, 0.0187],
[ 0.0092, 0.0607, 0.0319, 0.2063, 0.0913],
[-0.0773, -0.1235, -0.0412, -0.0902, -0.0545]]])
Final aggregated model accuracy for 2 rounds of training: 0.8467999994754791
We can get the final model, and all other aggregator attributes after the flow completes. But what if there’s an intermediate model task and its specific output that we want to look at in detail? This is where checkpointing and reuse of Metaflow tooling come in handy.
Let’s make a tweak to the flow object, and run the experiment one more time (we can even use our previous model / optimizer as a base for the experiment)
flflow2 = FederatedFlow(model=flflow.model, optimizer=flflow.optimizer, rounds=2, checkpoint=True)
flflow2.runtime = local_runtime
flflow2.run()
Show code cell output
Created flow FederatedFlow
Calling start
Performing initialization for model
Saving data artifacts for start
Saved data artifacts for start
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Portland
Test set: Avg. loss: 0.3249, Accuracy: 2275/2500 (91%)
Portland value of 0.9100000262260437
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.705955
Train Epoch: 1 [640/15000 (4%)] Loss: 0.617308
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.623395
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.713938
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.714206
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.563812
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.717757
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.394908
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.826978
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.462670
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.698488
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.846376
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.619333
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.575636
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.622939
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.886747
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.665729
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.516920
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.859567
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.466999
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.533711
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.521279
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.658550
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.926817
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2697, Accuracy: 2313/2500 (93%)
Doing local model validation for collaborator Portland: 0.9251999855041504
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Seattle
Test set: Avg. loss: 0.3345, Accuracy: 2265/2500 (91%)
Seattle value of 0.906000018119812
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.390258
Train Epoch: 1 [640/15000 (4%)] Loss: 0.465562
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.622512
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.569061
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.534309
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.609027
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.656029
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.454295
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.427925
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.572590
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.434475
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.433428
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.583645
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.375552
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.558989
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.561380
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.480449
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.548253
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.320670
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.515821
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.382779
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.295870
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.286087
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.469384
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2900, Accuracy: 2293/2500 (92%)
Doing local model validation for collaborator Seattle: 0.9172000288963318
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Chandler
Test set: Avg. loss: 0.3359, Accuracy: 2256/2500 (90%)
Chandler value of 0.902400016784668
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.986641
Train Epoch: 1 [640/15000 (4%)] Loss: 0.487543
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.999929
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.838406
Train Epoch: 1 [2560/15000 (17%)] Loss: 1.006288
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.875594
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.684269
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.751433
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.948535
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.701165
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.605181
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.572135
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.584587
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.677527
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.775974
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.579040
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.779329
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.705849
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.367069
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.578917
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.574488
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.290978
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.499532
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.531283
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2879, Accuracy: 2271/2500 (91%)
Doing local model validation for collaborator Chandler: 0.9083999991416931
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Bangalore
Test set: Avg. loss: 0.3211, Accuracy: 2278/2500 (91%)
Bangalore value of 0.9111999869346619
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.606566
Train Epoch: 1 [640/15000 (4%)] Loss: 0.423930
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.582356
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.404679
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.733127
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.458385
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.461127
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.653782
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.411580
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.520569
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.535583
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.577438
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.449876
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.511897
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.581871
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.644637
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.567783
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.576815
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.605296
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.441371
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.388708
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.354411
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.531725
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.479206
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2704, Accuracy: 2306/2500 (92%)
Doing local model validation for collaborator Bangalore: 0.9223999977111816
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling join
Average aggregated model validation values = 0.9074000120162964
Average training loss = 0.6016728430986404
Average local model validation values = 0.9183000028133392
Saving data artifacts for join
Saved data artifacts for join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Portland
Test set: Avg. loss: 0.2481, Accuracy: 2323/2500 (93%)
Portland value of 0.9291999936103821
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.630017
Train Epoch: 1 [640/15000 (4%)] Loss: 0.503889
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.640408
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.458764
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.736370
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.593804
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.686028
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.798474
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.703022
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.405487
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.453337
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.719088
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.863970
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.461701
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.630442
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.621713
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.644361
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.644697
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.399276
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.469072
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.474146
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.398161
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.599437
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.439608
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2183, Accuracy: 2336/2500 (93%)
Doing local model validation for collaborator Portland: 0.9344000220298767
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Seattle
Test set: Avg. loss: 0.2562, Accuracy: 2317/2500 (93%)
Seattle value of 0.926800012588501
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.385337
Train Epoch: 1 [640/15000 (4%)] Loss: 0.470521
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.459677
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.301743
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.486080
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.476714
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.440658
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.299032
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.578410
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.259214
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.277751
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.336378
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.357706
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.323220
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.347599
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.369618
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.364295
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.492413
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.271388
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.399994
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.334124
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.313602
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.413516
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.430733
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2245, Accuracy: 2336/2500 (93%)
Doing local model validation for collaborator Seattle: 0.9344000220298767
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Chandler
Test set: Avg. loss: 0.2511, Accuracy: 2312/2500 (92%)
Chandler value of 0.9247999787330627
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.502659
Train Epoch: 1 [640/15000 (4%)] Loss: 0.646937
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.521962
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.603257
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.650282
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.574407
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.635170
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.475845
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.528372
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.500761
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.505273
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.738660
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.355279
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.360918
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.712853
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.650161
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.505021
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.459242
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.595233
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.449048
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.530338
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.592250
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.518594
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.716185
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2376, Accuracy: 2323/2500 (93%)
Doing local model validation for collaborator Chandler: 0.9291999936103821
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling aggregated_model_validation
Performing aggregated model validation for collaborator Bangalore
Test set: Avg. loss: 0.2437, Accuracy: 2327/2500 (93%)
Bangalore value of 0.9308000206947327
Saving data artifacts for aggregated_model_validation
Saved data artifacts for aggregated_model_validation
Calling train
Train Epoch: 1 [0/15000 (0%)] Loss: 0.373329
Train Epoch: 1 [640/15000 (4%)] Loss: 0.367368
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.246474
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.561947
Train Epoch: 1 [2560/15000 (17%)] Loss: 0.213358
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.347174
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.427229
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.467920
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.509551
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.502692
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.362033
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.366702
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.621961
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.473972
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.648961
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.290578
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.334747
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.323814
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.343845
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.341860
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.212641
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.160580
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.487189
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.377483
Saving data artifacts for train
Saved data artifacts for train
Calling local_model_validation
Test set: Avg. loss: 0.2261, Accuracy: 2335/2500 (93%)
Doing local model validation for collaborator Bangalore: 0.9340000152587891
Saving data artifacts for local_model_validation
Saved data artifacts for local_model_validation
Should transfer from local_model_validation to join
Calling join
Average aggregated model validation values = 0.9279000014066696
Average training loss = 0.491002157330513
Average local model validation values = 0.9330000132322311
Saving data artifacts for join
Saved data artifacts for join
Calling end
This is the end of the flow
Saving data artifacts for end
Saved data artifacts for end
/tmp/ipykernel_252106/3655034279.py:59: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.
return F.log_softmax(x)
Now that the flow is complete, let’s dig into some of the information captured along the way
run_id = flflow2._run_id
from metaflow import Metaflow, Flow, Task, Step
m = Metaflow()
list(m)
Show code cell output
[Flow('FederatedFlow')]
For existing users of Metaflow, you’ll notice this is the same way you would examine a flow after completion. Let’s look at the latest run that generated some results:
f = Flow('FederatedFlow').latest_run
f
Show code cell output
Run('FederatedFlow/1734542077639071')
And its list of steps
list(f)
Show code cell output
[Step('FederatedFlow/1734542077639071/end'),
Step('FederatedFlow/1734542077639071/join'),
Step('FederatedFlow/1734542077639071/local_model_validation'),
Step('FederatedFlow/1734542077639071/train'),
Step('FederatedFlow/1734542077639071/aggregated_model_validation'),
Step('FederatedFlow/1734542077639071/start')]
This matches the list of steps executed in the flow, so far so good…
s = Step(f'FederatedFlow/{run_id}/train')
s
Show code cell output
Step('FederatedFlow/1734542077639071/train')
list(s)
Show code cell output
[Task('FederatedFlow/1734542077639071/train/25'),
Task('FederatedFlow/1734542077639071/train/22'),
Task('FederatedFlow/1734542077639071/train/19'),
Task('FederatedFlow/1734542077639071/train/16'),
Task('FederatedFlow/1734542077639071/train/12'),
Task('FederatedFlow/1734542077639071/train/9'),
Task('FederatedFlow/1734542077639071/train/6'),
Task('FederatedFlow/1734542077639071/train/3')]
Now we see 12 steps: 4 collaborators each performed 3 rounds of model training
t = Task(f'FederatedFlow/{run_id}/train/9')
t
Show code cell output
Task('FederatedFlow/1734542077639071/train/9')
Now let’s look at the data artifacts this task generated
t.data
Show code cell output
<MetaflowData: loss, collaborators, execute_task_args, test_loader, checkpoint, model, training_completed, input, train_loader, rounds, agg_validation_score, optimizer, current_round>
t.data.input
Show code cell output
'Chandler'
Now let’s look at its log output (stdout)
print(t.stdout)
Show code cell output
Train Epoch: 1 [0/15000 (0%)] Loss: 0.986641
Train Epoch: 1 [640/15000 (4%)] Loss: 0.487543
Train Epoch: 1 [1280/15000 (9%)] Loss: 0.999929
Train Epoch: 1 [1920/15000 (13%)] Loss: 0.838406
Train Epoch: 1 [2560/15000 (17%)] Loss: 1.006288
Train Epoch: 1 [3200/15000 (21%)] Loss: 0.875594
Train Epoch: 1 [3840/15000 (26%)] Loss: 0.684269
Train Epoch: 1 [4480/15000 (30%)] Loss: 0.751433
Train Epoch: 1 [5120/15000 (34%)] Loss: 0.948535
Train Epoch: 1 [5760/15000 (38%)] Loss: 0.701165
Train Epoch: 1 [6400/15000 (43%)] Loss: 0.605181
Train Epoch: 1 [7040/15000 (47%)] Loss: 0.572135
Train Epoch: 1 [7680/15000 (51%)] Loss: 0.584587
Train Epoch: 1 [8320/15000 (55%)] Loss: 0.677527
Train Epoch: 1 [8960/15000 (60%)] Loss: 0.775974
Train Epoch: 1 [9600/15000 (64%)] Loss: 0.579040
Train Epoch: 1 [10240/15000 (68%)] Loss: 0.779329
Train Epoch: 1 [10880/15000 (72%)] Loss: 0.705849
Train Epoch: 1 [11520/15000 (77%)] Loss: 0.367069
Train Epoch: 1 [12160/15000 (81%)] Loss: 0.578917
Train Epoch: 1 [12800/15000 (85%)] Loss: 0.574488
Train Epoch: 1 [13440/15000 (89%)] Loss: 0.290978
Train Epoch: 1 [14080/15000 (94%)] Loss: 0.499532
Train Epoch: 1 [14720/15000 (98%)] Loss: 0.531283
And any error logs? (stderr)
print(t.stderr)
Show code cell output
Congratulations!#
Now that you’ve completed your first workflow interface quickstart notebook, see some of the more advanced things you can do in our other tutorials, including:
Using the LocalRuntime Ray Backend for dedicated GPU access
Vertical Federated Learning
Model Watermarking
Differential Privacy
And More!