import ast
import inspect
import queue
import random
import time
import uuid
from typing import TYPE_CHECKING, TypedDict
from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator
from fedn.network.combiner.hooks.grpc_wrappers import call_with_fallback
from fedn.network.combiner.hooks.hook_client import CombinerHookInterface
from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctions
from fedn.network.combiner.modelservice import ModelService, serialize_model_to_BytesIO
from fedn.network.combiner.updatehandler import UpdateHandler
from fedn.network.storage.s3.repository import Repository
from fedn.utils.helpers.helpers import get_helper
from fedn.utils.parameters import Parameters
# This if is needed to avoid circular imports but is crucial for type hints.
if TYPE_CHECKING:
from fedn.network.combiner.combiner import Combiner # not-floating-import
[docs]
class RoundConfig(TypedDict):
"""Round configuration.
:param _job_id: A universally unique identifier for the round. Set by Combiner.
:type _job_id: str
:param committed_at: The time the round was committed. Set by Controller.
:type committed_at: str
:param task: The task to perform in the round. Set by Controller. Supported tasks are "training", "validation", and "prediction".
:type task: str
:param round_id: The round identifier as str(int)
:type round_id: str
:param round_timeout: The round timeout in seconds. Set by user interfaces or Controller.
:type round_timeout: str
:param rounds: The number of rounds. Set by user interfaces.
:param model_id: The model identifier. Set by user interfaces or Controller.
:type model_id: str
:param model_version: The model version. Currently not used.
:type model_version: str
:param model_type: The model type. Currently not used.
:type model_type: str
:param model_size: The size of the model. Currently not used.
:type model_size: int
:param model_parameters: The model parameters. Currently not used.
:type model_parameters: dict
:param model_metadata: The model metadata. Currently not used.
:type model_metadata: dict
:param session_id: The session identifier. Set by (Controller?).
:type session_id: str
:param prediction_id: The prediction identifier. Only used for prediction tasks.
:type prediction_id: str
:param helper_type: The helper type.
:type helper_type: str
:param aggregator: The aggregator type.
:type aggregator: str
:param client_settings: Settings that are distributed to clients.
:type client_settings: dict
:param selected_clients: List of client ids to participate in the round
:type selected_clients: list[str]
"""
_job_id: str
committed_at: str
task: str
round_id: str
round_timeout: str
rounds: int
model_id: str
model_version: str
model_type: str
model_size: int
model_parameters: dict
model_metadata: dict
session_id: str
helper_type: str
aggregator: str
client_settings: dict
selected_clients: list[str]
[docs]
class RoundHandler:
"""Round handler.
The round handler processes requests from the global controller
to produce model updates and perform model validations.
:param aggregator_name: The name of the aggregator plugin module.
:type aggregator_name: str
:param storage: Model repository for :class: `fedn.network.combiner.Combiner`
:type storage: class: `fedn.common.storage.s3.s3repo.S3ModelRepository`
:param server: A handle to the Combiner class :class: `fedn.network.combiner.Combiner`
:type server: class: `fedn.network.combiner.Combiner`
:param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService`
:type modelservice: class: `fedn.network.combiner.modelservice.ModelService`
"""
[docs]
def __init__(self, server: "Combiner", repository: Repository, modelservice: ModelService):
"""Initialize the RoundHandler."""
self.round_configs = queue.Queue()
self.storage = repository
self.server = server
self.modelservice = modelservice
self.server_functions = inspect.getsource(ServerFunctions)
self.update_handler = UpdateHandler(modelservice=modelservice)
self.hook_interface = CombinerHookInterface()
[docs]
def set_aggregator(self, aggregator):
self.aggregator = get_aggregator(aggregator, self.update_handler)
[docs]
def set_server_functions(self, server_functions: str):
self.server_functions = server_functions
[docs]
def push_round_config(self, round_config: RoundConfig) -> str:
"""Add a round_config (job description) to the inbox.
:param round_config: A dict containing the round configuration (from global controller).
:type round_config: dict
:return: A job id (universally unique identifier) for the round.
:rtype: str
"""
try:
round_config["_job_id"] = str(uuid.uuid4())
self.round_configs.put(round_config)
except Exception:
logger.error("Failed to push round config.")
raise
return round_config["_job_id"]
def _training_round(self, config: dict, clients: list, provided_functions: dict):
"""Send model update requests to clients and aggregate results.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to participate in the training round
:type clients: list
:return: an aggregated model and associated metadata
:rtype: model, dict
"""
logger.info("ROUNDHANDLER: Initiating training round, participating clients: {}".format(clients))
meta = {}
meta["nr_expected_updates"] = len(clients)
meta["nr_required_updates"] = int(config["clients_required"])
meta["timeout"] = float(config["round_timeout"])
session_id = config["session_id"]
model_id = config["model_id"]
if provided_functions.get("client_settings", False):
global_model_bytes = self.modelservice.temp_model_storage.get(model_id)
def _rpc():
return self.hook_interface.client_settings(global_model_bytes)
def _fallback():
return {}
client_settings = call_with_fallback("client_settings", _rpc, fallback_fn=_fallback) or {}
config["client_settings"] = {**config.get("client_settings", {}), **client_settings}
# Request model updates from all active clients.
self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients)
# If buffer_size is -1 (default), the round terminates when/if all clients have completed
if int(config["buffer_size"]) == -1:
buffer_size = len(clients)
else:
buffer_size = int(config["buffer_size"])
# Wait / block until the round termination policy has been met.
self.update_handler.waitforit(config, buffer_size=buffer_size)
tic = time.time()
model = None
data = None
try:
helper = get_helper(config["helper_type"])
logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"]))
if config["delete_models_storage"] == "True":
delete_models = True
else:
delete_models = False
if "aggregator_kwargs" in config.keys():
dict_parameters = ast.literal_eval(config["aggregator_kwargs"])
parameters = Parameters(dict_parameters)
else:
parameters = None
if provided_functions.get("aggregate", False) or provided_functions.get("incremental_aggregate", False):
previous_model_bytes = self.modelservice.temp_model_storage.get(model_id)
def _rpc():
return self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models)
def _fallback():
return self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters)
model, data = call_with_fallback("aggregate", _rpc, fallback_fn=_fallback)
else:
model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters)
except Exception as e:
logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e))
raise
meta["time_combination"] = time.time() - tic
meta["aggregation_time"] = data
return model, meta
def _validation_round(self, session_id, model_id, clients):
"""Send model validation requests to clients.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to send validation requests to
:type clients: list
:param model_id: The ID of the model to validate
:type model_id: str
"""
self.server.request_model_validation(session_id, model_id, clients=clients)
def _prediction_round(self, prediction_id: str, model_id: str, clients: list):
"""Send model prediction requests to clients.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to send prediction requests to
:type clients: list
:param model_id: The ID of the model to use for prediction
:type model_id: str
"""
self.server.request_model_prediction(prediction_id, model_id, clients=clients)
def _forward_pass(self, config: dict, clients: list):
"""Send model forward pass requests to clients.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to participate in the training round
:type clients: list
:return: aggregated embeddings and associated metadata
:rtype: model, dict
"""
logger.info("ROUNDHANDLER: Initiating forward pass, participating clients: {}".format(clients))
meta = {}
meta["nr_expected_updates"] = len(clients)
meta["nr_required_updates"] = int(config["clients_required"])
meta["timeout"] = float(config["round_timeout"])
session_id = config["session_id"]
model_id = config["model_id"]
is_sl_inference = config[
"is_sl_inference"
] # determines whether forward pass calculates gradients ("training"), or is used for inference (e.g., for validation)
# Request forward pass from all active clients.
self.server.request_forward_pass(session_id=session_id, model_id=model_id, config=config, clients=clients)
# the round should terminate when all clients have completed
buffer_size = len(clients)
# Wait / block until the round termination policy has been met.
self.update_handler.waitforit(config, buffer_size=buffer_size)
tic = time.time()
output = None
try:
helper = get_helper(config["helper_type"])
logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"]))
if config["delete_models_storage"] == "True":
delete_models = True
else:
delete_models = False
output = self.aggregator.combine_models(helper=helper, delete_models=delete_models, is_sl_inference=is_sl_inference)
except Exception as e:
logger.warning("EMBEDDING CONCATENATION in FORWARD PASS FAILED AT COMBINER! {}".format(e))
meta["time_combination"] = time.time() - tic
meta["aggregation_time"] = output["data"]
return output, meta
def _backward_pass(self, config: dict, clients: list):
"""Send backward pass requests to clients.
:param config: The round config object (passed to the client).
:type config: dict
:param clients: clients to participate in the training round
:type clients: list
:return: associated metadata
:rtype: dict
"""
logger.info("ROUNDHANDLER: Initiating backward pass, participating clients: {}".format(clients))
meta = {}
meta["nr_expected_updates"] = len(clients)
meta["nr_required_updates"] = int(config["clients_required"])
meta["timeout"] = float(config["round_timeout"])
# Clear previous backward completions queue
self.update_handler.clear_backward_completions()
# Request backward pass from all active clients.
logger.info("ROUNDHANDLER: Requesting backward pass, gradient_id: {}".format(config["model_id"]))
self.server.request_backward_pass(session_id=config["session_id"], gradient_id=config["model_id"], config=config, clients=clients)
# the round should terminate when all clients have completed
buffer_size = len(clients)
self.update_handler.waitforbackwardcompletion(config, required_backward_completions=buffer_size)
return meta
[docs]
def stage_model(self, model_id, timeout_retry=3, retry=2):
"""Download a model from persistent storage and set in modelservice.
:param model_id: ID of the model update object to stage.
:type model_id: str
:param timeout_retry: Sleep before retrying download again(sec), defaults to 3
:type timeout_retry: int, optional
:param retry: Number of retries, defaults to 2
:type retry: int, optional
"""
# If the model is already in memory at the server we do not need to do anything.
if self.modelservice.temp_model_storage.exist(model_id):
logger.info("Model already exists in memory, skipping model staging.")
return
logger.info("Model Staging, fetching model from storage...")
# If not, download it and stage it in memory at the combiner.
tries = 0
while True:
try:
model = self.storage.get_model_stream(model_id)
if model:
break
except Exception:
logger.warning("Could not fetch model from storage backend, retrying.")
time.sleep(timeout_retry)
tries += 1
if tries > retry:
logger.error("Failed to stage model {} from storage backend!".format(model_id))
raise
self.modelservice.set_model(model, model_id)
def _assign_round_clients(self, n: int, type: str = "trainers", selected_clients: list = None):
"""Obtain a list of clients(trainers or validators) to ask for updates in this round.
:param n: Size of a random set taken from active trainers(clients), if n > "active trainers" all is used
:type n: int
:param type: type of clients, either "trainers" or "validators", defaults to "trainers"
:type type: str, optional
:return: Set of clients
:rtype: list
"""
if type == "validators":
clients = self.server.get_active_validators()
elif type == "trainers":
clients = self.server.get_active_trainers()
else:
logger.error("(ERROR): {} is not a supported type of client".format(type))
if selected_clients is not None and len(selected_clients) > 0:
clients = [client for client in clients if client in selected_clients]
# If the number of requested trainers exceeds the number of available, use all available.
n = min(n, len(clients))
# If not, we pick a random subsample of all available clients.
clients = random.sample(clients, n)
return clients
def _check_nr_round_clients(self, config):
"""Check that the minimal number of clients required to start a round are available.
:param config: The round config object.
:type config: dict
:param timeout: Timeout in seconds, defaults to 0.0
:type timeout: float, optional
:return: True if the required number of clients are available, False otherwise.
:rtype: bool
"""
active = self.server.nr_active_trainers()
if active >= int(config["clients_required"]):
logger.info("Number of clients required ({0}) to start round met {1}.".format(config["clients_required"], active))
return True
else:
logger.info("Too few clients to start round.")
return False
[docs]
def execute_validation_round(self, session_id, model_id):
"""Coordinate validation rounds as specified in config.
:param round_config: The round config object.
:type round_config: dict
"""
logger.info("COMBINER orchestrating validation of model {}".format(model_id))
self.stage_model(model_id)
validators = self._assign_round_clients(self.server.max_clients, type="validators")
self._validation_round(session_id, model_id, validators)
[docs]
def execute_prediction_round(self, prediction_id: str, model_id: str) -> None:
"""Coordinate prediction rounds as specified in config.
:param round_config: The round config object.
:type round_config: dict
"""
logger.info("COMBINER orchestrating prediction using model {}".format(model_id))
self.stage_model(model_id)
# TODO: Implement prediction client type
clients = self._assign_round_clients(self.server.max_clients, type="validators")
self._prediction_round(prediction_id, model_id, clients)
[docs]
def execute_training_round(self, config):
"""Coordinates clients to execute training tasks.
:param config: The round config object.
:type config: dict
:return: metadata about the training round.
:rtype: dict
"""
logger.info("Processing training round, job_id {}".format(config["_job_id"]))
data = {}
data["config"] = config
data["round_id"] = config["round_id"]
# Download model to update and set in temp storage.
self.stage_model(config["model_id"])
# dictionary to which functions are provided
try:
provided_functions = self.hook_interface.provided_functions(self.server_functions)
except Exception:
provided_functions = {"client_selection": False, "client_settings": False, "aggregate": False, "incremental_aggregate": False}
if provided_functions.get("client_selection", False):
def _rpc():
return self.hook_interface.client_selection(clients=self.server.get_active_trainers())
def _fallback():
selected_clients = config["selected_clients"] if "selected_clients" in config and len(config["selected_clients"]) > 0 else None
return self._assign_round_clients(n=self.server.max_clients, type="trainers", selected_clients=selected_clients)
clients = call_with_fallback("client_selection", _rpc, fallback_fn=_fallback)
if not clients:
# Empty selection => fallback immediately (don't spin forever)
clients = _fallback()
else:
selected_clients = config["selected_clients"] if "selected_clients" in config and len(config["selected_clients"]) > 0 else None
clients = self._assign_round_clients(n=self.server.max_clients, type="trainers", selected_clients=selected_clients)
model, meta = self._training_round(config, clients, provided_functions)
data["data"] = meta
if model is None:
logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"]))
if model is not None:
helper = get_helper(config["helper_type"])
a = serialize_model_to_BytesIO(model, helper)
model_id = self.storage.set_model(a.read(), is_file=False)
a.close()
data["model_id"] = model_id
logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"]))
# Delete temp model
self.modelservice.temp_model_storage.delete(config["model_id"])
return data
[docs]
def execute_forward_pass(self, config):
"""Coordinates clients to execute forward pass.
:param config: The round config object.
:type config: dict
:return: metadata about the training round.
:rtype: dict
"""
logger.info("Processing forward pass, job_id {}".format(config["_job_id"]))
data = {}
data["config"] = config
data["round_id"] = config["round_id"]
data["model_id"] = None
clients = self._assign_round_clients(self.server.max_clients)
output, meta = self._forward_pass(config, clients)
data["data"] = meta
if output["gradients"] is None and output["validation_data"] is None:
logger.warning("\t Forward pass failed in round {0}!".format(config["round_id"]))
elif output["validation_data"] is not None: # in forward validation pass, no gradients are calculated. Skip in this case.
logger.info("FORWARD VALIDATION PASS COMPLETED. Job id: {}".format(config["_job_id"]))
return data
elif output["gradients"] is not None:
gradients = output["gradients"]
helper = get_helper(config["helper_type"])
a = serialize_model_to_BytesIO(gradients, helper)
gradient_id = self.storage.set_model(a.read(), is_file=False) # uploads gradients to storage
a.close()
data["model_id"] = gradient_id # intended
logger.info("FORWARD PASS COMPLETED. Aggregated model id: {}, Job id: {}".format(gradient_id, config["_job_id"]))
return data
[docs]
def execute_backward_pass(self, config):
"""Coordinates clients to execute backward pass.
:param config: The round config object.
:type config: dict
:return: metadata about the training round.
:rtype: dict
"""
logger.info("Processing backward pass, job_id {}".format(config["_job_id"]))
data = {}
data["config"] = config
data["round_id"] = config["round_id"]
logger.info("roundhandler execute_backward_pass: downloading gradients with id: {}".format(config["model_id"]))
# Download gradients and set in temp storage.
self.stage_model(config["model_id"]) # Download a model from persistent storage and set in modelservice
clients = self._assign_round_clients(self.server.max_clients)
meta = self._backward_pass(config, clients)
data["data"] = meta
if meta is None:
logger.warning("\t Failed to run backward pass in round {0}!".format(config["round_id"]))
# Delete temp model
self.modelservice.temp_model_storage.delete(config["model_id"])
return data
[docs]
def run(self, polling_interval=1.0):
"""Main control loop. Execute rounds based on round config on the queue.
:param polling_interval: The polling interval in seconds for checking if a new job/config is available.
:type polling_interval: float
"""
try:
while True:
try:
round_config = self.round_configs.get(block=False)
# Check that the minimum allowed number of clients are connected
ready = self._check_nr_round_clients(round_config)
round_meta = {}
if ready:
if round_config["task"] == "training":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
tic = time.time()
round_meta = self.execute_training_round(round_config)
round_meta["time_exec_training"] = time.time() - tic
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
active_round = self.server.db.round_store.get(round_meta["round_id"])
active_round.combiners.append(round_meta)
try:
self.server.db.round_store.update(active_round)
except Exception as e:
logger.error("Failed to update round data in round store. {}".format(e))
raise Exception("Failed to update round data in round store.")
elif round_config["task"] == "validation":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
self.execute_validation_round(session_id, model_id)
elif round_config["task"] == "prediction":
prediction_id = round_config["prediction_id"]
model_id = round_config["model_id"]
self.execute_prediction_round(prediction_id, model_id)
elif round_config["task"] == "forward":
session_id = round_config["session_id"]
model_id = round_config["model_id"]
tic = time.time()
round_meta = self.execute_forward_pass(round_config)
round_meta["time_exec_training"] = time.time() - tic
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
active_round = self.server.db.round_store.get(round_meta["round_id"]) # if "combiners" not in active_round:
# active_round["combiners"] = []
# active_round["combiners"].append(round_meta)
# updated = self.server.round_store.update(active_round["id"], active_round)
active_round.combiners.append(round_meta)
try:
self.server.db.round_store.update(active_round)
except Exception as e:
logger.error("Forward pass: Failed to update round data in round store. {}".format(e))
raise Exception("Forward passFailed to update round data in round store.")
elif round_config["task"] == "backward":
tic = time.time()
round_meta = self.execute_backward_pass(round_config)
round_meta["time_exec_training"] = time.time() - tic
round_meta["status"] = "Success"
round_meta["name"] = self.server.id
active_round = self.server.db.round_store.get(round_meta["round_id"])
# updated = self.server.round_store.update(active_round["id"], active_round)
active_round.combiners.append(round_meta)
try:
self.server.db.round_store.update(active_round)
except Exception as e:
logger.error("Backward pass: Failed to update round data in round store. {}".format(e))
raise Exception("Backward pass: Failed to update round data in round store.")
else:
logger.warning("config contains unkown task type.")
else:
round_meta = {}
round_meta["status"] = "Failed"
round_meta["reason"] = "Failed to meet client allocation requirements for this round config."
logger.warning("{0}".format(round_meta["reason"]))
self.round_configs.task_done()
except queue.Empty:
time.sleep(polling_interval)
except (KeyboardInterrupt, SystemExit):
pass