政安晨:【Keras机器学习示例演绎】(三十三)—— 知识提炼

目录

设置

构建 Distiller() 类

创建学生和教师模型

准备数据集

培训教师

将教师模型蒸馏给学生模型

从头开始训练学生进行比较


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实施经典知识蒸馏。

知识蒸馏简介

知识蒸馏(Knowledge Distillation)是一种模型压缩程序,在这种程序中,一个小的(学生)模型被训练成与一个大的预训练(教师)模型相匹配。通过最小化损失函数,将知识从教师模型转移到学生模型,目的是匹配软化的教师对数以及地面实况标签。

通过在 softmax 中应用 "温度 "缩放函数来软化对数,从而有效地平滑概率分布,并揭示教师所学的类间关系。

设置

import os

import keras
from keras import layers
from keras import ops
import numpy as np

构建 Distiller() 类

自定义 Distiller() 类覆盖了编译、计算损失和调用模型方法。

为了使用蒸馏器,我们需要:

一个训练有素的教师模型
一个要训练的学生模型
一个关于学生预测与地面实况之间差异的学生损失函数
学生软预测与教师软标签之间差值的蒸馏损失函数以及温度
用于加权学生损失和蒸馏损失的 alpha 因子
学生的优化器和(可选)性能评估指标

在 compute_loss 方法中,我们对教师和学生进行前向传递,计算损失,并分别按 alpha 和 1 - alpha 对 student_loss 和 distillation_loss 进行加权。注意:只更新学生权重。

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)
        student_loss = self.student_loss_fn(y, y_pred)

        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)

        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss

    def call(self, x):
        return self.student(x)

创建学生和教师模型


首先,我们创建一个教师模型和一个较小的学生模型。这两个模型都是使用 Sequential() 创建的卷积神经网络,但也可以是任何 Keras 模型。

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

准备数据集


用于训练教师和提炼教师的数据集是 MNIST,如果选择合适的模型,该过程也可用于任何其他数据集,如 CIFAR-10。学生和教师都在训练集上进行训练,并在测试集上进行评估。

# Prepare the train and test dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

培训教师


在知识提炼过程中,我们假定教师是经过训练且固定不变的。因此,我们首先按照常规方法在训练集上训练教师模型。

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=5)
teacher.evaluate(x_test, y_test)
Epoch 1/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 0.2408 - sparse_categorical_accuracy: 0.9259
Epoch 2/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0912 - sparse_categorical_accuracy: 0.9726
Epoch 3/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.0758 - sparse_categorical_accuracy: 0.9777
Epoch 4/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9797
Epoch 5/5
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.0582 - sparse_categorical_accuracy: 0.9825
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0931 - sparse_categorical_accuracy: 0.9760

[0.09044107794761658, 0.978100061416626]

将教师模型蒸馏给学生模型


我们已经训练了教师模型,现在只需初始化 Distiller(student, teacher) 实例,用所需的损失、超参数和优化器编译()它,然后将教师模型蒸馏为学生模型。

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=3)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 3ms/step - loss: 1.8752 - sparse_categorical_accuracy: 0.7357
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9475
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 6s 3ms/step - loss: 0.0223 - sparse_categorical_accuracy: 0.9621
 313/313 ━━━━━━━━━━━━━━━━━━━━ 2s 4ms/step - loss: 0.0189 - sparse_categorical_accuracy: 0.9629

[0.017046602442860603, 0.969200074672699]

从头开始训练学生进行比较


我们还可以在没有教师的情况下,从头开始训练一个等效的学生模型,以评估通过知识提炼获得的性能增益。

# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)
Epoch 1/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 4s 1ms/step - loss: 0.5111 - sparse_categorical_accuracy: 0.8460
Epoch 2/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.1039 - sparse_categorical_accuracy: 0.9687
Epoch 3/3
 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 3s 1ms/step - loss: 0.0748 - sparse_categorical_accuracy: 0.9780
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0744 - sparse_categorical_accuracy: 0.9737

[0.0629437193274498, 0.9778000712394714]

如果教师训练了 5 个完整的历元,学生在教师的基础上提炼了 3 个完整的历元,那么在这个示例中,与从头开始训练相同的学生模型相比,甚至与教师本身相比,您都会体验到性能的提升。

教师的准确率应该在 97.6% 左右,从头开始训练的学生的准确率应该在 97.6% 左右,而经过提炼的学生的准确率应该在 98.1% 左右。删除或尝试不同的种子,使用不同的权重初始化。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/591241.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Spring Boot的医疗服务系统设计与实现

基于Spring Boot的医疗服务系统设计与实现 开发语言:Java框架:springbootJDK版本:JDK1.8数据库工具:Navicat11开发软件:eclipse/myeclipse/idea 系统部分展示 医疗服务系统首页界面图,公告信息、医疗地图…

快速幂笔记

快速幂即为快速求出一个数的幂&#xff0c;这样可以避免TLE&#xff08;超时&#xff09;的错误。 传送门&#xff1a;快速幂模板 前置知识&#xff1a; 1) 又 2) 代码&#xff1a; #include <bits/stdc.h> using namespace std; int quickPower(int a, int b) {int…

容斥原理以及Nim基础(异或,SG函数)

容斥原理&#xff1a; 容斥的复杂度为O&#xff08;2^m)&#xff0c;所以可以通过&#xff0c;对于实现&#xff0c;一共2^n-1种&#xff0c;我们可以用二进制来实现 下面是AC代码&#xff1a; #include<bits/stdc.h> using namespace std; typedef long long LL; cons…

uniapp 文字转语音(文字播报、语音合成)、震动提示插件 Ba-TTS

简介&#xff08;下载地址&#xff09; Ba-TTS 是一款uniapp语音合成&#xff08;tts&#xff09;插件&#xff0c;支持文本转语音&#xff08;无服务费&#xff09;&#xff0c;支持震动提示。 支持语音合成&#xff0c;文本转语音支持震动&#xff08;可自定义任意震动效果…

【云原生】Docker 实践(四):使用 Dockerfile 文件的综合案例

【Docker 实践】系列共包含以下几篇文章&#xff1a; Docker 实践&#xff08;一&#xff09;&#xff1a;在 Docker 中部署第一个应用Docker 实践&#xff08;二&#xff09;&#xff1a;什么是 Docker 的镜像Docker 实践&#xff08;三&#xff09;&#xff1a;使用 Dockerf…

【LinuxC语言】信号的基本概念与基本使用

文章目录 前言一、信号的概念二、信号的使用2.1 基本的信号类型2.2 signal函数 总结 前言 在Linux环境下&#xff0c;信号是一种用于通知进程发生了某种事件的机制。这些事件可能是由操作系统、其他进程或进程本身触发的。对于C语言编程者来说&#xff0c;理解信号的基本概念和…

36.Docker-Dockerfile自定义镜像

镜像结构 镜像是将应用程序及其需要的系统函数库、环境、配置、依赖打包而成。 镜像是分层机构&#xff0c;每一层都是一个layer BaseImage层&#xff1a;包含基本的系统函数库、环境变量、文件系统 EntryPoint:入口&#xff0c;是镜像中应用启动的命令 其他&#xff1a;在…

从0开始学习制作一个微信小程序 学习部分(6)组件与事件绑定

系列文章目录 学习篇第一篇我们讲了编译器下载&#xff0c;项目、环境建立、文件说明与简单操作&#xff1a;第一篇链接 第二、三篇分析了几个重要的配置json文件&#xff0c;是用于对小程序进行的切换页面、改变图标、控制是否能被搜索到等的操作第二篇链接、第三篇链接 第四…

QT的TcpServer

Server服务器端 QT版本5.6.1 界面设计 工程文件&#xff1a; 添加 network 模块 头文件引入TcpServer类和TcpSocket&#xff1a;QTcpServer和QTcpSocket #include <QTcpServer> #include <QTcpSocket>创建server对象并实例化&#xff1a; /*h文件中*/QTcpServer…

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统

基于SSM SpringBoot vue宾馆网上预订综合业务服务系统 系统功能 首页 图片轮播 宾馆信息 饮食美食 休闲娱乐 新闻资讯 论坛 留言板 登录注册 个人中心 后台管理 登录注册 个人中心 用户管理 客房登记管理 客房调整管理 休闲娱乐管理 类型信息管理 论坛管理 系统管理 新闻资讯…

Docker-Compose编排LNMP并部署WordPress

前言 随着云计算和容器化技术的快速发展&#xff0c;使用 Docker Compose 编排 LNMP 环境已经成为快速部署 Web 应用程序的一种流行方式。LNMP 环境由 Linux、Nginx、MySQL 和 PHP 组成&#xff0c;为运行 Web 应用提供了稳定的基础。本文将介绍如何通过 Docker Compose 编排 …

BUUCTF---misc---被偷走的文件

1、题目描述 2、下载附件&#xff0c;是一个流量包&#xff0c;拿去wireshark分析&#xff0c;依次点开流量&#xff0c;发现有个流量的内容显示flag.rar 3、接着在kali中分离出压缩包&#xff0c;使用下面命令&#xff0c;将压缩包&#xff0c;分离出放在out3文件夹中 4、在文…

Java -- (part21)

一.File类 1.概述 表示文件或者文件夹的路径抽象表示形式 2.静态成员 static String pathSeparator:路径分隔符:; static String separator:名称分隔符:\ 3.构造方法 File(String parent,String child) File(File parent,String child) Flie(String path) 4.方法 获…

在M1芯片安装鸿蒙闪退解决方法

在M1芯片安装鸿蒙闪退解决方法 前言下载鸿蒙系统安装完成后&#xff0c;在M1 Macos14上打开闪退解决办法接下来就是按照提示一步一步安装。 前言 重新安装macos系统后&#xff0c;再次下载鸿蒙开发软件&#xff0c;竟然发现打不开。 下载鸿蒙系统 下载地址&#xff1a;http…

Android使用kts上传aar到JitPack仓库

Android使用kts上传aar到JitPack 之前做过sdk开发&#xff0c;需要将仓库上传到maven、JitPack或JCenter,但是JCenter已停止维护&#xff0c;本文是讲解上传到JitPack的方式,使用KTS语法&#xff0c;记录使用过程中遇到的一些坑. 1.创建项目(library方式) 由于之前用鸿神的w…

每日OJ题_贪心算法二⑥_力扣409. 最长回文串

目录 力扣409. 最长回文串 解析代码 力扣409. 最长回文串 409. 最长回文串 难度 简单 给定一个包含大写字母和小写字母的字符串 s &#xff0c;返回 通过这些字母构造成的 最长的回文串 。 在构造过程中&#xff0c;请注意 区分大小写 。比如 "Aa" 不能当做一个…

【Java从入门到精通】Java继承

继承的概念 继承是java面向对象编程技术的一块基石&#xff0c;因为它允许创建分等级层次的类。 继承就是子类继承父类的特征和行为&#xff0c;使得子类对象&#xff08;实例&#xff09;具有父类的实例域和方法&#xff0c;或子类从父类继承方法&#xff0c;使得子类具有父…

Linux 麒麟系统安装

国产麒麟系统官网地址&#xff1a; https://www.openkylin.top/downloads/ 下载该镜像后&#xff0c;使用VMware部署一个虚拟机&#xff1a; 完成虚拟机创建。点击&#xff1a;“开启此虚拟机” 选择“试用试用开放麒麟而不安装&#xff08;T&#xff09;”&#xff0c;进入op…

React Native支持Tailwind CSS 语法

React Native支持Tailwind CSS 语法 一、前沿背景 回想下我们平时按照官方的规范进行书写样式是什么样&#xff1f; 是像下面这样&#xff1a; const MyComponent () > {return (<View><Text style{{ fontSize: 20 }}>开发者演示专用</Text></View…

【论文阅读】Tutorial on Diffusion Models for Imaging and Vision

1.The Basics: Variational Auto-Encoder 1.1 VAE Setting 自动编码器有一个输入变量x和一个潜在变量z Example. 获得图像的潜在表现并不是一件陌生的事情。回到jpeg压缩&#xff0c;使用离散余弦变换&#xff08;dct&#xff09;基φn对图像的底层图像/块进行编码。如果你给…
最新文章