提交 4049f4a5 编写于 作者: H hypox64

0.14142

上级 d9aef5a2
import os import os
import csv import csv
import numpy as np import numpy as np
import random
# import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
# load description_txt # load description_txt
...@@ -40,6 +41,11 @@ for i in range(len(colon_indexs)-1): ...@@ -40,6 +41,11 @@ for i in range(len(colon_indexs)-1):
descriptions.append(mapping) descriptions.append(mapping)
# print(descriptions) # print(descriptions)
def match_random(a,b):
state = np.random.get_state()
np.random.shuffle(a)
np.random.set_state(state)
np.random.shuffle(b)
def normlize(npdata,justprice = False): def normlize(npdata,justprice = False):
_min = np.min(npdata) _min = np.min(npdata)
......
...@@ -28,7 +28,7 @@ def RMSE(records_real,records_predict): ...@@ -28,7 +28,7 @@ def RMSE(records_real,records_predict):
return None return None
def main(): def main():
my_price = load_submission('./result/0.04145_0.15960.csv') my_price = load_submission('./datasets/sample_submission.csv')
print(eval_test(my_price)) print(eval_test(my_price))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
此差异已折叠。
...@@ -4,22 +4,24 @@ import dataloader ...@@ -4,22 +4,24 @@ import dataloader
import model import model
import evaluation import evaluation
from torch import nn, optim from torch import nn, optim
import time
#parameter #parameter
LR = 0.0001 LR = 0.0001
EPOCHS = 100 EPOCHS = 1000
BATCHSIZE = 1 BATCHSIZE = 64
CONTINUE = False CONTINUE = False
use_gpu = True use_gpu = True
SAVE_FRE = 5 SAVE_FRE = 5
#train 0:1200 dev 1200:1460 test 1460: 2919
#load data #load data
train_desc,train_price,test_desc = dataloader.load_all() train_desc,train_price,test_desc = dataloader.load_all()
train_desc.tolist()
train_price.tolist()
#def network #def network
net = model.Linear(79,1024,1) net = model.Linear(79,256,1)
print(net) print(net)
if CONTINUE: if CONTINUE:
...@@ -32,13 +34,16 @@ if use_gpu: ...@@ -32,13 +34,16 @@ if use_gpu:
optimizer = torch.optim.Adam(net.parameters(), lr=LR ) optimizer = torch.optim.Adam(net.parameters(), lr=LR )
criterion = nn.MSELoss() criterion = nn.MSELoss()
test_loss_list = []
for epoch in range(EPOCHS): for epoch in range(EPOCHS):
print('Epoch {}/{}.'.format(epoch + 1, EPOCHS)) print('Epoch {}/{}.'.format(epoch + 1, EPOCHS))
t1 = time.time()
net.train() net.train()
price_pres = [] price_pres = []
price_trues = [] price_trues = []
dataloader.match_random(train_desc, train_price)
for i in range(int(len(train_desc)/BATCHSIZE)): for i in range(int(len(train_desc)/BATCHSIZE)):
desc = np.zeros((BATCHSIZE,79), dtype=np.float32) desc = np.zeros((BATCHSIZE,79), dtype=np.float32)
price = np.zeros((BATCHSIZE,1), dtype=np.float32) price = np.zeros((BATCHSIZE,1), dtype=np.float32)
...@@ -61,8 +66,6 @@ for epoch in range(EPOCHS): ...@@ -61,8 +66,6 @@ for epoch in range(EPOCHS):
price_pres.append(dataloader.convert2price(price_pre.cpu().detach().numpy()[j][0])) price_pres.append(dataloader.convert2price(price_pre.cpu().detach().numpy()[j][0]))
train_loss = evaluation.RMSE(price_trues,price_pres) train_loss = evaluation.RMSE(price_trues,price_pres)
net.eval() net.eval()
price_pres = [] price_pres = []
for i in range(len(test_desc)): for i in range(len(test_desc)):
...@@ -72,8 +75,14 @@ for epoch in range(EPOCHS): ...@@ -72,8 +75,14 @@ for epoch in range(EPOCHS):
price_pres.append(dataloader.convert2price(price_pre.cpu().detach().numpy()[0][0])) price_pres.append(dataloader.convert2price(price_pre.cpu().detach().numpy()[0][0]))
test_loss = evaluation.eval_test(price_pres) test_loss = evaluation.eval_test(price_pres)
test_loss_list.append(test_loss)
dataloader.write_csv(price_pres, './result/result_epoch'+str(epoch+1)+'.csv')
dataloader.write_csv(price_pres, './result_epoch'+str(epoch+1)+'.csv') t2 = time.time()
print('--- Epoch train_loss:','%.6f'%train_loss,' test_loss:','%.6f'%test_loss,' cost time:','%.3f'%(t2-t1),'s')
t1 = time.time()
print('--- Epoch train_loss:',train_loss,' test_loss:',test_loss) min_loss = min(test_loss_list)
index_epoch = test_loss_list.index(min_loss)
print('\nmin_loss:',min_loss,'epoch:',index_epoch+1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册