0%

Tensorflow-Keras可视化

学习深度学习的方法有很多,其中可视化是一个非常重要的方法。今天我就来简要的介绍一下如何对keras模型进行可视化。

通过tensorboard可视化

tensorboard 是一个最常用的可视化工具,可以查看训练过程中的loss和accuracy,还可以查看网络结构。

首先我们来看看tensorboard的基本使用方法。

  • 首先,安装tensorboard,命令如下:

    1
    pip install tensorflow-tensorboard
  • 以mnist为例,为mnist构造网络模型,代码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import tensorflow as tf

    #加载mnist数据集
    (train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    #归一化处理
    train_images, test_images = train_images / 255.0, test_images / 255.0

    #构造网络模型
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

    model.fit(train_images, train_labels, epochs=5)
    model.evaluate(test_images, test_labels)
  • 设置回调函数

    1
    2
    3
    4
    5
    6
    7
    8
    ...
    from keras_visualize import visualize
    ...
    model=...
    ...
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir="logs/fit/")
    model.fit(train_images, train_labels, epochs=5, validated_data=(test_images, test_labels), callbacks=[tensorboard])
    #model.evaluate(test_images, test_labels)
  • 运行tensorboard

    1
    tensorboard --logdir=log/fit

通过上面的操作之后,我们就可以看到mnist的训练过程了,如下图所示:

通过 keras_visualizer 实现可视化

tensorboard 是一个非常常用的可视化工具,但是对于深度学习来说,tensorboard 对网络架构的可视化效果并不好,所以,我推荐使用 keras_visualizer 来实现网络架构的可视化。

不过需要注意的是,从tensorflow2.16开始,默认使用最新的keras3,它与原来的keras2有很大的不同。而keras_visualizer 目前还不支持keras3,所以为了能够使用keras_visualizer,一定要选择安装tensorflow2.15及以下版本。

下面我们来看看如何使用 keras_visualizer 来实现网络架构的可视化的具体步骤:

  • 首先,安装依赖库 graphviz,命令如下:

    1
    sudo apt-get install graphviz

    该库用于绘制dot语言编写的网络结构图。

  • 接下来,安装 keras_visualizer,命令如下:

    1
    2
    pip install keras_visualizer
    pip install graphviz
  • 第三步,创建网络模型,代码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9

    ...
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
    ])
    ...
  • 最后,通过visualize函数绘制网络结构图,代码如下:

    1
    2
    3
    from keras_visualizer import visualize
    ...
    visualize(model, file_format='png', filename='model_visualization')

最终的显示结果如下图所示:

其它可视化工具

除了tensorboard、keras-visualizer,还有很多其它可视化工具,如:netron、ann_visualizer等,不过这些可视频工具都不同小异,与keras-visualizer差不多,因此我们只要掌握一种就可以了。

小结

本文向你介绍了两种重要的深度学习可视化工具,分别是tensorboard和keras-visualizer。这两种工具与keras结合的都非常好,操作简单方便,功能强大,为我们学习深度学习提供了良好的参考。

唯一需要注意的是keras-visualizer目前只支持到keras2, 对于keras3的版本,需要等待作者更新。

欢迎关注我的其它发布渠道