提交 c24059bd 编写于 作者: W willzhang4a58

fix bug : append to record

上级 6db713e5
dlnet_filepath: "./net.prototxt"
resource_filepath: "./resource.prototxt"
placement_filepath: "./placement.prototxt"
model_load_snapshot_path: ""
model_load_snapshot_path: "/willzhang/snapshots/snapshot_10"
piece_size: 1000
default_data_type: kFloat
use_async_cpu_stream: false
max_data_id_length: 0
max_data_id_length: 16
global_fs_conf {
hdfs_conf {
namenode: "hdfs://192.168.1.11:9000"
......
......@@ -27,31 +27,23 @@ op {
op {
name: "conv"
model_load_dir: "/willzhang/snapshots/snapshot_3/conv"
convolution_conf {
in: "feature/out"
out: "out"
out_num: 1
has_bias_term: true
pad_h: 0
pad_w: 0
kernel_h: 5
kernel_w: 5
stride_h: 1
stride_w: 1
dilation_h: 1
dilation_w: 1
}
}
op {
name: "ip10"
model_load_dir: "/willzhang/snapshots/snapshot_3/ip10"
innerproduct_conf {
in: "conv/out"
out: "out"
out_num: 10
has_bias_term: false
has_bias_term: true
}
}
......
dlnet_filepath: "./net.prototxt"
resource_filepath: "./resource.prototxt"
placement_filepath: "./placement.prototxt"
model_load_snapshot_path: ""
piece_size: 1000
default_data_type: kFloat
use_async_cpu_stream: false
piece_size: 100
global_fs_conf {
hdfs_conf {
namenode: "hdfs://192.168.1.11:9000"
......@@ -13,16 +10,15 @@ global_fs_conf {
train_conf {
num_of_pieces_in_batch: 10
model_save_snapshots_path: "/willzhang/snapshots"
num_of_batches_in_snapshot: 6
staleness: 0
total_batch_num: 15
num_of_batches_in_snapshot: 60
total_batch_num: 600
default_fill_conf {
gaussian_conf {
mean: 0.0
std: 0.1
}
}
piece_num_of_record_loss: 10
piece_num_of_record_loss: 100
normal_mdupdt_conf {
learning_rate: 0.01
}
......
......@@ -32,14 +32,8 @@ op {
out: "out"
out_num: 1
has_bias_term: true
pad_h: 0
pad_w: 0
kernel_h: 5
kernel_w: 5
stride_h: 1
stride_w: 1
dilation_h: 1
dilation_w: 1
}
}
......@@ -49,7 +43,7 @@ op {
in: "conv/out"
out: "out"
out_num: 10
has_bias_term: false
has_bias_term: true
}
}
......
......@@ -14,19 +14,11 @@ placement_group {
op_set {
op_name: "conv"
op_name: "ip10"
}
parallel_conf {
policy: kDataParallel
device_name: "192.168.1.11:0-3"
}
}
placement_group {
op_set {
op_name: "softmax_loss"
}
parallel_conf {
policy: kDataParallel
device_name: "192.168.1.11:0-3"
device_name: "192.168.1.13:0-3"
}
}
......@@ -7,6 +7,18 @@ std::string RuntimeCtx::GetCtrlAddr(int64_t machine_id) const {
return mchn.addr() + ":" + std::to_string(mchn.port());
}
PersistentOutStream* RuntimeCtx::GetPersistentOutStream(
const std::string& filepath) {
auto iter = filepath2ostream_.find(filepath);
if (iter != filepath2ostream_.end()) {
return iter->second.get();
} else {
auto ostream_ptr = new PersistentOutStream(GlobalFS(), filepath);
filepath2ostream_[filepath].reset(ostream_ptr);
return ostream_ptr;
}
}
RuntimeCtx::RuntimeCtx(const std::string& name) {
this_machine_id_ = IDMgr::Singleton()->MachineID4MachineName(name);
LOG(INFO) << "this machine name: " << name;
......
......@@ -4,6 +4,7 @@
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/persistence/persistent_in_stream.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
namespace oneflow {
......@@ -25,6 +26,8 @@ class RuntimeCtx final {
BlockingCounter& mut_active_actor_cnt() { return active_actor_cnt_; }
BlockingCounter& mut_inactive_actor_cnt() { return inactive_actor_cnt_; }
PersistentOutStream* GetPersistentOutStream(const std::string& filepath);
private:
RuntimeCtx(const std::string& name);
......@@ -34,6 +37,8 @@ class RuntimeCtx final {
BlockingCounter active_actor_cnt_;
BlockingCounter inactive_actor_cnt_;
HashMap<std::string, std::unique_ptr<PersistentOutStream>> filepath2ostream_;
};
} // namespace oneflow
......
#include "oneflow/core/kernel/record_kernel.h"
#include "oneflow/core/job/runtime_context.h"
namespace oneflow {
......@@ -7,8 +8,6 @@ namespace {
template<typename T>
void RecordBlobImpl(PersistentOutStream& out_stream, const Blob* blob) {
CHECK_EQ(GetDataType<T>::val, blob->data_type());
blob->shape().SerializeWithTextFormat(out_stream);
out_stream << '\n';
const T* dptr = blob->dptr<T>();
for (int64_t i = 0; i < blob->shape().At(0); ++i) {
if (blob->has_data_id()) {
......@@ -41,7 +40,7 @@ void RecordKernel::Forward(
std::function<Blob*(const std::string&)> BnInOp2Blob) const {
int64_t parallel_id = reinterpret_cast<int64_t>(kernel_ctx.other);
const std::string& root_path = op()->op_conf().record_conf().record_path();
OF_CALL_ONCE(root_path, GlobalFS()->CreateDirIfNotExist(root_path));
OF_CALL_ONCE(root_path, GlobalFS()->MakeEmptyDir(root_path));
for (const std::string& ibn : op()->input_bns()) {
const std::string& lbn = op()->Lbn4BnInOp(ibn);
const Blob* blob = BnInOp2Blob(ibn);
......@@ -55,9 +54,11 @@ void RecordKernel::Forward(
std::string bn_in_op_dir = JoinPath(op_dir, bn_in_op);
OF_CALL_ONCE(bn_in_op_dir, GlobalFS()->CreateDir(bn_in_op_dir));
std::string file_path =
JoinPath(bn_in_op_dir, "part_" + std::to_string(parallel_id));
PersistentOutStream out_stream(GlobalFS(), file_path);
RecordBlob(out_stream, blob);
JoinPath(bn_in_op_dir, "part-" + std::to_string(parallel_id));
auto out_stream =
RuntimeCtx::Singleton()->GetPersistentOutStream(file_path);
RecordBlob(*out_stream, blob);
out_stream->Flush();
});
}
}
......
......@@ -23,6 +23,11 @@ std::string FileSystem::TranslateName(const std::string& name) const {
return CleanPath(name);
}
void FileSystem::MakeEmptyDir(const std::string& dirname) {
if (IsDirectory(dirname)) { RecursivelyDeleteDir(dirname); }
CreateDir(dirname);
}
void FileSystem::RecursivelyDeleteDir(const std::string& dirname) {
CHECK(FileExists(dirname));
std::deque<std::string> dir_q; // Queue for the BFS
......
......@@ -124,6 +124,8 @@ class FileSystem {
// subdirectories.
virtual void RecursivelyCreateDir(const std::string& dirname);
void MakeEmptyDir(const std::string& dirname);
// Deletes the specified directory.
virtual void DeleteDir(const std::string& dirname) = 0;
......
......@@ -14,4 +14,6 @@ PersistentOutStream& PersistentOutStream::Write(const char* s, size_t n) {
return *this;
}
void PersistentOutStream::Flush() { file_->Flush(); }
} // namespace oneflow
......@@ -18,6 +18,8 @@ class PersistentOutStream final {
// Inserts the first n characters of the array pointed by s into the stream.
PersistentOutStream& Write(const char* s, size_t n);
void Flush();
private:
std::unique_ptr<fs::WritableFile> file_;
};
......
......@@ -9,12 +9,8 @@ SnapshotMgr::SnapshotMgr(const Plan& plan) {
num_of_model_blobs_ = 0;
if (JobDesc::Singleton()->is_train()) {
model_save_snapshots_path_ = JobDesc::Singleton()->md_save_snapshots_path();
OF_CALL_ONCE(model_save_snapshots_path_, {
if (GlobalFS()->IsDirectory(model_save_snapshots_path_)) {
GlobalFS()->RecursivelyDeleteDir(model_save_snapshots_path_);
}
GlobalFS()->CreateDir(model_save_snapshots_path_);
});
OF_CALL_ONCE(model_save_snapshots_path_,
GlobalFS()->MakeEmptyDir(model_save_snapshots_path_));
HashSet<std::string> model_blob_set;
for (const OperatorProto& op_proto : plan.op()) {
if (op_proto.op_conf().has_model_save_conf()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册