Source code for fedn.network.combiner.hooks.hook_client

import json
import os

import grpc

import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger
from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model
from fedn.network.combiner.updatehandler import UpdateHandler

CHUNK_SIZE = 1024 * 1024
# for quick functions
TIMEOUT_SHORT = 120
# for functions which might take longer such as aggregation
TIMEOUT_LONG = 600


[docs] class CombinerHookInterface: """Combiner to server function hooks client."""
[docs] def __init__(self): """Initialize CombinerHookInterface client.""" self.hook_service_host = os.getenv("HOOK_SERVICE_HOST", "hook:12081") self.channel = grpc.insecure_channel( self.hook_service_host, options=[ ("grpc.keepalive_time_ms", 300_000), # 5 min ping interval ("grpc.keepalive_timeout_ms", 5000), # 5 seconds timeout for a response ("grpc.keepalive_permit_without_calls", 1), # allow keepalives even with no active calls ("grpc.enable_retries", 1), # automatic retries ("grpc.initial_reconnect_backoff_ms", 1000), # initial delay before retrying ("grpc.max_reconnect_backoff_ms", 5000), # maximum delay before retrying ], ) self.stub = rpc.FunctionServiceStub(self.channel)
[docs] def provided_functions(self, server_functions: str): """Communicates to hook container and asks which functions are available. :param server_functions: String version of an implementation of the ServerFunctionsBase interface. :type server_functions: :str: :return: dictionary specifing which functions are implemented. :rtype: dict """ try: request = fedn.ProvidedFunctionsRequest(function_code=server_functions) response = self.stub.HandleProvidedFunctions(request, timeout=TIMEOUT_SHORT) return response.available_functions except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.UNAVAILABLE: logger.warning(f"Server-functions container is unavailable; using default implementations: {rpc_error}") else: logger.error(f"gRPC error: {rpc_error.code().name} - {rpc_error.details()}") return {} except Exception as e: logger.error(f"Unexpected error communicating with hooks container: {e}") return {}
[docs] def client_settings(self, global_model) -> dict: """Communicates to hook container to get a client config. :param global_model: The global model that will be distributed to clients. :type global_model: :bytes: :return: config that will be distributed to clients. :rtype: dict """ request_function = fedn.ClientConfigRequest args = {} model = model_as_bytesIO(global_model) response = self.stub.HandleClientConfig(bytesIO_request_generator(mdl=model, request_function=request_function, args=args), timeout=TIMEOUT_SHORT) return json.loads(response.client_settings)
[docs] def client_selection(self, clients: list) -> list: request = fedn.ClientSelectionRequest(client_ids=json.dumps(clients)) response = self.stub.HandleClientSelection(request) return json.loads(response.client_ids)
[docs] def aggregate(self, previous_global, update_handler: UpdateHandler, helper, delete_models: bool): """Aggregation call to the hook functions. Sends models in chunks, then asks for aggregation. :param global_model: The global model that will be distributed to clients. :type global_model: :bytes: :return: config that will be distributed to clients. :rtype: dict """ data = {} data["time_model_load"] = 0.0 data["time_model_aggregation"] = 0.0 # send previous global request_function = fedn.StoreModelRequest args = {"id": "global_model"} response = self.stub.HandleStoreModel( bytesIO_request_generator(mdl=previous_global, request_function=request_function, args=args), timeout=TIMEOUT_SHORT ) logger.info(f"Store model response: {response.status}") # send client models and metadata nr_updates = 0 while not update_handler.model_updates.empty(): logger.info("Getting next model update from queue.") update = update_handler.next_model_update() metadata = json.loads(update.meta)["training_metadata"] model = update_handler.load_model_update_bytesIO(update.model_update_id) # send metadata client_id = update.sender.client_id request = fedn.ClientMetaRequest(metadata=json.dumps(metadata), client_id=client_id) response = self.stub.HandleMetadata(request, timeout=TIMEOUT_SHORT) # send client model args = {"id": client_id} request_function = fedn.StoreModelRequest response = self.stub.HandleStoreModel(bytesIO_request_generator(mdl=model, request_function=request_function, args=args), timeout=TIMEOUT_SHORT) logger.info(f"Store model response: {response.status}") nr_updates += 1 if delete_models: # delete model from disk update_handler.delete_model(model_update=update) # ask for aggregation request = fedn.AggregationRequest(aggregate="aggregate") response_generator = self.stub.HandleAggregation(request, timeout=TIMEOUT_LONG) data["nr_aggregated_models"] = nr_updates model, _ = unpack_model(response_generator, helper) return model, data