import copy from tqdm import trange import flgo.algorithm.fedbase as fab class Server(fab.BasicServer): def initialize(self, *args, **kwargs): self.init_algo_para({'tune_key':'loss', 'larger_is_better': False}) def run(self): self.gv.logger.time_start('Total Time Cost') # evaluating initial model performance self.gv.logger.info("--------------Initial Evaluation--------------") self.gv.logger.time_start('Eval Time Cost') self.gv.logger.log_once() self.gv.logger.time_end('Eval Time Cost') # Standalone Training for c in self.clients: c.finetune() self.gv.logger.info("=================End==================") self.gv.logger.time_end('Total Time Cost') self.gv.logger.time_start('Eval Time Cost') self.gv.logger.log_once() self.gv.logger.time_end('Eval Time Cost') # save results as .json file self.gv.logger.save_output_as_json() return class Client(fab.BasicClient): def initialize(self, *args, **kwargs): self.model = copy.deepcopy(self.server.model) def finetune(self): dataloader = self.calculator.get_dataloader(self.train_data, batch_size=self.batch_size) optimizer = self.calculator.get_optimizer(self.model, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.momentum) epoch_iter = trange(self.num_epochs+1) op_model_dict = copy.deepcopy(self.model.state_dict()) op_met = self.test(self.model, 'val') op_epoch = 0 for e in epoch_iter: if eop_met[self.tune_key]) or ((not self.larger_is_better) and val_metric[self.tune_key]