| import time |
| import random |
| import argparse |
| import torch |
|
|
| from greedrl import Problem, Solution, Solver |
|
|
|
|
| def run(make_problem, mask_task_ratio=0.1): |
| random.seed(123) |
| torch.manual_seed(123) |
| problem_list = make_problem(1) |
|
|
| parser = argparse.ArgumentParser(description="") |
| parser.add_argument('--device', default=None, type=str) |
| parser.add_argument('--batch_size', default=32, type=int) |
| parser.add_argument('--agent_file', default=None, type=str) |
| parser.add_argument('--valid_steps', default=5, type=int) |
| parser.add_argument('--max_steps', default=10000000, type=int) |
|
|
| args, _ = parser.parse_known_args() |
| for k, v in args.__dict__.items(): |
| print("arg: {} = {}".format(k, v)) |
|
|
| |
| solver = Solver(device=args.device) |
| solver.train(args.agent_file, problem_list, problem_list, |
| batch_size=args.batch_size, valid_steps=args.valid_steps, max_steps=args.max_steps) |
| |
| solver = Solver(device=args.device) |
| if args.agent_file is not None: |
| solver.load_agent(args.agent_file) |
|
|
| print("solve ...") |
| start = time.time() |
| for problem in problem_list: |
| solver.solve(problem, batch_size=args.batch_size) |
| print("time: {}s".format(time.time() - start)) |
|
|