李沐《动手学深度学习》pycharm运行目标检测数据集报错问题
我在pycharm运行目标检测数据集的时候出现以下报错:1.AttributeError: '_OpNamespace' 'image' object has no attribute 'read_file'2.AttributeError: float
这篇文章是我在学习李沐《动手学深度学习》pytorch版“目标检测数据集”的笔记。
我在pycharm运行目标检测数据集的时候出现以下报错:
1.AttributeError: '_OpNamespace' 'image' object has no attribute 'read_file'
2.AttributeError: float
原因及解决办法如下:
1.AttributeError: '_OpNamespace' 'image' object has no attribute 'read_file'
原因:由于PyTorch版本不匹配导致的,torchvision.io.read_image函数在PyTorch 1.10及以上版本中才可用,而我当前使用的是PyTorch 1.9版本。因此,需要更新PyTorch版本以解决此问题。但是更新PyTorch有可能会带来其他的版本不兼容的问题,建议采用其他方法。不升级PyTorch,可以使用Pillow库(已安装)来代替torchvision.io.read_image函数。
解决办法:
原码:
for img_name, target in csv_data.iterrows():
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')))
修改:
for img_name, target in csv_data.iterrows():
img_path = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}')
image = Image.open(img_path)
images.append(image)
2.AttributeError: float
原因:问题出在BananasDataset类的__getitem__方法中。在该方法中,尝试将图像数据转换为浮点型时引发了错误。这是因为PIL库的图像对象不能直接转换为浮点型张量。
要解决这个问题,可以在__getitem__方法中修改代码,将图像数据转换为张量时使用torchvision.transforms.ToTensor()转换器。这样可以确保正确地将图像转换为张量类型。
解决办法:
原码:
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
def __len__(self):
return len(self.features)
修改:
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
image = self.features[idx]
image = torchvision.transforms.ToTensor()(image)
label = self.labels[idx]
return (image, label)
def __len__(self):
return len(self.features)
以下是我的修改后并成功运行的代码:
import os
import pandas as pd
import torch
import torchvision
import matplotlib.pyplot as plt
from d2l import torch as d2l
from PIL import Image
#@save
d2l.DATA_HUB['banana-detection'] = (
d2l.DATA_URL + 'banana-detection.zip',
'5de26c8fce5ccdea9f91267273464dc968d20d72')
#@save
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
data_dir = d2l.download_extract('banana-detection')
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train
else 'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
img_path = os.path.join(data_dir, 'bananas_train' if is_train else 'bananas_val', 'images', f'{img_name}')
image = Image.open(img_path)
images.append(image)
# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
# 其中所有图像都具有相同的香蕉类(索引为0)
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1) / 256
#@save
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
def __getitem__(self, idx):
image = self.features[idx]
image = torchvision.transforms.ToTensor()(image)
label = self.labels[idx]
return (image, label)
def __len__(self):
return len(self.features)
#@save
def load_data_bananas(batch_size):
"""加载香蕉检测数据集"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])
plt.show()
运行结果:
read 1000 training examples
read 100 validation examples

更多推荐




所有评论(0)