人工智能综合项目开发4--农业病虫害识别model_train.py
tf.keras.layers.Dense(256, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.001)),# 添加全连接层,256个神经元,使用ReLU激活函数,kernel_regularizer添加L2正则化,进一步防止过拟合。tf.keras.layers.Dropout(0.4),# 添加Dropou
·
# 基于InceptionV3的9分类图像模型训练脚本(可直接运行)
# 依赖:需提前安装tensorflow、matplotlib、numpy,且dataclean.py文件在同目录
# --------------------------
# 1、导入所需三方库
# --------------------------
# 导入InceptionV3预训练模型,用于图像特征提取
from tensorflow.keras.applications.inception_v3 import InceptionV3
import tensorflow as tf # 导入tensorflow深度学习框架,用于构建和训练模型
import matplotlib.pyplot as plt # 导入matplotlib的pyplot模块,用于数据可视化
import os # 导入os模块,用于文件和目录操作
import numpy as np # 导入numpy库,用于数值计算
# 导入ImageDataGenerator,用于图像数据增强
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 从dataclean.py文件导入预处理后的训练集和验证集数据
from dataclean import train_images, train_labels, val_images, val_labels
# --------------------------
# 2、构建InceptionV3模型架构
# --------------------------
# 定义InceptionV3模型要求的输入图像尺寸(宽度, 高度, 通道数)
# InceptionV3模型固定要求输入尺寸为299x299,3表示RGB三通道
input_shape = (299, 299, 3)
# 加载预训练的InceptionV3基础网络
# weights="imagenet"表示使用在ImageNet数据集上预训练的权重
# include_top=False表示不包含模型顶部的全连接层,只使用特征提取部分
# input_shape指定输入图像的尺寸
iv3_base = InceptionV3(
weights="imagenet",
include_top=False,
input_shape=input_shape
)
# 冻结预训练网络的所有层,这样可以保留预训练模型学到的通用图像特征
for layer in iv3_base.layers:
layer.trainable = False # 将每层的可训练属性设置为False,初始训练阶段不更新这些层的参数
# 构建完整的分类模型,使用Sequential顺序模型
model = tf.keras.models.Sequential([
iv3_base, # 添加预训练的InceptionV3基础网络作为特征提取器
tf.keras.layers.GlobalAveragePooling2D(),# 特征压缩层:添加全局平均池化2D层,对每个通道的所有空间位置取平均值,将特征图转换为一维向量,用于减少特征维度并保留通道信息
tf.keras.layers.Dropout(0.4), # 添加Dropout层,随机丢弃40%的神经元,抑制过拟合
tf.keras.layers.Dense(256, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.001)),# 添加全连接层,256个神经元,使用ReLU激活函数,kernel_regularizer添加L2正则化,进一步防止过拟合
tf.keras.layers.Dropout(0.4), # 再次添加Dropout层,增强抑制过拟合的效果
tf.keras.layers.Dense(9, activation="softmax"),# 输出层,9个神经元对应9个类别,使用softmax激活函数,softmax确保输出值为概率分布,所有类别概率之和为1
tf.keras.layers.Flatten() # 注:此层无实际作用,可保留原逻辑或后续删除
])
# --------------------------
# 3、模型编译(配置训练参数)
# --------------------------
model.summary() # 打印模型的详细结构信息,包括各层名称、输出形状和参数数量
# 编译模型,配置训练时需要的优化器、损失函数和评估指标
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=5e-5), # 使用RMSprop优化器,学习率设置为5e-5(较小的学习率有助于稳定训练)
loss=tf.keras.losses.SparseCategoricalCrossentropy(),# 损失函数使用SparseCategoricalCrossentropy,适用于标签为整数形式(而非one-hot编码)的多分类问题
metrics=['accuracy'] # 训练过程中监控的指标,这里使用准确率
)
# --------------------------
# 4、数据增强与生成器配置
# --------------------------
# 创建训练集数据增强生成器,通过对图像进行随机变换来扩充训练数据
train_datagen = ImageDataGenerator(
rescale=1/255.0, # 将像素值从0-255缩放到0-1范围,加速模型收敛
rotation_range=15, # 随机旋转图像,角度范围为±15度
width_shift_range=0.1, # 随机水平平移图像,平移范围为宽度的10%
height_shift_range=0.1, # 随机垂直平移图像,平移范围为高度的10%
horizontal_flip=True # 随机水平翻转图像
)
# 创建验证集数据生成器,仅进行归一化处理,不做数据增强
# 验证集用于评估模型性能,需要保持数据的真实性
val_datagen = ImageDataGenerator(rescale=1/255.0)
# 创建训练数据生成器,按批次生成增强后的训练数据
# batch_size=32表示每次生成32个样本
train_generator = train_datagen.flow(train_images, train_labels, batch_size=32)
# 创建验证数据生成器,按批次生成归一化后的验证数据
val_generator = val_datagen.flow(val_images, val_labels, batch_size=32)
# --------------------------
# 5、模型训练
# --------------------------
# 训练模型,返回训练历史记录(包含每轮的损失值和准确率)
history = model.fit(
train_generator, # 训练数据生成器
validation_data=val_generator, # 验证数据生成器
epochs=30, # 训练的总轮数,即整个训练集将被训练30次
verbose=1 # 训练过程的日志显示模式,1表示显示进度条和实时指标
)
# --------------------------
# 6、训练过程可视化
# --------------------------
# 绘制训练集准确率曲线
plt.plot(history.history['accuracy'], label='Train Accuracy')
# 绘制训练集损失值曲线
plt.plot(history.history['loss'], label='Train Loss')
# 添加图例,用于区分不同曲线
plt.legend(loc='best')
# 添加图表标题
plt.title('Model Accuracy & Loss')
# 显示图表
plt.show()
# --------------------------
# 7、模型保存
# --------------------------
# 检查是否存在model文件夹,如果不存在则创建
if not os.path.exists('./model'):
os.mkdir('./model')
# 将训练好的模型保存到指定路径,保存格式为H5
model.save('./model/model.h5')
# 打印模型保存成功的提示信息
print("模型已保存至 ./model/model.h5")

openvela 操作系统专为 AIoT 领域量身定制,以轻量化、标准兼容、安全性和高度可扩展性为核心特点。openvela 以其卓越的技术优势,已成为众多物联网设备和 AI 硬件的技术首选,涵盖了智能手表、运动手环、智能音箱、耳机、智能家居设备以及机器人等多个领域。
更多推荐

所有评论(0)