VGG16 训练猫狗数据集
一、准备数据集
1. 准备数据
准备数据应该是一件比较麻烦的过程,所以一般都去找那种公开的数据集。在网上找到的可以用于猫狗分类的数据集有 Kaggle 的 “Dogs vs. Cats”数据集,还有牛津大学提供的 Oxford-IIIT Pet 数据集,包含猫和狗的图片,都是非常适合做猫狗分类任务的公开数据集。
这里我就选择 Kaggle 中的 Cat VS Dog 数据集,在 Kaggle 中搜一下就搜到了
Kaggle 提供的原始数据集结构是这样的:
data/
├── Cat/
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── ...
├── Dog/
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── ...
我们下载后传到板子上,然后解压下载的数据集:
解压之后简单移动一下 Dog 和 Cat 目录:
但 PyTorch 的 ImageFolder
需要数据按照 train/
和 val/
分类存放,所以我们要将数据整理成如下格式:
data_split/
├── train/
│ ├── cats/
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ ├── dogs/
│ ├── 1.jpg
│ ├── 2.jpg
├── val/
│ ├── cats/
│ │ ├── 1001.jpg
│ │ ├── 1002.jpg
│ ├── dogs/
│ ├── 1001.jpg
│ ├── 1002.jpg
2. Python 代码:划分数据集
运行以下代码,它会自动 创建 train/
和 val/
目录,并按 80%:20% 的比例划分数据:
import os
import shutil
import random
# 设置路径
original_data_dir = "data" # 你的原始 Cat/ 和 Dog/ 目录所在路径
base_dir = "data_split" # 训练/验证集存放路径
train_dir = os.path.join(base_dir, "train")
val_dir = os.path.join(base_dir, "val")
# 创建 train 和 val 目录
for split in ["train", "val"]:
os.makedirs(os.path.join(train_dir, "cats"), exist_ok=True)
os.makedirs(os.path.join(train_dir, "dogs"), exist_ok=True)
os.makedirs(os.path.join(val_dir, "cats"), exist_ok=True)
os.makedirs(os.path.join(val_dir, "dogs"), exist_ok=True)
# 获取所有猫和狗的图片
all_cats = [f for f in os.listdir(os.path.join(original_data_dir, "Cat")) if f.endswith(".jpg")]
all_dogs = [f for f in os.listdir(os.path.join(original_data_dir, "Dog")) if f.endswith(".jpg")]
# 随机打乱数据集
random.seed(42)
random.shuffle(all_cats)
random.shuffle(all_dogs)
# 计算 80% 训练,20% 验证
train_size = int(0.8 * len(all_cats))
train_cats, val_cats = all_cats[:train_size], all_cats[train_size:]
train_dogs, val_dogs = all_dogs[:train_size], all_dogs[train_size:]
# 复制猫图片到新的目录
for fname in train_cats:
shutil.copy(os.path.join(original_data_dir, "Cat", fname), os.path.join(train_dir, "cats", fname))
for fname in val_cats:
shutil.copy(os.path.join(original_data_dir, "Cat", fname), os.path.join(val_dir, "cats", fname))
# 复制狗图片到新的目录
for fname in train_dogs:
shutil.copy(os.path.join(original_data_dir, "Dog", fname), os.path.join(train_dir, "dogs", fname))
for fname in val_dogs:
shutil.copy(os.path.join(original_data_dir, "Dog", fname), os.path.join(val_dir, "dogs", fname))
print("数据集划分完成!")
发布者:admin,转转请注明出处:http://www.yc00.com/web/1748171971a4741788.html
评论列表(0条)