第7章 使用Keras开发神经网络

简介: 第7章 使用Keras开发神经网络 Keras基于Python,开发深度学习模型很容易。Keras将Theano和TensorFlow的数值计算封装好,几句话就可以配置并训练神经网络。本章开始使用Keras开发神经网络。

第7章 使用Keras开发神经网络

Keras基于Python,开发深度学习模型很容易。Keras将Theano和TensorFlow的数值计算封装好,几句话就可以配置并训练神经网络。本章开始使用Keras开发神经网络。本章将:

  • 将CSV数据读入Keras
  • 用Keras配置并编译多层感知器模型
  • 用验证数据集验证Keras模型

我们开始吧。

7.1 简介

虽然代码量不大,但是我们还是慢慢来。大体分几步:

  1. 导入数据
  2. 定义模型
  3. 编译模型
  4. 训练模型
  5. 测试模型
  6. 写出程序

7.2 皮马人糖尿病数据集

我们使用皮马人糖尿病数据集(Pima Indians onset of diabetes),在UCI的机器学习网站可以免费下载。数据集的内容是皮马人的医疗记录,以及过去5年内是否有糖尿病。所有的数据都是数字,问题是(是否有糖尿病是1或0),是二分类问题。数据的数量级不同,有8个属性:

  1. 怀孕次数
  2. 2小时口服葡萄糖耐量试验中的血浆葡萄糖浓度
  3. 舒张压(毫米汞柱)
  4. 2小时血清胰岛素(mu U/ml)
  5. 体重指数(BMI)
  6. 糖尿病血系功能
  7. 年龄(年)
  8. 类别:过去5年内是否有糖尿病

所有的数据都是数字,可以直接导入Keras。本书后面也会用到这个数据集。数据有768行,前5行的样本长这样:

6,148,72,35,0,33.6,0.627,50,1
1,85,66,29,0,26.6,0.351,31,0
8,183,64,0,0,23.3,0.672,32,1
1,89,66,23,94,28.1,0.167,21,0
0,137,40,35,168,43.1,2.288,33,1

数据在本书代码的data 目录,也可以在UCI机器学习的网站下载。把数据和Python文件放在一起,改名:

pima-indians-diabetes.csv

基准准确率是65.1%,在10次交叉验证中最高的正确率是77.7%。在UCI机器学习的网站可以得到数据集的更多资料。

7.3 导入资料

使用随机梯度下降时最好固定随机数种子,这样你的代码每次运行的结果都一致。这种做法在演示结果、比较算法或debug时特别有效。你可以随便选种子:

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

现在导入皮马人数据集。NumPy的loadtxt()函数可以直接带入数据,输入变量是8个,输出1个。导入数据后,我们把数据分成输入和输出两组以便交叉检验:

# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]

这样我们的数据每次结果都一致,可以定义模型了。

7.4 定义模型

Keras的模型由层构成:我们建立一个Sequential模型,一层层加入神经元。第一步是确定输入层的数目正确:在创建模型时用input_dim参数确定。例如,有8个输入变量,就设成8。

隐层怎么设置?这个问题很难回答,需要慢慢试验。一般来说,如果网络够大,即使存在问题也不会有影响。这个例子里我们用3层全连接网络。

全连接层用Dense类定义:第一个参数是本层神经元个数,然后是初始化方式和激活函数。这里的初始化方法是0到0.05的连续型均匀分布(uniform),Keras的默认方法也是这个。也可以用高斯分布进行初始化(normal)。

前两层的激活函数是线性整流函数(relu),最后一层的激活函数是S型函数(sigmoid)。之前大家喜欢用S型和正切函数,但现在线性整流函数效果更好。为了保证输出是0到1的概率数字,最后一层的激活函数是S型函数,这样映射到0.5的阈值函数也容易。前两个隐层分别有12和8个神经元,最后一层是1个神经元(是否有糖尿病)。

# create model
model = Sequential()
model.add(Dense(12, input_dim=8, init='uniform', activation='relu')) model.add(Dense(8, init='uniform', activation='relu')) model.add(Dense(1, init='uniform', activation='sigmoid'))

7.5 编译模型

定义好的模型可以编译:Keras会调用Theano或者TensorFlow编译模型。后端会自动选择表示网络的最佳方法,配合你的硬件。这步需要定义几个新的参数。训练神经网络的意义是:找到最好的一组权重,解决问题。

我们需要定义损失函数和优化算法,以及需要收集的数据。我们使用binary_crossentropy,错误的对数作为损失函数;adam作为优化算法,因为这东西好用。想深入了解请查阅:Adam: A Method for Stochastic Optimization论文。因为这个问题是分类问题,我们收集每轮的准确率。

7.6 训练模型

终于开始训练了!调用模型的fit()方法即可开始训练。

网络按轮训练,通过nb_epoch参数控制。每次送入的数据(批尺寸)可以用batch_size参数控制。这里我们只跑150轮,每次10个数据。多试试就知道了。

# Fit the model
model.fit(X, Y, nb_epoch=150, batch_size=10)

现在CPU或GPU开始煎鸡蛋了。

相关文章
|
24天前
|
机器学习/深度学习 自然语言处理 数据处理
大模型开发:描述长短期记忆网络(LSTM)和它们在序列数据上的应用。
LSTM,一种RNN变体,设计用于解决RNN处理长期依赖的难题。其核心在于门控机制(输入、遗忘、输出门)和长期记忆单元(细胞状态),能有效捕捉序列数据的长期依赖,广泛应用于语言模型、机器翻译等领域。然而,LSTM也存在计算复杂度高、解释性差和数据依赖性强等问题,需要通过优化和增强策略来改进。
|
24天前
|
机器学习/深度学习
大模型开发:解释卷积神经网络(CNN)是如何在图像识别任务中工作的。
**CNN图像识别摘要:** CNN通过卷积层提取图像局部特征,池化层减小尺寸并保持关键信息,全连接层整合特征,最后用Softmax等分类器进行识别。自动学习与空间处理能力使其在图像识别中表现出色。
24 2
|
1月前
|
网络协议 C++
C++ Qt开发:QTcpSocket网络通信组件
`QTcpSocket`和`QTcpServer`是Qt中用于实现基于TCP(Transmission Control Protocol)通信的两个关键类。TCP是一种面向连接的协议,它提供可靠的、双向的、面向字节流的通信。这两个类允许Qt应用程序在网络上建立客户端和服务器之间的连接。Qt 是一个跨平台C++图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍如何运用`QTcpSocket`组件实现基于TCP的网络通信功能。
37 8
C++ Qt开发:QTcpSocket网络通信组件
|
30天前
|
存储 网络安全 C++
C++ Qt开发:QUdpSocket网络通信组件
Qt 是一个跨平台C++图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍如何运用`QUdpSocket`组件实现基于UDP的网络通信功能。与`QTcpSocket`组件功能类似,`QUdpSocket`组件是 Qt 中用于实现用户数据报协议(UDP,User Datagram Protocol)通信的类。UDP 是一种无连接的、不可靠的数据传输协议,它不保证数据包的顺序和可靠性,但具有低延迟和简单的特点。
19 0
C++ Qt开发:QUdpSocket网络通信组件
|
30天前
|
机器学习/深度学习 算法框架/工具 Python
如何使用Python的Keras库构建神经网络模型?
如何使用Python的Keras库构建神经网络模型?
8 0
|
1月前
|
缓存 网络安全 调度
C++ Qt开发:QNetworkAccessManager网络接口组件
Qt 是一个跨平台C++图形界面开发库,利用Qt可以快速开发跨平台窗体应用程序,在Qt中我们可以通过拖拽的方式将不同组件放到指定的位置,实现图形化开发极大的方便了开发效率,本章将重点介绍如何运用`QNetworkAccessManager`组件实现Web网页访问。QNetworkAccessManager是Qt网络模块中的关键类,用于管理网络访问和请求。作为一个网络请求的调度中心,它为Qt应用程序提供了发送和接收各种类型的网络请求的能力,包括常见的GET、POST、PUT、DELETE等。这个模块的核心功能在于通过处理`QNetworkReply`和`QNetworkRequest`来实现
21 0
C++ Qt开发:QNetworkAccessManager网络接口组件
|
1月前
|
监控 C++ 索引
C++ Qt开发:QNetworkInterface网络接口组件
在Qt网络编程中,`QNetworkInterface`是一个强大的类,提供了获取本地网络接口信息的能力。通过`QNetworkInterface`,可以轻松地获取有关网络接口的信息,包括接口的名称、硬件地址、IP地址和子网掩码等。这个类对于需要获取本地网络环境信息的应用程序特别有用,例如网络配置工具、网络监控程序等。`QNetworkInterface`通过提供一致而易于使用的接口,使得网络编程中的任务更加简便和可靠。
26 4
C++ Qt开发:QNetworkInterface网络接口组件
|
1月前
|
JSON Go API
Go语言网络编程:HTTP客户端开发实战
【2月更文挑战第12天】本文将深入探讨使用Go语言开发HTTP客户端的技术细节,包括发送GET和POST请求、处理响应、错误处理、设置请求头、使用Cookie等方面。通过实例演示和代码解析,帮助读者掌握构建高效、可靠的HTTP客户端的关键技术。
|
2月前
|
消息中间件 机器学习/深度学习 安全
盘点网络安全开发中的缩写——M
盘点网络安全开发中的缩写——M
24 0
|
2月前
|
安全 物联网 测试技术
盘点网络安全开发中的缩写——I
盘点网络安全开发中的缩写——I
32 0

热门文章

最新文章