get_data.py 723 B

123456789101112131415161718192021222324252627
  1. # -*- encoding:utf-8 -*-
  2. import numpy as np
  3. def load_data(path='mnist.npz'):
  4. f = np.load(path)
  5. x_train, y_train = f['x_train'], f['y_train']
  6. x_test, y_test = f['x_test'], f['y_test']
  7. print(len(y_train))
  8. print(len(y_test))
  9. f.close()
  10. return (x_train, y_train), (x_test, y_test)
  11. print("开始读取数据")
  12. (train_X, train_y), (test_X, test_y) = load_data()
  13. print("读取结束")
  14. train_data=zip(train_X,train_y)
  15. train_data=[str([x.tolist(),y]) for [x,y] in train_data]
  16. test_data=zip(test_X,test_y)
  17. test_data=[str([x.tolist(),y]) for [x,y] in test_data]
  18. with open("train_data","w") as f:
  19. f.writelines("\n".join(train_data))
  20. with open("test_data","w") as f:
  21. f.writelines("\n".join(test_data))