Commit 84b3c3bb authored by zhaoxiaomeng's avatar zhaoxiaomeng

update: add mfr cost time each batch of dataloader

parent 9302caa3
...@@ -144,9 +144,12 @@ class CustomPEKModel: ...@@ -144,9 +144,12 @@ class CustomPEKModel:
dataloader = DataLoader(dataset, batch_size=128, num_workers=0) dataloader = DataLoader(dataset, batch_size=128, num_workers=0)
mfr_res = [] mfr_res = []
for imgs in dataloader: for imgs in dataloader:
start = time.time()
imgs = imgs.to(self.device) imgs = imgs.to(self.device)
output = self.mfr_model.generate({'image': imgs}) output = self.mfr_model.generate({'image': imgs})
mfr_res.extend(output['pred_str']) mfr_res.extend(output['pred_str'])
cost = time.time() - start
logger.info(f"batch size: {len(imgs)}, cost time: {round(cost, 2)}")
for res, latex in zip(latex_filling_list, mfr_res): for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex) res['latex'] = latex_rm_whitespace(latex)
b = time.time() b = time.time()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment