| | import os |
| | from models.module import DiffGeolocalizer |
| | import hydra |
| | import wandb |
| | from os.path import isfile, join |
| | from shutil import copyfile |
| |
|
| | import torch |
| |
|
| | from omegaconf import OmegaConf |
| | from omegaconf import open_dict |
| | from hydra.core.hydra_config import HydraConfig |
| | from hydra.utils import instantiate |
| | from pytorch_lightning.callbacks import LearningRateMonitor |
| | from lightning_fabric.utilities.rank_zero import _get_rank |
| |
|
| | from models.module import DiffGeolocalizer |
| |
|
| | torch.set_float32_matmul_precision("high") |
| |
|
| | |
| | |
| | |
| | OmegaConf.register_new_resolver("eval", eval) |
| |
|
| |
|
| | def load_model(cfg, dict_config, wandb_id): |
| | logger = instantiate(cfg.logger, id=open(wandb_id, "r").read(), resume="allow") |
| | model = DiffGeolocalizer.load_from_checkpoint(cfg.checkpoint, cfg=cfg.model) |
| | trainer = instantiate(cfg.trainer, strategy=cfg.trainer.strategy, logger=logger) |
| | return trainer, model |
| |
|
| |
|
| | def hydra_boilerplate(cfg): |
| | dict_config = OmegaConf.to_container(cfg, resolve=True) |
| | trainer, model = load_model(cfg, dict_config, cfg.wandb_id) |
| | return trainer, model |
| |
|
| |
|
| | import copy |
| |
|
| |
|
| | def generate_datamodules(cfg_): |
| | for f in os.listdir(cfg_.test_dir): |
| | cfg = copy.deepcopy(cfg_) |
| | |
| | with open_dict(cfg): |
| | cfg_new = OmegaConf.load(join(cfg.test_dir, f)) |
| | cfg.datamodule = cfg_new.datamodule |
| | cfg.dataset = cfg_new.dataset |
| | cfg.dataset.test_transform = cfg_.dataset.test_transform |
| |
|
| | datamodule = instantiate(cfg.datamodule) |
| | yield datamodule |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| |
|
| | sys.argv = ( |
| | [sys.argv[0]] |
| | + ["+pt_model_path=${hydra:runtime.config_sources}"] |
| | + sys.argv[1:] |
| | ) |
| |
|
| | @hydra.main(version_base=None) |
| | def main(cfg): |
| | |
| | with open_dict(cfg): |
| | path = cfg.pt_model_path[1]["path"] |
| | cfg.wandb_id = join(path, "wandb_id.txt") |
| | cfg.checkpoint = join(path, "last.ckpt") |
| | cfg.computer.devices = 1 |
| |
|
| | ( |
| | trainer, |
| | model, |
| | ) = hydra_boilerplate(cfg) |
| | for datamodule in generate_datamodules(cfg): |
| | model.datamodule = datamodule |
| | model.datamodule.setup() |
| | print("Testing on", datamodule.test_dataset.class_name) |
| | trainer.test(model, datamodule=datamodule) |
| |
|
| | main() |
| |
|