Vengineerの妄想(準備期間)

人生は短いけど、長いです。人生を楽しみましょう!

Nervana neon で BNN


最近のDNN関連では、Binarized Neural Netoworkに関する論文や実装が多いですよね。
でも実装コードってあんまりないんですよ。

いろいろと調べてみると、Intelに買収されたNervanaのneonの実装があるのを知りました。



そして、実装は、ここ

train.pyの中には、BinaryAffineというレイヤーを使っているみたい。

layer.pyにBinaryLinerというクラスがありました。
引用
class BinaryLinear(Linear):

    """
    A binary fully connected layer implemented as the dot product of inputs
    and binarized weights.
    Arguments:
        nout (int, tuple): Desired size or shape of layer output
        init (Initializer, optional): Initializer object to use for
            initializing layer weights
        name (str, optional): Layer name. Defaults to "BinaryLinearLayer"
    """

    def __str__(self):
        return "BinaryLinear Layer '%s': %d inputs, %d outputs" % (
               self.name, self.nin, self.nout)

    def allocate(self, shared_outputs=None):
        super(BinaryLinear, self).allocate(shared_outputs)
        self.Wb = self.be.empty_like(self.W)

    def fprop(self, inputs, inference=False, beta=0.0):
        self.inputs = inputs
        self.be.binarize(self.W, self.Wb, stochastic=False)

        not_binarized = self.be.zeros(self.inputs.shape)
        not_binarized[:] = self.be.not_equal(self.be.absolute(self.inputs), 1)
        if np.any(not_binarized.get()):
            gemm = self.be.compound_dot
        else:
            gemm = self.be.xnor_compound_dot

        if self.actual_bsz is None and self.actual_seq_len is None:
            gemm(A=self.Wb, B=self.inputs, C=self.outputs, beta=beta,
                 bsum=self.batch_sum)
        else:
            bsz = self.be.bsz if self.actual_bsz is None else self.actual_bsz
            steps = self.nsteps if self.actual_seq_len is None else self.actual_seq_len

            gemm(A=self.Wb,
                 B=self.inputs[:, :bsz * steps],
                 C=self.outputs[:, :bsz * steps],
                 beta=beta,
                 bsum=self.batch_sum)

        return self.outputs

    def bprop(self, error, alpha=1.0, beta=0.0):
        if self.deltas:
            self.be.compound_dot(A=self.Wb.T, B=error, C=self.deltas, alpha=alpha, beta=beta)
        self.be.compound_dot(A=error, B=self.inputs.T, C=self.dW)
        return self.deltas


gemmの部分で、
        if np.any(not_binarized.get()):
            gemm = self.be.compound_dot
        else:
            gemm = self.be.xnor_compound_dot
のように、binarizedかどうかを調べて、binarizedの場合は、xnor_compound_dotを行っている