Source code for pyhealth.data.rnn_reader


import os.path
from .base_dataset import BaseDataset
import pandas as pd
import numpy as np
import torch
import pickle

[docs]def time_series_get(fpath): data = pd.read_csv(fpath, sep=',') return data.values
[docs]class DatasetReader(BaseDataset): def __init__(self, data): BaseDataset.__init__(self, data) self.series_paths = data['x'] self.label_list = data['y'] # self.seq_len = 200 self.seq_len = max(data['l']) if max(data['l'])<200 else 200 def __getitem__(self, index): xpath = self.series_paths[index] s_data = time_series_get(xpath) data = s_data[:self.seq_len, 1: ] l, w = np.shape(data) time = s_data[:self.seq_len, 0] time[1:] = time[1:] - time[:-1] time[0] = 0. x_time = np.zeros(self.seq_len) x_time[:l] = time x_series = np.zeros([self.seq_len, w]) x_series[:l, :] = data x_mask = np.zeros(self.seq_len) x_mask[:l] = 1. x_mask_cur = np.zeros(self.seq_len) x_mask_cur[l-1] = 1. label = self.label_list[index] return {'X': np.array(x_series), 'M': np.array(x_mask), 'cur_M': np.array(x_mask_cur), 'Y': np.array(label), 'T':np.array(x_time)} def __len__(self): return len(self.series_paths)