提交 c42d85e9 编写于 作者: M mamingjie-China

support for python2

上级 e47ffeec
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from x2paddle.core.graph import GraphNode from x2paddle.core.graph import GraphNode
import collections
from x2paddle.core.util import * from x2paddle.core.util import *
import collections
import six
class Layer(object): class Layer(object):
...@@ -28,7 +29,7 @@ class Layer(object): ...@@ -28,7 +29,7 @@ class Layer(object):
def get_code(self): def get_code(self):
layer_code = "" layer_code = ""
if self.output is not None: if self.output is not None:
if isinstance(self.output, str): if isinstance(self.output, six.string_types):
layer_code = self.output + " = " layer_code = self.output + " = "
else: else:
layer_code = self.output.layer_name + " = " layer_code = self.output.layer_name + " = "
...@@ -47,7 +48,7 @@ class Layer(object): ...@@ -47,7 +48,7 @@ class Layer(object):
"[{}]".format(input.index) + ", ") "[{}]".format(input.index) + ", ")
else: else:
in_list += (input.layer_name + ", ") in_list += (input.layer_name + ", ")
elif isinstance(input, str): elif isinstance(input, six.string_types):
in_list += (input + ", ") in_list += (input + ", ")
else: else:
raise Exception( raise Exception(
...@@ -72,7 +73,7 @@ class Layer(object): ...@@ -72,7 +73,7 @@ class Layer(object):
"[{}]".format(self.inputs.index) + ", ") "[{}]".format(self.inputs.index) + ", ")
else: else:
layer_code += (self.inputs.layer_name + ", ") layer_code += (self.inputs.layer_name + ", ")
elif isinstance(self.inputs, str): elif isinstance(self.inputs, six.string_types):
layer_code += (self.inputs + ", ") layer_code += (self.inputs + ", ")
else: else:
raise Exception("Unknown type of inputs.") raise Exception("Unknown type of inputs.")
...@@ -119,6 +120,6 @@ class FluidCode(object): ...@@ -119,6 +120,6 @@ class FluidCode(object):
for layer in self.layers: for layer in self.layers:
if isinstance(layer, Layer): if isinstance(layer, Layer):
codes.append(layer.get_code()) codes.append(layer.get_code())
elif isinstance(layer, str): elif isinstance(layer, six.string_types):
codes.append(layer) codes.append(layer)
return codes return codes
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
from __future__ import division
import collections import collections
import copy as cp import copy as cp
......
...@@ -236,11 +236,7 @@ class CaffeDecoder(object): ...@@ -236,11 +236,7 @@ class CaffeDecoder(object):
data.MergeFromString(open(self.model_path, 'rb').read()) data.MergeFromString(open(self.model_path, 'rb').read())
pair = lambda layer: (layer.name, self.normalize_pb_data(layer)) pair = lambda layer: (layer.name, self.normalize_pb_data(layer))
layers = data.layers or data.layer layers = data.layers or data.layer
import time
start = time.time()
self.params = [pair(layer) for layer in layers if layer.blobs] self.params = [pair(layer) for layer in layers if layer.blobs]
end = time.time()
print('cost:', str(end - start))
def normalize_pb_data(self, layer): def normalize_pb_data(self, layer):
transformed = [] transformed = []
......
...@@ -94,7 +94,7 @@ class ONNXOpMapper(OpMapper): ...@@ -94,7 +94,7 @@ class ONNXOpMapper(OpMapper):
print(op) print(op)
return False return False
def directly_map(self, node, *args, name='', **kwargs): def directly_map(self, node, name='', *args, **kwargs):
inputs = node.layer.input inputs = node.layer.input
outputs = node.layer.output outputs = node.layer.output
op_type = node.layer_type op_type = node.layer_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册