Skip to content

Commit

Permalink
url and fragment indexing!
Browse files Browse the repository at this point in the history
  • Loading branch information
ehayeshaiper committed Oct 27, 2024
1 parent 04b5978 commit 7a3cc41
Showing 1 changed file with 6 additions and 25 deletions.
31 changes: 6 additions & 25 deletions wids/wids.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def group_by_key(names):
key, ext = splitname(fname)
groups_dict[key].append(i)

# Convert the dictionary values to a list of lists for the desired output format
return list(groups_dict.values())
return groups_dict



Expand Down Expand Up @@ -471,36 +470,18 @@ def check_cache_misses(self):
)
)

def get_shard(self, index):
"""Get the shard and index within the shard corresponding to the given index."""
# Find the shard corresponding to the given index.
shard_idx = np.searchsorted(self.cum_lengths, index, side="right")

# Figure out which index within the shard corresponds to the
# given index.
if shard_idx == 0:
inner_idx = index
else:
inner_idx = index - self.cum_lengths[shard_idx - 1]
def __getitem__(self, url_index):
"""Return the sample corresponding to the given index."""

# Get the shard and return the corresponding element.
desc = self.shards[shard_idx]
url = desc["url"]
url, index = url_index
shard = self.cache.get_shard(url)
return shard, inner_idx, desc

def __getitem__(self, index):
"""Return the sample corresponding to the given index."""
shard, inner_idx, desc = self.get_shard(index)
sample = shard[inner_idx]
sample = shard[index]

# Check if we're missing the cache too often.
self.check_cache_misses()

sample["__dataset__"] = desc.get("dataset")
sample["__index__"] = index
sample["__shard__"] = desc["url"]
sample["__shardindex__"] = inner_idx
sample["__shard__"] = url

# Apply transformations
for transform in self.transformations:
Expand Down

0 comments on commit 7a3cc41

Please sign in to comment.