-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
78 lines (71 loc) · 2.6 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch.utils.data.dataset import Dataset
import numpy as np
import h5py
inpDir='/scratch/dsw310/CCA/data/DM+halos/'
outDir='/scratch/dsw310/CCA/data/smoothed/HI/'
IndList = []
for i in np.arange(0,3969*22):
IndList+=[i]
for i in np.arange(3969*22,3969*63):
if( ((i%3969)//63)<22):
IndList+=[i]
else:
if(i%63<22):
IndList+=[i]
class SimuData(Dataset):
def __init__(self,lIndex,hIndex,hod=0,aug=0,test=0):
self.datafiles = []
self.hod=hod
self.aug=aug
self.test=test
for i in np.arange(lIndex,hIndex):
self.datafiles+=[IndList[i]]
def __getitem__(self, index):
return get_mini_batch(self.datafiles[index],self.hod,self.aug,self.test)
def __len__(self):
return len(self.datafiles)
def get_mini_batch(ind,hod,aug,test):
inp=np.load(inpDir+str(ind)+'.npy')
if hod:
inp[1]=inp[1]/5.
else:
inp=np.split(inp,2,axis=0)[1]
if test:
if hod:
return torch.from_numpy(inp[0]).float(),torch.from_numpy(inp[1]).float(),ind
else:
return torch.from_numpy(inp).float(),ind
else:
out=np.load(outDir+str(ind)+'.npy')
if(aug==1):
if np.random.rand() < .5:
inp=inp[:,::-1,:,:]
out=out[:,::-1,:,:]
if np.random.rand() < .5:
inp=inp[:,:,::-1,:]
out=out[:,:,::-1,:]
if np.random.rand() < .5:
inp=inp[:,:,:,::-1]
out=out[:,:,:,::-1]
prand = np.random.rand()
if prand < 1./6:
inp = np.transpose(inp, axes = (0,2,3,1))
out = np.transpose(out, axes = (0,2,3,1))
elif prand < 2./6:
inp = np.transpose(inp, axes = (0,2,1,3))
out = np.transpose(out, axes = (0,2,1,3))
elif prand < 3./6:
inp = np.transpose(inp, axes = (0,1,3,2))
out = np.transpose(out, axes = (0,1,3,2))
elif prand < 4./6:
inp = np.transpose(inp, axes = (0,3,1,2))
out = np.transpose(out, axes = (0,3,1,2))
elif prand < 5./6:
inp = np.transpose(inp, axes = (0,3,2,1))
out = np.transpose(out, axes = (0,3,2,1))
if hod:
inp=np.split(inp,2,axis=0)
return torch.from_numpy(inp[0].copy()).float(),torch.from_numpy(inp[1].copy()).float(),torch.from_numpy(out.copy()).float()
else:
return torch.from_numpy(inp.copy()).float(),torch.from_numpy(out.copy()).float()