CV icon indicating copy to clipboard operation
CV copied to clipboard

作者您好,为什么分类预测结果一直为daisy

Open ZZHOO1 opened this issue 1 year ago • 5 comments

ZZHOO1 avatar Jul 16 '24 09:07 ZZHOO1

Uploading predict.jpg…

ZZHOO1 avatar Jul 16 '24 09:07 ZZHOO1

权重载入了吗

codecat0 avatar Jul 17 '24 14:07 codecat0

我检查了代码,问题再加载数据时,不是加载的所有数据集,是第一个循环的数据

ZZHOO1 avatar Jul 17 '24 14:07 ZZHOO1

数据集划分错误: 由于 read_split_data 函数中的 return 语句放置在第一个类别循环的末尾,导致函数只会处理第一个类别的数据并提前返回。确保数据集划分包含所有类别。 def read_split_data(root: str, val_rate: float = 0.2, plot_image: bool = False): # 保证随机结果可复现 random.seed(0) assert os.path.exists(root), f'dataset root {root} does not exist.'

# 遍历文件夹,一个文件夹对应一个类别
flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]

# 排序,保证顺序一致
flower_classes.sort()

# 给类别进行编码,生成对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_classes))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as f:
    f.write(json_str)

# 训练集所有图片的路径和对应索引信息
train_images_path, train_images_label = [], []

# 验证集所有图片的路径和对应索引信息
val_images_path, val_images_label = [], []

# 每个类别的样本总数
every_class_num = []

# 支持的图片格式
images_format = [".jpg", ".JPG", ".png", ".PNG"]

# 遍历每个文件夹下的文件
for cla in flower_classes:
    cla_path = os.path.join(root, cla)

    # 获取每个类别文件夹下所有图片的路径
    images = [os.path.join(cla_path, i) for i in os.listdir(cla_path)
              if os.path.splitext(i)[-1] in images_format]

    # 获取类别对应的索引
    image_class = class_indices[cla]

    # 获取此类别的样本数
    every_class_num.append(len(images))

    # 按比例随机采样验证集
    val_path = random.sample(images, k=int(len(images) * val_rate))

    for img_path in images:
        if img_path in val_path:
            val_images_path.append(img_path)
            val_images_label.append(image_class)
        else:
            train_images_path.append(img_path)
            train_images_label.append(image_class)

print(f"{sum(every_class_num)} images found in dataset.")
print(f"{len(train_images_path)} images for training.")
print(f"{len(val_images_path)} images for validation.")

if plot_image:
    plt.bar(range(len(flower_classes)), every_class_num, align='center')
    plt.xticks(range(len(flower_classes)), flower_classes)
    for i, v in enumerate(every_class_num):
        plt.text(x=i, y=v + 5, s=str(v), ha='center')
    plt.xlabel('image class')
    plt.ylabel('number of images')
    plt.title('flower class distribution')
    plt.show()

return train_images_path, train_images_label, val_images_path, val_images_label

ZZHOO1 avatar Jul 17 '24 14:07 ZZHOO1

就是把data_utils.py里的77-91行改一下缩进就好了: print(f"{sum(every_class_num)} images found in dataset.") print(f"{len(train_images_path)} images for training.") print(f"{len(val_images_path)} images for validation.")

if plot_image:
    plt.bar(range(len(flower_classes)), every_class_num, align='center')
    plt.xticks(range(len(flower_classes)), flower_classes)
    for i, v in enumerate(every_class_num):
        plt.text(x=i, y=v + 5, s=str(v), ha='center')
    plt.xlabel('image class')
    plt.ylabel('number of images')
    plt.title('flower class distribution')
    plt.show()

return train_images_path, train_images_label, val_images_path, val_images_label

yuguoliua avatar Nov 26 '24 12:11 yuguoliua