get_data.py 471 B

1234567891011121314151617181920212223
  1. import random
  2. def generate_data(num):
  3. results=[]
  4. count=0
  5. while True:
  6. x1=random.uniform(0,10)
  7. x2=random.uniform(0,10)
  8. if abs(x1-x2)<2:
  9. continue
  10. y=1 if x1>x2 else 0
  11. results.append(str([[x1,x2],[y]]))
  12. count+=1
  13. if count>=num:
  14. break
  15. return results
  16. train_data=generate_data(500)
  17. test_data=generate_data(100)
  18. with open("train_data","w") as f :
  19. f.writelines("\n".join(train_data))
  20. with open("test_data","w") as f :
  21. f.writelines("\n".join(test_data))