Very simple example of NetChain running in MxNet cause error

First define the net

SeedRandom[1234];
net = NetInitialize@NetChain[{5, 3}, "Input" -> 2];
{net[{1, 2}], net[{0, 0}], net[{0.1, 0.3}]}

{{1.27735, -1.21455, -1.02647}, {0., 0.,
0.}, {0.141054, -0.145882, -0.099945}}

Export it. (way from a great answer of How to export an MXNet?)

<< MXNetLink`;
<< NeuralNetworks`;
<< GeneralUtilities`;

jsonPath = "simple model-symbol.json";
paraPath = "simple model-0000.params";
Export[jsonPath, ToMXJSON[net][[1]], "String"]

f[str_] := 
 If[StringFreeQ[str, "Arrays"], str, StringReplace[
      StringSplit[str, ".Arrays."] /. {a_, b_} :> 
        StringJoin[{"arg:", a, "_", b}], {"Weights" -> "weight", "Biases" -> "bias"}]]

plan = ToMXPlan[net];
NDArrayExport[paraPath, NDArrayCreate /@ KeyMap[f, plan["ArgumentArrays"]]]

Python code to load the net

import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('simple model', 0)
mod = mx.mod.Module(symbol=sym)
(*At here,it throws error*)
mod.bind(for_training=False, data_shapes=[('data', (1,2))])
mod.set_params(arg_params, aux_params)

Then it throws error at mod = mx.mod.Module(symbol=sym)

ckA9l Very simple example of NetChain running in MxNet cause error

The JSON file(It define the net structure) is

d0p0U Very simple example of NetChain running in MxNet cause error

I really want to use net in my C++ project,a real problem.

Thank you!

Let’s block ads! (Why?)

Recent Questions – Mathematica Stack Exchange