Source code for fedn.network.combiner.updatehandler

import json
import queue
import sys
import time
import traceback

from fedn.common.log_config import logger
from fedn.network.combiner.modelservice import ModelService, load_model_from_bytes


[docs] class ModelUpdateError(Exception): pass
[docs] class UpdateHandler: """Update handler. Responsible for receiving, loading and supplying client model updates. :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` """
[docs] def __init__(self, modelservice: ModelService) -> None: self.model_updates = queue.Queue() self.backward_completions = queue.Queue() self.modelservice = modelservice self.model_id_to_model_data = {}
[docs] def delete_model(self, model_update): self.modelservice.temp_model_storage.delete(model_update.model_update_id) logger.info("UPDATE HANDLER: Deleted model update {} from storage.".format(model_update.model_update_id))
[docs] def next_model_update(self): """Get the next model update from the queue. :param helper: A helper object. :type helper: object :return: The model update. :rtype: fedn.network.grpc.fedn.proto.ModelUpdate """ model_update = self.model_updates.get(block=False) return model_update
[docs] def on_model_update(self, model_update): """Callback when a new client model update is recieved. Performs (optional) validation and pre-processing, and then puts the update id on the aggregation queue. Override in subclass as needed. :param model_update: fedn.network.grpc.fedn.proto.ModelUpdate :type model_id: str """ try: logger.info("UPDATE HANDLER: callback received model update {}".format(model_update.model_update_id)) # Validate the update and metadata valid_update = self._validate_model_update(model_update) if valid_update: # Push the model update to the processing queue self.model_updates.put(model_update) else: logger.warning("UPDATE HANDLER: Invalid model update, skipping.") except Exception as e: tb = traceback.format_exc() logger.error("UPDATE HANDLER: failed to receive model update: {}".format(e)) logger.error(tb) pass
def _validate_model_update(self, model_update): """Validate the model update. :param model_update: A ModelUpdate message. :type model_update: object :return: True if the model update is valid, False otherwise. :rtype: bool """ try: data = json.loads(model_update.meta)["training_metadata"] _ = data["num_examples"] except KeyError: tb = traceback.format_exc() logger.error("UPDATE HANDLER: Invalid model update, missing metadata.") logger.error(tb) return False return True
[docs] def load_model_update(self, model_update, helper): """Load the memory representation of the model update. Load the model update paramters and the associate metadata into memory. :param model_update: The model update. :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate :param helper: A helper object. :type helper: fedn.utils.helpers.helperbase.Helper :return: A tuple of (parameters, metadata) :rtype: tuple """ model_id = model_update.model_update_id model = self.load_model(helper, model_id) # Get relevant metadata metadata = json.loads(model_update.meta) if "config" in metadata.keys(): # Used in Python client config = json.loads(metadata["config"]) else: # Used in C++ client config = json.loads(model_update.config) training_metadata = metadata["training_metadata"] if "round_id" in config: training_metadata["round_id"] = config["round_id"] return model, training_metadata
[docs] def load_model_update_byte(self, model_update): """Load the memory representation of the model update. Load the model update paramters and the associate metadata into memory. :param model_update: The model update. :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate :return: A tuple of parameters(bytes), metadata :rtype: tuple """ model_id = model_update.model_update_id model = self.load_model_update_bytesIO(model_id).getbuffer() # Get relevant metadata metadata = json.loads(model_update.meta) if "config" in metadata.keys(): # Used in Python client config = json.loads(metadata["config"]) else: # Used in C++ client config = json.loads(model_update.config) training_metadata = metadata["training_metadata"] if "round_id" in config: training_metadata["round_id"] = config["round_id"] return model, training_metadata
[docs] def load_model(self, helper, model_id): """Load model update with id model_id into its memory representation. :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase` :type helper: class: `fedn.utils.helpers.helpers.HelperBase` :param model_id: The ID of the model update, UUID in str format :type model_id: str """ model_bytesIO = self.load_model_update_bytesIO(model_id) if model_bytesIO: try: model = load_model_from_bytes(model_bytesIO.getbuffer(), helper) except IOError: logger.warning("UPDATE HANDLER: Failed to load model!") else: raise ModelUpdateError("Failed to load model.") return model
[docs] def load_model_update_bytesIO(self, model_id, retry=3): """Load model update object and return it as BytesIO. :param model_id: The ID of the model :type model_id: str :param retry: number of times retrying load model update, defaults to 3 :type retry: int, optional :return: Updated model :rtype: class: `io.BytesIO` """ # Try reading model update from local disk/combiner memory model_str = self.modelservice.temp_model_storage.get(model_id) # And if we cannot access that, try downloading from the server if model_str is None: model_str = self.modelservice.get_model(model_id) # TODO: use retrying library tries = 0 while tries < retry: tries += 1 if not model_str or sys.getsizeof(model_str) == 80: logger.warning("Model download failed. retrying") time.sleep(3) # sleep longer model_str = self.modelservice.get_model(model_id) return model_str
[docs] def waitforit(self, config, buffer_size=100, polling_interval=0.1): """Defines the policy for how long the server should wait before starting to aggregate models. The policy is as follows: 1. Wait a maximum of time_window time until the round times out. 2. Terminate if a preset number of model updates (buffer_size) are in the queue. :param config: The round config object :type config: dict :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 :type buffer_size: int, optional :param polling_interval: The polling interval, defaults to 0.1 :type polling_interval: float, optional """ time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: if self.model_updates.qsize() >= buffer_size: break time.sleep(polling_interval) tt += polling_interval
[docs] def waitforbackwardcompletion(self, config, required_backward_completions=-1, polling_interval=0.1): """Wait for backward completion messages. :param config: The round config object :param required_backward_completions: Number of required backward completions """ time_window = float(config["round_timeout"]) tt = 0.0 while tt < time_window: if self.backward_completions.qsize() >= required_backward_completions: break time.sleep(polling_interval) tt += polling_interval
[docs] def clear_backward_completions(self): """Clear the backward completions queue.""" while not self.backward_completions.empty(): try: self.backward_completions.get_nowait() except queue.Empty: break