Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions pyterrier/_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load(cls, path: str, **kwargs : Any) -> 'Artifact':
ValueError: If no implementation is found that supports the artifact at the specified path.
"""
if cls is Artifact:
return _load(path)
return _load(path, **kwargs)
else:
# called SomeArtifact.load(path), so load this specific artifact
# TODO: add error message if not loaded
Expand Down Expand Up @@ -103,7 +103,7 @@ def from_url(cls, url: str, *, expected_sha256: Optional[str] = None, **kwargs :
headers.update(resolver_result.headers)

if parsed_url.scheme == '' and os.path.exists(url):
return cls.load(url) # already resolved, load it
return cls.load(url, **kwargs) # already resolved, load it

# buid local path
base_path = os.path.join(pt.io.pyterrier_home(), 'artifacts')
Expand Down Expand Up @@ -166,7 +166,7 @@ def from_url(cls, url: str, *, expected_sha256: Optional[str] = None, **kwargs :
print(f'extracting {member.path} [{pt.utils.byte_count_to_human_readable(member.size)}]')
tar_in.extract(member, dout, set_attrs=False)

return cls.load(path)
return cls.load(path, **kwargs)

def _package_files(self) -> Iterator[Tuple[str, Union[str, io.BytesIO]]]:
assert not isinstance(self.path, _NoPath), "package cannot be built for artifacts without a path"
Expand Down Expand Up @@ -312,19 +312,22 @@ def manage_maxsize(_: None):
return package_path

@classmethod
def from_dataset(cls, dataset: str, variant: str, *, expected_sha256: Optional[str] = None) -> 'Artifact':
def from_dataset(cls, dataset: str, variant: str, *, expected_sha256: Optional[str] = None, **kwargs) -> 'Artifact':
"""Load an artifact from a PyTerrier dataset.

Args:
dataset: The name of the dataset.
variant: The variant of the dataset.
expected_sha256: The expected SHA-256 hash of the artifact. If provided, the downloaded artifact will be
verified against this hash and an error will be raised if the hash does not match.
**kwargs: arguments that will be passed to the constructor of the artifact class
"""
return cls.from_hf(
repo='pyterrier/from-dataset',
branch=f'{dataset}.{variant}',
expected_sha256=expected_sha256)
expected_sha256=expected_sha256,
**kwargs
)

# -------------------------------------------------
# HuggingFace Datasets Integration
Expand Down Expand Up @@ -625,7 +628,7 @@ def from_p2p(cls, code: str, path: str, *, expected_sha256: Optional[str] = None
if (member.isfile() or member.isdir()) and pt.io.path_is_under_base(member.path, dout):
print(f'extracting {member.path} [{pt.utils.byte_count_to_human_readable(member.size)}]')
tar_in.extract(member, dout, set_attrs=False)
return cls.load(path)
return cls.load(path, **kwargs)

def to_p2p(self):
"""Send this artifact directly to a peer using Magic Wormhole.
Expand Down
Loading