提交 8282e7f9 编写于 作者: S Sam Gross

Create torch.Tensor classes for all defined types

上级 058ccd21
......@@ -169,8 +169,8 @@ bool THPModule_isTensor(PyObject *obj)
PyObject * THPModule_setDefaultTensorType(PyObject *_unused, PyObject *type)
{
HANDLE_TH_ERRORS
torch::tensor::py_set_default_tensor_type(type);
THPDefaultTensorClass = type;
torch::tensor::set_default_tensor_type(type);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
......
......@@ -3,6 +3,7 @@
#include <structmember.h>
#include <mutex>
#include <pybind11/pybind11.h>
#include <sstream>
#include "torch/csrc/assertions.h"
#include "torch/csrc/Exceptions.h"
......@@ -11,6 +12,7 @@
#include "torch/csrc/autograd/generated/VariableType.h"
#include "torch/csrc/utils/python_strings.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_types.h"
namespace torch { namespace tensor {
......@@ -23,7 +25,7 @@ struct PyTensorType {
bool is_cuda;
bool is_sparse;
bool is_default;
std::string name;
char name[64];
};
static_assert(std::is_standard_layout<PyTensorType>::value, "PyTensorType must be standard layout");
......@@ -36,6 +38,9 @@ static void py_bind_tensor_types(const std::vector<PyTensorType>& tensor_types);
static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) {
HANDLE_TH_ERRORS
auto& tensor_type = *((PyTensorType*)type);
if (!tensor_type.aten_type) {
throw TypeError("type %s not available", tensor_type.name);
}
if (tensor_type.is_cuda) {
std::call_once(init_cuda_flag, []() {
pybind11::module::import("torch.cuda").attr("init")();
......@@ -117,20 +122,25 @@ static const char* get_module(Backend backend) {
}
}
static std::string get_name(Type& aten_type) {
std::string name = at::toString(aten_type.scalarType());
name += "Tensor";
std::string module_name = get_module(aten_type.backend());
return module_name + "." + name;
static std::string get_name(Backend backend, ScalarType scalarType) {
std::ostringstream ss;
ss << get_module(backend) << "." << at::toString(scalarType) << "Tensor";
return ss.str();
}
static void set_type(PyTensorType& type_obj, Type& aten_type) {
auto backend = aten_type.backend();
type_obj.aten_type = &aten_type;
static void set_type(PyTensorType& type_obj, Backend backend, ScalarType scalarType) {
auto baseType = globalContext().type_registry[static_cast<int>(backend)][static_cast<int>(scalarType)].get();
type_obj.aten_type = baseType ? torch::autograd::VariableType::getType(*baseType) : nullptr;
type_obj.is_cuda = backend == kCUDA || backend == kSparseCUDA;
type_obj.is_sparse = backend == kSparseCPU || backend == kSparseCUDA;
}
static void set_name(PyTensorType& type_obj, const std::string& name) {
size_t n = sizeof(type_obj.name);
strncpy(type_obj.name, name.c_str(), n);
type_obj.name[n - 1] = '\0';
}
static PyObject* get_variable_dict() {
auto autograd = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
if (!autograd) throw python_error();
......@@ -144,14 +154,20 @@ static PyObject* get_variable_dict() {
static std::vector<PyTensorType> tensor_types;
static void initialize_aten_types(std::vector<PyTensorType>& tensor_types) {
auto var_types = VariableType::allTypes();
tensor_types.resize(var_types.size() + 1);
for (size_t i = 0; i < var_types.size(); i++) {
set_type(tensor_types[i], *var_types[i]);
tensor_types[i].name = get_name(*var_types[i]);
// includes CUDA types even when PyTorch is not built with CUDA
auto declared_types = torch::utils::all_declared_types();
tensor_types.resize(declared_types.size() + 1);
for (size_t i = 0, end = declared_types.size(); i != end; i++) {
auto& tensor_type = tensor_types[i];
Backend backend = declared_types[i].first;
ScalarType scalar_type = declared_types[i].second;
set_type(tensor_type, backend, scalar_type);
set_name(tensor_type, get_name(backend, scalar_type));
}
set_type(tensor_types.back(), *VariableType::getType(CPU(kFloat)));
tensor_types.back().name = "torch.Tensor";
set_type(tensor_types.back(), kCPU, kFloat);
set_name(tensor_types.back(), "torch.Tensor");
tensor_types.back().is_default = true;
}
......@@ -173,8 +189,7 @@ void initialize_python_bindings(PyObject* module) {
// Initialize each Python type object torch.FloatTensor, torch.DoubleTensor,
// etc. and the "default" type object torch.Tensor.
for (auto& tensor_type : tensor_types) {
const char* name = tensor_type.name.c_str();
py_initialize_tensor_type(tensor_type.py_type, name, var_dict);
py_initialize_tensor_type(tensor_type.py_type, tensor_type.name, var_dict);
}
// The type object for torch.Tensor is at the end.
......@@ -194,9 +209,10 @@ static void py_bind_tensor_types(const std::vector<PyTensorType>& tensor_types)
if (!tensor_classes) throw python_error();
for (auto& tensor_type : tensor_types) {
auto idx = tensor_type.name.rfind(".");
auto type_name = tensor_type.name.substr(idx + 1);
auto module_name = tensor_type.name.substr(0, idx);
auto name = std::string(tensor_type.name);
auto idx = name.rfind(".");
auto type_name = name.substr(idx + 1);
auto module_name = name.substr(0, idx);
auto module_obj = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
if (!module_obj) throw python_error();
......@@ -221,12 +237,20 @@ static bool PyTensorType_Check(PyObject* obj) {
static at::Type* THPDefaultATenType;
void set_default_tensor_type(PyObject* type_obj) {
if (!PyTensorType_Check(type_obj)) {
void set_default_tensor_type(const at::Type& type) {
set_type(*default_tensor_type, type.backend(), type.scalarType());
THPDefaultATenType = default_tensor_type->aten_type;
}
void py_set_default_tensor_type(PyObject* obj) {
if (!PyTensorType_Check(obj)) {
throw TypeError("invalid type object");
}
set_type(*default_tensor_type, *((PyTensorType*)type_obj)->aten_type);
THPDefaultATenType = default_tensor_type->aten_type;
auto type = (PyTensorType*)obj;
if (!type->aten_type) {
throw TypeError("invalid type object");
}
set_default_tensor_type(*type->aten_type);
}
at::Type& get_default_tensor_type() {
......
......@@ -10,7 +10,10 @@ namespace torch { namespace tensor {
void initialize_python_bindings(PyObject* module);
// Sets the concrete type constructed by calls to torch.Tensor()
void set_default_tensor_type(PyObject* type_obj);
void set_default_tensor_type(const at::Type& type);
// Same as set_default_tensor_type() but takes a PyObject*
void py_set_default_tensor_type(PyObject* type_obj);
// Gets the ATen type object for the default tensor type. Note that the
// returned value will be a VariableType instance.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册