-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
23 lines (18 loc) · 824 Bytes
/
data.py
File metadata and controls
23 lines (18 loc) · 824 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch.utils.data import DataLoader
def get_mnist_dls(data_dir, batch_size, n_cpus):
transformer = T.Compose(
[T.ToTensor(), T.Normalize(mean=0.5, std=0.5)]
)
train_ds = MNIST(root=data_dir, train=True, download=True, transform=transformer)
test_ds = MNIST(root=data_dir, train=False, download=True, transform=transformer)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=n_cpus, pin_memory=True, drop_last=True,
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=n_cpus, pin_memory=False, drop_last=False,
)
return train_dl, test_dl
if __name__ == "__main__":
data_dir = "/Users/jongbeomkim/Documents/datasets"