Source code for fedn.network.clients.client_v2

"""Client module for handling client operations in the FEDn network."""

import io
import json
import os
import time
import uuid
from io import BytesIO
from typing import Dict, Optional, Tuple

from fedn.common.config import FEDN_CUSTOM_URL_PREFIX
from fedn.common.log_config import logger
from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions
from fedn.network.combiner.modelservice import get_tmp_path
from fedn.utils.helpers.helpers import get_helper, save_metadata


[docs] def get_url(api_url: str, api_port: int) -> str: """Construct the URL for the API.""" return f"{api_url}:{api_port}/{FEDN_CUSTOM_URL_PREFIX}" if api_port else f"{api_url}/{FEDN_CUSTOM_URL_PREFIX}"
[docs] class ClientOptions: """Options for configuring the client."""
[docs] def __init__(self, name: str, package: str, preferred_combiner: Optional[str] = None, id: Optional[str] = None) -> None: """Initialize ClientOptions with validation.""" self._validate(name, package) self.name = name self.package = package self.preferred_combiner = preferred_combiner self.client_id = id if id else str(uuid.uuid4())
def _validate(self, name: str, package: str) -> None: """Validate the name and package.""" if not isinstance(name, str) or len(name) == 0: raise ValueError("Name must be a string") if not isinstance(package, str) or len(package) == 0 or package not in ["local", "remote"]: raise ValueError("Package must be either 'local' or 'remote'")
[docs] def to_json(self) -> Dict[str, Optional[str]]: """Convert ClientOptions to JSON.""" return { "name": self.name, "client_id": self.client_id, "preferred_combiner": self.preferred_combiner, "package": self.package, }
[docs] class Client: """Client for interacting with the FEDn network."""
[docs] def __init__( self, api_url: str, api_port: int, client_obj: ClientOptions, combiner_host: Optional[str] = None, combiner_port: Optional[int] = None, token: Optional[str] = None, package_checksum: Optional[str] = None, helper_type: Optional[str] = None, ) -> None: """Initialize the Client.""" self.api_url = api_url self.api_port = api_port self.combiner_host = combiner_host self.combiner_port = combiner_port self.token = token self.client_obj = client_obj self.package_checksum = package_checksum self.helper_type = helper_type self.fedn_api_url = get_url(self.api_url, self.api_port) self.fedn_client: FednClient = FednClient() self.helper = None
def _connect_to_api(self) -> Tuple[bool, Optional[dict]]: """Connect to the API and handle retries.""" result = None response = None while not result or result == ConnectToApiResult.ComputePackageMissing: if result == ConnectToApiResult.ComputePackageMissing: logger.info("Retrying in 3 seconds") time.sleep(3) result, response = self.fedn_client.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) if result == ConnectToApiResult.Assigned: return True, response return False, None
[docs] def start(self) -> None: """Start the client.""" if self.combiner_host and self.combiner_port: combiner_config = GrpcConnectionOptions(host=self.combiner_host, port=self.combiner_port) else: result, combiner_config = self._connect_to_api() if not result: return if self.client_obj.package == "remote": result = self.fedn_client.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) if not result: return else: result = self.fedn_client.init_local_compute_package() if not result: return self.set_helper(combiner_config) result = self.fedn_client.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) if not result: return logger.info("-----------------------------") self.fedn_client.set_train_callback(self.on_train) self.fedn_client.set_validate_callback(self.on_validation) self.fedn_client.set_forward_callback(self.on_forward) self.fedn_client.set_backward_callback(self.on_backward) self.fedn_client.set_predict_callback(self._process_prediction_request) self.fedn_client.set_name(self.client_obj.name) self.fedn_client.set_client_id(self.client_obj.client_id) self.fedn_client.run()
[docs] def set_helper(self, response: Optional[GrpcConnectionOptions] = None) -> None: """Set the helper based on the response or default.""" helper_type = response.helper_type if response else None helper_type_to_use = self.helper_type or helper_type or "numpyhelper" logger.info(f"Setting helper to: {helper_type_to_use}") self.helper = get_helper(helper_type_to_use)
[docs] def on_train(self, in_model: BytesIO, client_settings: dict) -> Tuple[Optional[BytesIO], dict]: """Handle the training callback.""" return self._process_training_request(in_model, client_settings)
[docs] def on_validation(self, in_model: BytesIO) -> Optional[dict]: """Handle the validation callback.""" return self._process_validation_request(in_model)
[docs] def on_forward(self, client_id, is_sl_inference): out_embeddings, meta = self._process_forward_request(client_id, is_sl_inference) return out_embeddings, meta
[docs] def on_backward(self, in_gradients, client_id): meta = self._process_backward_request(in_gradients, client_id) return meta
def _process_training_request(self, in_model: BytesIO, client_settings: dict) -> Tuple[Optional[BytesIO], dict]: """Process a training (model update) request.""" try: meta = {} inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(in_model.getbuffer()) save_metadata(metadata=client_settings, filename=inpath) outpath = self.helper.get_tmp_path() tic = time.time() self.fedn_client.dispatcher.run_cmd(f"train {inpath} {outpath}") meta["exec_training"] = time.time() - tic with open(outpath, "rb") as fr: out_model = io.BytesIO(fr.read()) with open(outpath + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) logger.info(f"SETTING Training metadata: {training_metadata}") meta["training_metadata"] = training_metadata os.unlink(inpath) os.unlink(outpath) os.unlink(outpath + "-metadata") except Exception as e: logger.error(f"Could not process training request due to error: {e}") out_model = None meta = {"status": "failed", "error": str(e)} return out_model, meta def _process_validation_request(self, in_model: BytesIO) -> Optional[dict]: """Process a validation request.""" try: inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(in_model.getbuffer()) outpath = get_tmp_path() self.fedn_client.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.loads(fh.read()) os.unlink(inpath) os.unlink(outpath) except Exception as e: logger.warning(f"Validation failed with exception {e}") metrics = None return metrics def _process_prediction_request(self, in_model: BytesIO) -> Optional[dict]: """Process a prediction request.""" try: inpath = self.helper.get_tmp_path() with open(inpath, "wb") as fh: fh.write(in_model.getbuffer()) outpath = get_tmp_path() self.fedn_client.dispatcher.run_cmd(f"predict {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.load(fh) os.unlink(inpath) os.unlink(outpath) except Exception as e: logger.warning(f"Prediction failed with exception {e}") metrics = None return metrics def _process_forward_request(self, client_id, is_sl_inference) -> Tuple[BytesIO, dict]: """Process a forward request. Param is_sl_inference determines whether the forward pass is used for gradient calculation or validation. :param client_id: The client ID. :type client_id: str :param is_sl_inference: Whether the request is for splitlearning inference or not. :type is_sl_inference: str :return: The embeddings, or None if forward failed. :rtype: tuple """ try: out_embedding_path = get_tmp_path() tic = time.time() self.fedn_client.dispatcher.run_cmd(f"forward {client_id} {out_embedding_path} {is_sl_inference}") meta = {} embeddings = None with open(out_embedding_path, "rb") as fr: embeddings = io.BytesIO(fr.read()) meta["exec_training"] = time.time() - tic # Read the metadata file with open(out_embedding_path + "-metadata", "r") as fh: training_metadata = json.loads(fh.read()) logger.debug("SETTING Forward metadata: {}".format(training_metadata)) meta["training_metadata"] = training_metadata os.unlink(out_embedding_path) os.unlink(out_embedding_path + "-metadata") except Exception as e: logger.warning("Forward failed with exception {}".format(e)) embeddings = None meta = {"status": "failed", "error": str(e)} return embeddings, meta def _process_backward_request(self, in_gradients: BytesIO, client_id: str) -> dict: """Process a backward request. :param in_gradients: The gradients to be processed. :type in_gradients: BytesIO :return: Metadata, or None if backward failed. :rtype: dict """ try: meta = {} inpath = get_tmp_path() # load gradients with open(inpath, "wb") as fh: fh.write(in_gradients.getbuffer()) tic = time.time() self.fedn_client.dispatcher.run_cmd(f"backward {inpath} {client_id}") meta["exec_training"] = time.time() - tic os.unlink(inpath) except Exception as e: logger.error("Backward failed with exception {}".format(e)) meta = {"status": "failed", "error": str(e)} return meta