Source code for habitat.analysis.mlp.train

import argparse
import random
import torch
import numpy

from habitat.analysis.mlp.mlp import RuntimePredictor


[docs]def main(): parser = argparse.ArgumentParser(description="MLP Training Script") parser.add_argument("operation", type=str) parser.add_argument("dataset_path", type=str) parser.add_argument("--layers", type=int, default=8) parser.add_argument("--layer_size", type=int, default=1024) parser.add_argument("--epochs", type=int, default=80) parser.add_argument("--seed", type=int, default=1337) args = parser.parse_args() # Ensure reproducibility random.seed(args.seed) torch.manual_seed(args.seed) numpy.random.seed(args.seed) predictor = RuntimePredictor(args.operation, args.layers, args.layer_size) predictor.train_with_dataset(args.dataset_path, epochs=args.epochs)
if __name__ == "__main__": main()