Source code for habitat.analysis.mlp.dataset

import numpy as np
import torch

from torch.utils.data import Dataset

from habitat.analysis.mlp.dataset_process import get_dataset


[docs]class HabitatDataset(Dataset): def __init__(self, dataset_path, features): self.x, self.y = get_dataset(dataset_path, features) # input normalization self.x = np.array(self.x) self.mu = np.mean(self.x, axis=0) self.sigma = np.std(self.x, axis=0) self.x = np.divide(np.subtract(self.x, self.mu), self.sigma)
[docs] def __len__(self): return len(self.y)
[docs] def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() return torch.from_numpy(np.array(self.x[idx]).astype(np.float32), ), float(self.y[idx])