博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow中的batch_normalization实现
阅读量:4663 次
发布时间:2019-06-09

本文共 1282 字,大约阅读时间需要 4 分钟。

  tensorflow中实现batch_normalization的函数主要有两个:

    1)tf.nn.moments

    2)tf.nn.batch_normalization

  tf.nn.moments主要是用来计算均值mean和方差variance的值,这两个值被用在之后的tf.nn.batch_normalization中

  tf.nn.moments(x, axis,...)

  主要有两个参数:输入的batchs数据;进行求均值和方差的维度axis,axis的值是一个列表,可以传入多个维度

  返回值:mean和variance

  tf.nn.batch_normalization(x, mean, variance, offset, scala, variance_epsilon)

  主要参数:输入的batchs数据;mean;variance;offset和scala,这两个参数是要学习的参数,所以只要给出初始值,一般offset=0,scala=1;variance_epsilon是为了保证variance为0时,除法仍然可行,设置为一个较小的值即可

  输出:bn处理后的数据

  具体代码如下:    

import tensorflow as tfimport numpy as npX = tf.constant(np.random.uniform(1, 10, size=(3, 3)), dtype=tf.float32)axis = list(range(len(X.get_shape()) - 1))mean, variance = tf.nn.moments(X, axis)print(axis)X_batch = tf.nn.batch_normalization(X, mean, variance, 0, 1, 0.001)init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    mean, variance, X_batch = sess.run([mean, variance, X_batch])    print(mean)    print(variance)    print(X_batch)输出:

axis: [0]

mean: [5.124098 3.0998185 4.723417 ]
variance: [3.7908943 1.7062012 3.8243492]
X_batch: [[-0.32879925 -1.3645337 0.39226937]
      [-1.0266179 0.36186576 -1.3726556 ]
      [ 1.355417 1.0026684 0.98038626]]

 

 

转载于:https://www.cnblogs.com/jiangxinyang/p/9394353.html

你可能感兴趣的文章
Windows10系统在VMware中安装CentOS7操作系统并实现图形化用户界面Gnome
查看>>
分布式文件系统
查看>>
线程同步工具 Semaphore类的基础使用
查看>>
Bug的等级程度(Blocker, Critical, Major, Minor/Trivial)及修复优先级
查看>>
js多图预览及上传功能
查看>>
Mac下安装ionic和cordova,并生成iOS项目
查看>>
caffe介绍
查看>>
MongoDB副本集和分片模式安装
查看>>
Spring Mvc Url和参数名称忽略大小写
查看>>
Python之常用模块学习(一)
查看>>
CSS------如何让大小不一样的div顶部对齐
查看>>
SOCKET.IO 前后端使用
查看>>
CodeForces 122G Lucky Array(一脸懵逼的树状数组)
查看>>
【开发实例】C#调用SAPI实现语音合成的两种方法
查看>>
Django实战(15):Django实现RESTful web service
查看>>
离散实验二
查看>>
使用sharepoint里Open with explorer功能
查看>>
通过模糊来弱化背景
查看>>
The Fourth Day
查看>>
NSString 比较(转)
查看>>