import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from scheduler import DDPMScheduler

class DiffusionModule(nn.Module):
    def __init__(self, network, scheduler: DDPMScheduler):
        super().__init__()
        self.network = network
        self.scheduler = scheduler

    def get_loss(self, x0, noise=None):
        raise NotImplementedError

    @torch.no_grad()
    def sample(self, batch_size, guidance_scale=7.5):
        raise NotImplementedError
    
    def save_model(self, file_path):
        params = {"model": self.network, "scheduler": self.scheduler}
        state = self.state_dict()
        torch.save({"params": params, "state": state}, file_path)

    def load_model(self, file_path):
        data = torch.load(file_path, map_location="cpu")
        params = data["params"]
        state = data["state"]
        self.network = params["model"]
        self.scheduler = params["scheduler"]
        self.load_state_dict(state)