Source code for flagevalmm.server.server_dataset

from torch.utils.data import Dataset


[docs] class ServerDataset(Dataset): """ Get data from the server """
[docs] def __init__( self, task_name: str, task_manager, task_type: str = "vqa", ) -> None: # from flagevalmm.models.base_model_adapter import TaskManager self.task_manager = task_manager self.task_name = task_name self.task_type = task_type meta_info = self.task_manager.get_meta_info(task_name) self.datasetname = meta_info["name"] self.length: int = meta_info["length"]
def __len__(self) -> int: return self.length def __getitem__(self, index): data = self.get_data(index) question_id = data["question_id"] qs = data["question"] if data.get("video_path", None): data_path = data["video_path"] multi_modal_data = {"video": data_path} else: data_path = data["img_path"] multi_modal_data = {"image": data_path} return question_id, multi_modal_data, qs
[docs] def get_data(self, index: int): data = self.task_manager.get_data(self.task_name, index) return data