1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| from torch.autograd import Variable import argparse import torch from pysot.models.model_builder import ModelBuilder from pysot.core.config import cfg
parser = argparse.ArgumentParser(description='onnx demo') parser.add_argument('--config', type=str, help='config file') parser.add_argument('--snapshot', type=str, help='model name') args = parser.parse_args()
def main(): cfg.merge_from_file(args.config) device = torch.device('cpu')
model = ModelBuilder()
model.load_state_dict(torch.load(args.snapshot, map_location=lambda storage, loc: storage.cpu())) model.eval().to(device)
x_1 = Variable(torch.Tensor(1, 3, 127, 127)) x_2 = Variable(torch.Tensor(1, 3, 287, 287)) x = (x_1,x_2)
torch_out = torch.onnx.export(model, x, "onnx_test_2.onnx", export_params=True)
if __name__ == '__main__': main()
|