swanlab
diffusers
transformers
runwayml/stable-diffusion-v1-5
文生图实战:Stable Diffusion模型训练入门教程(完整代码)
环境安装与准备
在开始之前,请确保您的计算机已安装Python 3.8 及以上版本。此外,至少拥有NVIDIA GPU(建议显存至少为 22GB),以支持Stable Diffusion模型的运行**。
接着,执行以下命令来安装必需的Python库:
pip install swanlab diffusers datasets accelerate torchvision transformers
确认您使用的库版本符合以下推荐要求:diffusers - 0.29.0, accelerate - 0.30.1, datasets - 2.18.0, transformers - 4.41.2, swanlab - 使用最新版本以确保获得最佳功能与性能。
数据集准备
我们将使用火影忍者数据集作为训练的基础数据。数据集包含约1200对图像与描述文本,大小约为700MB。自动下载数据集的Python代码如下:
from datasets import load_dataset
naruto_dataset = load_dataset("lambdalabs/naruto-blip-captions")
遇到网络问题时,可以手动从百度网盘下载数据集,并解压至与训练脚本同一目录。
模型加载
runwayml/stable-diffusion-v1-5
from diffusers import StableDiffusionPipeline
model_path = "稳定_diffusion_v1_5文件夹路径"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe = pipe.to("cuda") # 将模型加载至GPU
配置训练可视化工具
为了有效监控训练过程并评估模型效果,我们建议使用SwanLab进行可视化。首先,请注册SwanLab账户并获取API Key。在执行训练脚本前,请将API Key粘贴至SwanLab平台。
训练流程
训练Stable Diffusion模型涉及多参数调整以优化图像生成质量。以下示例命令将启动训练:
python train_sd1-5_naruto.py \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--seed=42 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--output_dir="sd-naruto-model"
参数解释:
- --use_ema:应用指数移动平均技术以提升模型泛化能力。 --resolution=512:设置训练图像分辨率为512像素。 --center_crop:进行中心裁剪,忽略图像边缘。 --random_flip:在训练过程中随机翻转图像,增加数据多样性。 --gradient_accumulation_steps=4:梯度累积步数,每4步累积梯度更新一次。 --gradient_checkpointing:减少内存消耗,加速训练过程。 --max_train_steps=15000:最大训练步数设定为15000步。 --learning_rate=1e-05:学习率设置为1e-05。 --max_grad_norm=1:梯度范数的最大限制为1。 --seed=42:设置随机种子为42,确保训练结果可重复。 --lr_scheduler="constant":使用常数学习率调度器,保持学习率恒定。 --lr_warmup_steps=0:不进行学习率预热。 --output_dir="sd-naruto-model":模型输出目录为“sd-naruto-model”。
模型推理与结果展示
训练完成后,模型将保存在“sd-naruto-model”目录下。使用以下代码进行模型推理:
from diffusers import StableDiffusionPipeline
import torch
model_id = "./sd-naruto-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "Lebron James with a hat"
image = pipe(prompt).images[0]
image.save("result.png")
资源与链接
详尽的实验代码、训练日志和最终生成结果均位于GitHub仓库。SwanLab提供了详细的使用指南和日志可视化功能。火影忍者数据集和Stable Diffusion模型的下载路径位于百度网盘。
总结
通过上述指导,您已掌握了使用Stable Diffusion模型在火影忍者数据集上进行微调并生成高质量图像的整个流程。从环境准备到数据集配置、模型加载、训练可视化工具设置,再到高效的训练执行,以及模型推理生成图像,每个环节都提供了详细的代码示例和操作指南。本篇教程旨在帮助您在文生图领域迅速获得实践经验,期待您以此为起点,探索更多图像生成的可能。