首页 » 人工智能 »

无视无惧numpy计算时的runtime warning...

2018年12月3日 / 15次阅读

打开支付宝首页,搜索“529018372”,即可领取红包!可重复领。

特色图片

之前总结过一篇在训练神经网络的时候,出现大量numpy的runtime warning的分析。今天这篇继续这个话题,我想说的时候,我们其实可以通过一些手段,来无视无惧这些warning。

1, 无惧sigmoid函数的np.exp函数出现的overflow warning:

>>> import numpy as np
>>> def sigmoid(z):
...     return 1.0/(1.0+np.exp(z))
... 
>>> z = np.arange(10).reshape(10,1)
>>> z
array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7],
       [8],
       [9]])
>>> a = sigmoid(z)
>>> a
array([[5.00000000e-01],
       [2.68941421e-01],
       [1.19202922e-01],
       [4.74258732e-02],
       [1.79862100e-02],
       [6.69285092e-03],
       [2.47262316e-03],
       [9.11051194e-04],
       [3.35350130e-04],
       [1.23394576e-04]])
>>> z[4] = 800
>>> z
array([[  0],
       [  1],
       [  2],
       [  3],
       [800],
       [  5],
       [  6],
       [  7],
       [  8],
       [  9]])
>>> b = sigmoid(z)
__main__:2: RuntimeWarning: overflow encountered in exp
>>> b
array([[5.00000000e-01],
       [2.68941421e-01],
       [1.19202922e-01],
       [4.74258732e-02],
       [0.00000000e+00],
       [6.69285092e-03],
       [2.47262316e-03],
       [9.11051194e-04],
       [3.35350130e-04],
       [1.23394576e-04]])
>>> b[4]
array([0.])
>>> 

因为出现np.exp overflow之后,计算的结果是对的,可以接受的。

2, 改写cross-entropy函数来计算total cost:

cross-entropy函数在计算total cost的时候,会调用np.log函数,这个函数的输入如果是0,就会出现一个divided by zero的warning,计算的结果就是inf。我们可以改写cross-entropy函数来规避这个问题,下面是我改写的版本:

def crossentropy_cost_x(y, a):
    """cross-entropy cost function for one x"""
    # return np.sum(-y*np.log(a)-(1-y)*np.log(1-a))
    # safe version. runtime warning can be ignored completely.
    # And total cost can be compute correctly.
    pos = np.argmax(y)
    x = 0.0
    bn = 1e-100  # arbitrary big number
    for i in range(y.shape[0]):
        if i == pos:
            r = -np.log(a[i])
            if np.isinf(r):
                r = -np.log(bn)
        if i != pos:
            r = -np.log(1-a[i])
            if np.isinf(r):
                r = -np.log(bn)
        x += r
    return x

注意:

(1) 以上改写还规避了另外一个warning:在计算y*np.log(a)的时候,如果y是0,a也是0,计算的结果是nan。

(2) 以上改写使用了一个arbitrary number, 1e-100,用来模拟一个比较大的cost值;因此,用这种改写得到的cost值不值准确的,但是总比计算出inf来的好一点,至少还有参考价值。

(3) 这种改写的方式,个人认为可以推广,numpy还有一个函数,np.isnan,用来判断是否为nan。

3, 使用numba

使用numba加速之后,我发现我再也没有见过np.exp的overflow了,似乎numba对某些warning有抑制作用。当然,这种抑制并不影响最后的结果。

4, 最后,良好设计和配置的神经网络,天然不会出现这些问题!

本文链接:http://www.maixj.net/ai/numpy-runtime-warning-19458
云上小悟 麦新杰(QQ:1093023102)

评论是美德

无力满足评论实名制,评论对非实名注册用户关闭,有事QQ:1093023102.


前一篇:
后一篇:

栏目精选

云上小悟,麦新杰的独立博客

Ctrl+D 收藏本页

栏目

AD

ppdai

©Copyright 麦新杰 Since 2014 云上小悟独立博客版权所有 备案号:苏ICP备14045477号-1。云上小悟网站部分内容来源于网络,转载目的是为了整合信息,收藏学习,服务大家,有些转载内容也难以判断是否有侵权问题,如果侵犯了您的权益,请及时联系站长,我会立即删除。

网站二维码
go to top