Source code for fedn.utils.helpers.plugins.splitlearninghelper

import os
import tempfile

import numpy as np

# import torch
from fedn.common.log_config import logger
from fedn.utils.helpers.helperbase import HelperBase


[docs] class Helper(HelperBase): """FEDn helper class for models weights/parameters that can be transformed to numpy ndarrays."""
[docs] def __init__(self): """Initialize helper.""" super().__init__() self.name = "splitlearninghelper"
[docs] def increment_average(self, embedding1, embedding2): """Concatenates two embeddings of format {client_id: embedding} into a new dictionary :param embedding1: First embedding dictionary :param embedding2: Second embedding dictionary :return: Concatenated embedding dictionary """ return {**embedding1, **embedding2}
[docs] def save(self, data_dict, path=None, file_type="npz"): if not path: path = self.get_tmp_path() logger.info("SPLIT LEARNING HELPER: Saving data to {}".format(path)) # Ensure all values are numpy arrays processed_dict = {str(k): np.array(v) for k, v in data_dict.items()} with open(path, "wb") as f: np.savez_compressed(f, **processed_dict) return path
[docs] def load(self, path): """Load embeddings/gradients. :param path: Path to file :return: Dict mapping client IDs to numpy arrays (either embeddings or gradients) """ try: data = np.load(path) logger.info("SPLIT LEARNING HELPER: loaded data from {}".format(path)) result_dict = {k: data[k] for k in data.files} return result_dict except Exception as e: logger.error(f"Error in splitlearninghelper: loading data from {path}: {str(e)}") raise
[docs] def get_tmp_path(self, suffix=".npz"): """Return a temporary output path compatible with save_model, load_model. :param suffix: File suffix. :return: Path to file. """ fd, path = tempfile.mkstemp(suffix=suffix) os.close(fd) return path
[docs] def check_supported_file_type(self, file_type): """Check if the file type is supported. :param file_type: File type to check. :type file_type: str :return: True if supported, False otherwise. :rtype: bool """ supported_file_types = ["npz", "raw_binary"] if file_type not in supported_file_types: raise ValueError("File type not supported. Supported types are: {}".format(supported_file_types)) return True