4

I am working on a time series problem. Different training time series data is stored in a large JSON file with the size of 30GB. In tensorflow I know how to use TF records. Is there a similar way in pytorch?

Amir
  • 16,067
  • 10
  • 80
  • 119
Shamane Siriwardhana
  • 3,951
  • 6
  • 33
  • 73

2 Answers2

10

I suppose IterableDataset (docs) is what you need, because:

  1. you probably want to traverse files without random access;
  2. number of samples in jsons is not pre-computed.

I've made a minimal usage example with an assumption that every line of dataset file is a json itself, but you can change the logic.

import json
from torch.utils.data import DataLoader, IterableDataset


class JsonDataset(IterableDataset):
    def __init__(self, files):
        self.files = files

    def __iter__(self):
        for json_file in self.files:
            with open(json_file) as f:
                for sample_line in f:
                    sample = json.loads(sample_line)
                    yield sample['x'], sample['time'], ...

...

dataset = JsonDataset(['data/1.json', 'data/2.json', ...])
dataloader = DataLoader(dataset, batch_size=32)

for batch in dataloader:
    y = model(batch)
roman
  • 1,061
  • 6
  • 14
  • 1
    A problem will arise if having more workers. every worker will start reading from the same file causing duplicates. – theo2021 Apr 08 '23 at 18:33
2

Generally, you do not need to change/overload the default data.Dataloader.

What you should look into is how to create a custom data.Dataset.
Once you have your own Dataset that knows how to extract item-by-item from the json file, you feed it do the "vanilla" data.Dataloader and all the batching/multi-processing etc, is done for you based on your dataset provided.

If, for example, you have a folder with several json files, each containing several examples, you can have a Dataset that looks like:

import bisect

class MyJsonsDataset(data.Dataset):
  def __init__(self, jfolder):
    super(MyJsonsDataset, self).__init__()
      self.filenames = []  # keep track of the jfiles you need to load
      self.cumulative_sizes = [0]  # keep track of number of examples viewed so far
      # this is not actually python code - just pseudo code for you to follow
      for each jsonfile in jfolder:
        self.filenames.append(jsonfile)
        l = number of examples in jsonfile
        self.cumulative_sizes.append(self.cumulative_sizes[-1] + l)
      # discard the first element 
      self.cumulative_sizes.pop(0)

  def __len__(self):
    return self.cumulative_sizes[-1]

  def __getitem__(self, idx):
    # first you need to know wich of the files holds the idx example
    jfile_idx = bisect.bisect_right(self.cumulative_sizes, idx)
    if jfile_idx == 0:
      sample_idx = idx
    else:
      sample_idx = idx - self.cumulative_sizes[jfile_idx - 1]
    # now you need to retrieve the `sample_idx` example from self.filenames[jfile_idx]
    return retrieved_example 

Shai
  • 111,146
  • 38
  • 238
  • 371