import ast
import json
import linecache
import linecache as _lc
import traceback
from concurrent import futures
import grpc
import fedn.network.grpc.fedn_pb2 as fedn
import fedn.network.grpc.fedn_pb2_grpc as rpc
from fedn.common.log_config import logger
# imports for user defined code
from fedn.network.combiner.hooks.allowed_import import * # noqa: F403
from fedn.network.combiner.hooks.allowed_import import ServerFunctionsBase
from fedn.network.combiner.hooks.grpc_wrappers import safe_streaming, safe_unary
from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model
from fedn.utils.helpers.plugins.numpyhelper import Helper
CHUNK_SIZE = 1024 * 1024
VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$"
[docs]
class FunctionServiceServicer(rpc.FunctionServiceServicer):
"""Function service running in an environment combined with each combiner.
Receiving requests from the combiner.
"""
[docs]
def __init__(self) -> None:
"""Initialize long-running Function server."""
super().__init__()
self.helper = Helper()
self.server_functions: ServerFunctionsBase = None
self.server_functions_code: str = None
self.client_updates = {}
self.implemented_functions = {}
logger.info("Server Functions initialized.")
@safe_unary("client_settings", lambda: fedn.ClientConfigResponse(client_settings=json.dumps({})))
def HandleClientConfig(self, request_iterator: fedn.ClientConfigRequest, context):
"""Distribute client configs to clients from user defined code.
:param request_iterator: the client config request
:type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ClientConfigRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the client config response
:rtype: :class:`fedn.network.grpc.fedn_pb2.ClientConfigResponse`
"""
logger.info("Received client config request.")
model, _ = unpack_model(request_iterator, self.helper)
client_settings = self.server_functions.client_settings(global_model=model)
logger.info(f"Client config response: {client_settings}")
return fedn.ClientConfigResponse(client_settings=json.dumps(client_settings))
@safe_unary("client_selection", lambda: fedn.ClientSelectionResponse(client_ids=json.dumps([])))
def HandleClientSelection(self, request: fedn.ClientSelectionRequest, context):
"""Handle client selection from user defined code.
:param request: the client selection request
:type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ClientSelectionRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the client selection response
:rtype: :class:`fedn.network.grpc.fedn_pb2.ClientSelectionResponse`
"""
logger.info("Received client selection request.")
client_ids = json.loads(request.client_ids)
client_ids = self.server_functions.client_selection(client_ids)
logger.info(f"Clients selected: {client_ids}")
return fedn.ClientSelectionResponse(client_ids=json.dumps(client_ids))
@safe_unary("store_metadata", lambda: fedn.ClientMetaResponse(status="ERROR"))
def HandleMetadata(self, request: fedn.ClientMetaRequest, context):
"""Store client metadata from a request.
:param request: the client meta request
:type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ClientMetaRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the client meta response
:rtype: :class:`fedn.network.grpc.fedn_pb2.ClientMetaResponse`
"""
logger.info("Received metadata")
client_id = request.client_id
metadata = json.loads(request.metadata)
# dictionary contains: [model, client_metadata] in that order for each key
self.client_updates[client_id] = self.client_updates.get(client_id, []) + [metadata]
self.check_incremental_aggregate(client_id)
return fedn.ClientMetaResponse(status="Metadata stored")
@safe_unary("store_model", lambda: fedn.StoreModelResponse(status="ERROR"))
def HandleStoreModel(self, request_iterator, context):
model, final_request = unpack_model(request_iterator, self.helper)
client_id = final_request.id
if client_id == "global_model":
logger.info("Received previous global model")
self.previous_global = model
else:
logger.info(f"Received client model from client {client_id}")
# dictionary contains: [model, client_metadata] in that order for each key
self.client_updates[client_id] = [model] + self.client_updates.get(client_id, [])
self.check_incremental_aggregate(client_id)
return fedn.StoreModelResponse(status=f"Received model originating from {client_id}")
[docs]
def check_incremental_aggregate(self, client_id):
# incremental aggregation (memory secure)
if client_id == "global_model":
return
model_and_metadata_received = len(self.client_updates[client_id]) == 2
if model_and_metadata_received and self.implemented_functions["incremental_aggregate"]:
client_model = self.client_updates[client_id][0]
client_metadata = self.client_updates[client_id][1]
self.server_functions.incremental_aggregate(client_id, client_model, client_metadata, self.previous_global)
del self.client_updates[client_id]
@safe_streaming("aggregate")
def HandleAggregation(self, request, context):
"""Receive and store models and aggregate based on user-defined code when specified in the request.
:param request_iterator: the aggregation request
:type request_iterator: :class:`fedn.network.grpc.fedn_pb2.fedn.AggregationRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: the aggregation response (aggregated model or None)
:rtype: :class:`fedn.network.grpc.fedn_pb2.AggregationResponse`
"""
logger.info(f"Receieved aggregation request: {request.aggregate}")
if self.implemented_functions["incremental_aggregate"]:
aggregated_model = self.server_functions.get_incremental_aggregate_model()
else:
aggregated_model = self.server_functions.aggregate(self.previous_global, self.client_updates)
model_bytesIO = model_as_bytesIO(aggregated_model, self.helper)
request_function = fedn.AggregationResponse
self.client_updates = {}
logger.info("Returning aggregate model.")
response_generator = bytesIO_request_generator(mdl=model_bytesIO, request_function=request_function, args={})
for response in response_generator:
yield response
[docs]
def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, context):
"""Handles the 'provided_functions' request. Sends back which functions are available.
:param request: the provided function request
:type request: :class:`fedn.network.grpc.fedn_pb2.fedn.ProvidedFunctionsRequest`
:param context: the context (unused)
:type context: :class:`grpc._server._Context`
:return: dict with str -> bool for which functions are available
:rtype: :class:`fedn.network.grpc.fedn_pb2.ProvidedFunctionsResponse`
"""
logger.info("Receieved provided functions request.")
server_functions_code = request.function_code
# if no new code return previous
if server_functions_code == self.server_functions_code:
logger.info("No new server function code provided.")
logger.info(f"Provided functions: {self.implemented_functions}")
return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions)
self.server_functions_code = server_functions_code
self.implemented_functions = {}
self._instansiate_server_functions_code()
if self.implemented_functions == {}: # not defaultet due to error
functions = ["client_selection", "client_settings", "aggregate", "incremental_aggregate"]
# parse the entire code string into an AST
tree = ast.parse(server_functions_code)
# collect all real function names
defined_funcs = {node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)}
# check each target function
for func in functions:
if func in defined_funcs:
print(f"Function '{func}' found—assuming it´s implemented.")
self.implemented_functions[func] = True
else:
print(f"Function '{func}' not found.")
self.implemented_functions[func] = False
logger.info(f"Provided function: {self.implemented_functions}")
return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions)
def _instansiate_server_functions_code(self):
try:
namespace = {}
# create a stable synthetic filename to appear in tracebacks
self._server_code_filename = f"server_functions:{hash(self.server_functions_code)}"
code_obj = compile(self.server_functions_code, self._server_code_filename, "exec")
# prime linecache so traceback can show source lines
linecache.cache[self._server_code_filename] = (
len(self.server_functions_code),
None,
[ln if ln.endswith("\n") else ln + "\n" for ln in self.server_functions_code.splitlines()],
self._server_code_filename,
)
exec(code_obj, globals(), namespace) # noqa: S102
exec("server_functions = ServerFunctions()", globals(), namespace) # noqa: S102
self.server_functions = namespace.get("server_functions")
except Exception as e:
logger.error(f"Exec failed: {e}")
self.server_functions = None
self.implemented_functions = dict.fromkeys(["client_selection", "client_settings", "aggregate", "incremental_aggregate"], False)
def _retire_and_log(self, func_name: str, err: Exception):
# retire the function immediately
if func_name in self.implemented_functions:
self.implemented_functions[func_name] = False
# try to find frames that originate from the compiled user code
tb = traceback.extract_tb(err.__traceback__)
user_frames = []
filename = getattr(self, "_server_code_filename", None)
for frame in tb:
if filename and frame.filename == filename:
user_frames.append(frame)
if user_frames:
# deepest frame in user code (where it actually failed)
f = user_frames[-1]
# fetch the source line from linecache (primed earlier)
src_line = (_lc.getline(f.filename, f.lineno) or "").rstrip("\n")
logger.error(f"User function '{func_name}' crashed at {f.filename}:{f.lineno} in {f.name}()\n> {src_line}\nException: {repr(err)}")
else:
# fallback: full traceback (server + user frames) if we didn't match a user frame
logger.exception(f"{func_name} failed, retiring until next code update: {err}")
[docs]
def serve():
"""Start the hooks service."""
# Keepalive settings: these detect if the client is alive
KEEPALIVE_TIME_MS = 5 * 60 * 1000 # send keepalive ping every 5 minutes
KEEPALIVE_TIMEOUT_MS = 20 * 1000 # wait 20 seconds for keepalive ping ack before considering connection dead
MAX_CONNECTION_IDLE_MS = 5 * 60 * 1000 # max idle time before server terminates the connection (5 minutes)
MAX_MESSAGE_LENGTH = 1 * 1024 * 1024 * 1024 # 1 GB in bytes
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=100), # Increase based on expected load
options=[
("grpc.keepalive_time_ms", KEEPALIVE_TIME_MS),
("grpc.keepalive_timeout_ms", KEEPALIVE_TIMEOUT_MS),
("grpc.max_connection_idle_ms", MAX_CONNECTION_IDLE_MS),
("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
("grpc.max_receive_message_length", -1),
],
)
rpc.add_FunctionServiceServicer_to_server(FunctionServiceServicer(), server)
server.add_insecure_port("[::]:12081")
server.start()
server.wait_for_termination()