Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from datasets import load_dataset | |
| from huggingface_hub import Repository | |
| from huggingface_hub import HfApi, HfFolder, Repository, create_repo | |
| import os | |
| token = os.getenv('NEW_TOKEN') | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| from small_256_model import UNet as small_UNet | |
| from big_1024_model import UNet as big_UNet | |
| # Device configuration | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| big = False if device == torch.device('cpu') else True | |
| # Parameters | |
| IMG_SIZE = 1024 if big else 256 | |
| BATCH_SIZE = 16 if big else 4 | |
| EPOCHS = 12 | |
| LR = 0.0002 | |
| dataset_id = "K00B404/pix2pix_flux_set" | |
| model_repo_id = "K00B404/pix2pix_flux" | |
| # Create dataset and dataloader | |
| class Pix2PixDataset(torch.utils.data.Dataset): | |
| def __init__(self, ds, transform): | |
| # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images | |
| self.originals = [x for x in ds["train"] if x['label'] == 0] | |
| self.targets = [x for x in ds["train"] if x['label'] == 1] | |
| # Ensure the number of original and target images match | |
| assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images." | |
| # Debug: Print dataset size | |
| print(f"Number of original images: {len(self.originals)}") | |
| print(f"Number of target images: {len(self.targets)}") | |
| self.transform = transform # Store the transform | |
| def __len__(self): | |
| return len(self.originals) | |
| def __getitem__(self, idx): | |
| original_img = self.originals[idx]['image'] | |
| target_img = self.targets[idx]['image'] | |
| original = original_img.convert('RGB') # Convert to RGB if needed | |
| target = target_img.convert('RGB') # Convert to RGB if needed | |
| # Apply the necessary transforms | |
| return self.transform(original), self.transform(target) | |
| class UNetWrapper: | |
| def __init__(self, unet_model, repo_id): | |
| self.model = unet_model | |
| self.repo_id = repo_id | |
| self.token = os.getenv('NEW_TOKEN') # Make sure this environment variable is set | |
| self.api = HfApi(token=os.getenv('NEW_TOKEN')) | |
| def push_to_hub(self): | |
| try: | |
| # Save model state and configuration | |
| save_dict = { | |
| 'model_state_dict': self.model.state_dict(), | |
| 'model_config': { | |
| 'big': isinstance(self.model, big_UNet), | |
| 'img_size': 1024 if isinstance(self.model, big_UNet) else 256 | |
| }, | |
| 'model_architecture': str(self.model) | |
| } | |
| # Save model locally | |
| pth_name = 'model_weights.pth' | |
| torch.save(save_dict, pth_name) | |
| # Create repo if it doesn't exist | |
| try: | |
| create_repo( | |
| repo_id=self.repo_id, | |
| token=self.token, | |
| exist_ok=True | |
| ) | |
| except Exception as e: | |
| print(f"Repository creation note: {e}") | |
| # Upload the model file | |
| self.api.upload_file( | |
| path_or_fileobj=pth_name, | |
| path_in_repo=pth_name, | |
| repo_id=self.repo_id, | |
| token=self.token, | |
| repo_type="model" | |
| ) | |
| # Create and upload model card | |
| model_card = f"""--- | |
| tags: | |
| - unet | |
| - pix2pix | |
| library_name: pytorch | |
| --- | |
| # Pix2Pix UNet Model | |
| ## Model Description | |
| Custom UNet model for Pix2Pix image translation. | |
| - Image Size: {1024 if isinstance(self.model, big_UNet) else 256} | |
| - Model Type: {"Big (1024)" if isinstance(self.model, big_UNet) else "Small (256)"} | |
| ## Usage | |
| ```python | |
| import torch | |
| from small_256_model import UNet as small_UNet | |
| from big_1024_model import UNet as big_UNet | |
| # Load the model | |
| checkpoint = torch.load('model_weights.pth') | |
| model = big_UNet() if checkpoint['model_config']['big'] else small_UNet() | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| Model Architecture | |
| {str(self.model)} | |
| """ | |
| # Save and upload README | |
| with open("README.md", "w") as f: | |
| f.write(model_card) | |
| self.api.upload_file( | |
| path_or_fileobj="README.md", | |
| path_in_repo="README.md", | |
| repo_id=self.repo_id, | |
| token=self.token, | |
| repo_type="model" | |
| ) | |
| # Clean up local files | |
| os.remove(pth_name) | |
| os.remove("README.md") | |
| print(f"Model successfully uploaded to {self.repo_id}") | |
| except Exception as e: | |
| print(f"Error uploading model: {e}") | |
| # Training function | |
| def train_model(epochs): | |
| # Load the dataset | |
| ds = load_dataset(dataset_id) | |
| print(f"ds{ds}") | |
| transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| ]) | |
| dataset = Pix2PixDataset(ds, transform) | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) | |
| # Initialize model, loss function, and optimizer | |
| try: | |
| model = UNet2DModel.from_pretrained(model_repo_id).to(device) | |
| except Exception: | |
| model = big_UNet().to(device) if big else small_UNet().to(device) | |
| criterion = nn.L1Loss() | |
| optimizer = optim.Adam(model.parameters(), lr=LR) | |
| output_text = [] | |
| # Training loop | |
| for epoch in range(epochs): | |
| model.train() | |
| for i, (original, target) in enumerate(dataloader): | |
| original, target = original.to(device), target.to(device) | |
| optimizer.zero_grad() | |
| # Forward pass | |
| output = model(target) # Generate cutout image | |
| loss = criterion(output, original) # Compare with original image | |
| # Backward pass | |
| loss.backward() | |
| optimizer.step() | |
| if i % 10 == 0: | |
| status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.8f}" | |
| print(status) | |
| output_text.append(status) | |
| return model, "\n".join(output_text) | |
| # Push model to Hugging Face Hub | |
| def push_model_to_hub(model, repo_id): | |
| wrapper = UNetWrapper(model, repo_id) | |
| wrapper.push_to_hub() | |
| # Push the model to the Hugging Face hub | |
| #model.push_to_hub(repo_name) | |
| # Gradio interface function | |
| def gradio_train(epochs): | |
| model, training_log = train_model(int(epochs)) | |
| push_model_to_hub(model, model_repo_id) | |
| return f"{training_log}\n\nModel trained for {epochs} epochs on the {dataset_id} dataset and pushed to Hugging Face Hub {model_repo_id} repository." | |
| # Gradio Interface | |
| gr_interface = gr.Interface( | |
| fn=gradio_train, | |
| inputs=gr.Number(label="Number of Epochs"), | |
| outputs=gr.Textbox(label="Training Progress", lines=10), | |
| title="Pix2Pix Model Training", | |
| description="Train the Pix2Pix model and push it to the Hugging Face Hub repository." | |
| ) | |
| if __name__ == '__main__': | |
| # Create or clone the repository if necessary | |
| #repo = Repository(local_dir=model_repo_id, clone_from=model_repo_id) | |
| #repo.git_pull() | |
| # Launch the Gradio app | |
| gr_interface.launch() |