博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Python3读取深度学习CIFAR-10数据集出现的若干问题解决
阅读量:6862 次
发布时间:2019-06-26

本文共 3258 字,大约阅读时间需要 10 分钟。

今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集。当我兴高采烈的运行代码时,却发现了一些错误:

# -*- coding: utf-8 -*-import pickle as pimport numpy as npimport osdef load_CIFAR_batch(filename):""" 载入cifar数据集的一个batch """with open(filename, 'r') as f:datadict = p.load(f)X = datadict['data']Y = datadict['labels']X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):""" 载入cifar全部数据 """xs = []ys = []for b in range(1, 6):f = os.path.join(ROOT, 'data_batch_%d' % (b,))X, Y = load_CIFAR_batch(f)xs.append(X)ys.append(Y)Xtr = np.concatenate(xs)Ytr = np.concatenate(ys)del X, YXte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtr, Ytr, Xte, Yte复制代码

错误代码如下:

'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence复制代码
于是乎开始各种搜索问题,问大佬,网上的答案都是类似:

然而并没有解决问题!还是错误的!(我大概搜索了一下午吧,都是上面的答案)

哇,就当我很绝望的时候,我终于发现了一个新奇的答案,抱着试一试的态度,尝试了一下:

def load_CIFAR_batch(filename):""" 载入cifar数据集的一个batch """with open(filename, 'rb') as f:datadict = p.load(f, encoding='latin1')X = datadict['data']Y = datadict['labels']X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")Y = np.array(Y)return X, Y复制代码
竟然成功了,这里没有报错了!欣喜之余,我就很好奇,encoding='latin1'到底是啥玩意呢,以前没有见过啊?于是,我搜索了一下,了解到:
 Latin1是ISO-8859-1的别名,有些环境下写作Latin-1。ISO-8859-1编码是单字节编码,向下兼容ASCII,其编码范围是0x00-0xFF,0x00-0x7F之间完全和ASCII一致,0x80-0x9F之间是控制字符,0xA0-0xFF之间是文字符号。
因为ISO-8859-1编码范围使用了单字节内的所有空间,在支持ISO-8859-1的系统中传输和存储其他任何编码的字节流都不会被抛弃。换言之,把其他任何编码的字节流当作ISO-8859-1编码看待都没有问题。这是个很重要的特性,MySQL数据库默认编码是Latin1就是利用了这个特性。ASCII编码是一个7位的容器,ISO-8859-1编码是一个8位的容器。
还没等我高兴起来,运行后,又发现了一个问题:

memory error复制代码
什么鬼?内存错误!哇,原来是数据大小的问题。

X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")复制代码
这告诉我们每批数据都是10000 * 3 * 32 * 32,相当于超过3000万个浮点数。 float数据类型实际上与float64相同,意味着每个数字大小占8个字节。这意味着每个批次占用至少240 MB。你加载6这些(5训练+ 1测试)在总产量接近1.4 GB的数据。

for b in range(1, 2):f = os.path.join(ROOT, 'data_batch_%d' % (b,))X, Y = load_CIFAR_batch(f)xs.append(X)ys.append(Y)复制代码

所以如有可能,如上代码所示只能一次运行一批。
到此为止,错误基本搞定,下面贴出正确代码:
# -*- coding: utf-8 -*-import pickle as pimport numpy as npimport osdef load_CIFAR_batch(filename):""" 载入cifar数据集的一个batch """with open(filename, 'rb') as f:datadict = p.load(f, encoding='latin1')X = datadict['data']Y = datadict['labels']X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):""" 载入cifar全部数据 """xs = []ys = []for b in range(1, 2):f = os.path.join(ROOT, 'data_batch_%d' % (b,))X, Y = load_CIFAR_batch(f)xs.append(X) #将所有batch整合起来ys.append(Y)Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)Ytr = np.concatenate(ys)del X, YXte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtr, Ytr, Xte, Yteimport numpy as npfrom julyedu.data_utils import load_CIFAR10import matplotlib.pyplot as pltplt.rcParams['figure.figsize'] = (10.0, 8.0)plt.rcParams['image.interpolation'] = 'nearest'plt.rcParams['image.cmap'] = 'gray'# 载入CIFAR-10数据集cifar10_dir = 'julyedu/datasets/cifar-10-batches-py'X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)# 看看数据集中的一些样本:每个类别展示一些print('Training data shape: ', X_train.shape)print('Training labels shape: ', y_train.shape)print('Test data shape: ', X_test.shape)print('Test labels shape: ', y_test.shape)复制代码

顺便看一下CIFAR-10数据组成:

更多内容,可关注我的个人公众号

                                                   

转载地址:http://mbhyl.baihongyu.com/

你可能感兴趣的文章
安卓多线程的实现
查看>>
【现在还没补的比赛及题解】
查看>>
C#截取字符串按字节截取SubString
查看>>
MAVLink v1.0详解——结构
查看>>
Office 365离线安装
查看>>
服务器负载暴涨以后...
查看>>
【物联网智能网关-15】WAV播放器(WinForm+WavPlay库实例)
查看>>
实战:将静态路由发布到动态路由
查看>>
Linux桌面新彩虹-Fedora 14 炫酷应用新体验
查看>>
灵活管理Hadoop各发行版的运维利器 - vSphere Big Data Extensions
查看>>
Data Protection Manager 2010 系列之安装部署
查看>>
【SeaJS】【3】seajs.data相关的源码阅读
查看>>
[PHP] 访问MySQL
查看>>
linux下redmine3.3迁移、升级、插件备忘录
查看>>
Hadoop原理及部署初探
查看>>
Oracle 11g R2 常见问题处理
查看>>
windows下expdp自动备份脚本
查看>>
WPF-009:WPF窗体的拖动
查看>>
MDT2012部署系列之10 Win7镜像捕获与系统安装
查看>>
Windows 2003 AD升级至Windows 2012 AD之DHCP服务器迁移
查看>>