Skip to content

Commit af9042e

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 872295675
1 parent fdd439b commit af9042e

5 files changed

Lines changed: 135 additions & 76 deletions

File tree

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,11 @@ async def async_save(
648648
use_compression=self._use_compression,
649649
use_zarr3=self._use_zarr3,
650650
)
651-
assert all(
652-
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
653-
)
651+
# TODO(b/425293362): Add validation for PathAwaitingCreation.
652+
if isinstance(directory, epath.Path):
653+
assert all(
654+
leaf.parent_dir is directory for leaf in jax.tree.leaves(param_infos)
655+
)
654656

655657
serialize_ops = [] # List of (coros -> List of futures)
656658
batch_requests = batched_serialization_requests(
@@ -710,7 +712,6 @@ async def async_save(
710712
future.CommitFutureAwaitingContractedSignals(
711713
self._write_metadata_after_commits(
712714
commit_futures,
713-
checkpoint_dir=directory,
714715
param_infos=param_infos,
715716
save_args=save_args,
716717
custom_metadata=custom_metadata,
@@ -790,7 +791,7 @@ async def _maybe_deserialize(
790791
flat_metadata = tree_utils.to_flat_dict(metadata)
791792
byte_limiter = limits.get_byte_limiter(self._restore_concurrent_bytes)
792793
param_infos = jax.tree.map(
793-
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
794+
lambda info: info.replace(byte_limiter=byte_limiter),
794795
param_infos,
795796
)
796797
batch_requests = batched_serialization_requests(
@@ -1156,9 +1157,8 @@ def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo:
11561157
f'No ArrayMetadata found for param_info: {param_info}, checkpoint'
11571158
f' directory: {checkpoint_dir}, process_index={process_index}.'
11581159
)
1159-
return dataclasses.replace(
1160-
param_info,
1161-
write_shape=array_metadatas_cache[param_info.name].write_shape,
1160+
return param_info.replace(
1161+
write_shape=array_metadatas_cache[param_info.name].write_shape
11621162
)
11631163

11641164
return jax.tree.map(update_param_info, param_infos)
@@ -1210,7 +1210,6 @@ async def _write_metadata_file(
12101210
async def _write_metadata_after_commits(
12111211
self,
12121212
commit_futures: List[future.Future],
1213-
checkpoint_dir: epath.Path,
12141213
*,
12151214
param_infos: PyTree,
12161215
save_args: PyTree,
@@ -1225,6 +1224,11 @@ async def _write_metadata_after_commits(
12251224
for commit_future in commit_futures:
12261225
await asyncio.to_thread(commit_future.result)
12271226

1227+
await asyncio.gather(
1228+
*[info.await_path_creation() for info in jax.tree.leaves(param_infos)]
1229+
)
1230+
checkpoint_dir = jax.tree.leaves(param_infos)[0].parent_dir
1231+
12281232
commit_time = time.time()
12291233
# `write_shape` is extracted from ArrayMetadata store saved during
12301234
# materialization of commit_futures. Then it is written to the pytree

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ async def _maybe_deserialize(
680680
"""Deserializes values or gets them from the aggregate file."""
681681
byte_limiter = limits.get_byte_limiter(self._restore_concurrent_bytes)
682682
param_infos = jax.tree.map(
683-
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
683+
lambda info: info.replace(byte_limiter=byte_limiter),
684684
param_infos,
685685
)
686686

checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ async def _async_serialize_shardings(
155155
for sharding, info in zip(shardings, infos):
156156
if sharding is None:
157157
continue
158+
await info.await_path_creation()
158159
if info.parent_dir is None:
159160
raise ValueError('parent_dir cannot be None')
160161
tspec_sharding = ts_utils.get_sharding_tensorstore_spec(
@@ -568,6 +569,7 @@ async def _async_serialize_replica_slices(
568569
'Replica_separate_folder is disabled as OCDBT is not enabled.',
569570
1,
570571
)
572+
await info.await_path_creation()
571573
array_write_spec = ts_utils.build_array_write_spec(
572574
info=info,
573575
arg=arg,
@@ -1003,6 +1005,7 @@ async def metadata(
10031005
open_ops = []
10041006
sharding_open_ops = []
10051007
shardings = []
1008+
await asyncio.gather(*[info.await_path_creation() for info in infos])
10061009
if infos[0].parent_dir is None:
10071010
raise ValueError('parent_dir cannot be None')
10081011
sharding_file_path = infos[0].parent_dir / _SHARDING_FILE_NAME
@@ -1197,6 +1200,7 @@ async def deserialize(
11971200
if args is None:
11981201
raise ValueError('Must provide ArrayRestoreArgs to restore as jax.Array.')
11991202
types.check_input_arguments(infos, args)
1203+
await asyncio.gather(*[info.await_path_creation() for info in infos])
12001204
if infos[0].parent_dir is None:
12011205
raise ValueError('parent_dir cannot be None')
12021206
sharding_file_path = infos[0].parent_dir / _SHARDING_FILE_NAME

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def metadata(
8585
) -> Sequence[value_metadata.ArrayMetadata]:
8686
open_ops = []
8787
for info in infos:
88+
await info.await_path_creation()
8889
# Use OCDBT flag from the existing checkpoint.
8990
use_ocdbt = info.is_ocdbt_checkpoint
9091
array_read_spec = ts_utils.build_array_read_spec(
@@ -122,6 +123,7 @@ async def _background_serialize(
122123
"""Serializes numpy arrays in a background thread."""
123124
write_coros = []
124125
for value, info, arg in zip(values, infos, args):
126+
await info.await_path_creation()
125127
array_write_spec = ts_utils.build_array_write_spec(
126128
info=info,
127129
arg=arg,
@@ -175,6 +177,7 @@ async def deserialize(
175177
types.check_input_arguments(infos, args)
176178
open_futures = []
177179
for info, arg in zip(infos, args):
180+
await info.await_path_creation()
178181
if not info.is_ocdbt_checkpoint:
179182
await ts_utils.assert_parameter_files_exist(
180183
info.parent_dir / info.name, self._metadata_key, info.use_zarr3
@@ -308,6 +311,8 @@ def typestr(self) -> str:
308311
async def metadata(
309312
self, infos: Sequence[types.ParamInfo]
310313
) -> Sequence[StringMetadata]:
314+
for info in infos:
315+
await info.await_path_creation()
311316
return [
312317
StringMetadata(name=info.name, directory=info.parent_dir)
313318
for info in infos
@@ -330,6 +335,7 @@ async def _background_serialize(
330335
info,
331336
value,
332337
) in zip(infos, values):
338+
await info.await_path_creation()
333339
tspec = self._get_json_tspec(info)
334340
if multihost.process_index() == 0:
335341
t = await ts.open(
@@ -368,6 +374,7 @@ async def deserialize(
368374
open_futures = []
369375

370376
for info in infos:
377+
await info.await_path_creation()
371378
tspec = self._get_json_tspec(info)
372379
open_future = ts.open(
373380
tspec, open=True, read=True, context=self._ts_context

checkpoint/orbax/checkpoint/_src/serialization/types.py

Lines changed: 110 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import abc
20+
import copy
2021
import dataclasses
2122
from typing import Any, Callable, Optional, Protocol, Sequence, Tuple
2223

@@ -30,6 +31,7 @@
3031
from orbax.checkpoint._src.metadata import empty_values
3132
from orbax.checkpoint._src.metadata import pytree_metadata_options as pytree_metadata_options_lib
3233
from orbax.checkpoint._src.metadata import value as value_metadata
34+
from orbax.checkpoint._src.path import types as path_types
3335
from orbax.checkpoint._src.serialization import limits
3436
import tensorstore as ts
3537

@@ -60,81 +62,123 @@ def check_input_arguments(*args):
6062
raise ValueError('Found input args with mismatched lengths.')
6163

6264

63-
@dataclasses.dataclass(kw_only=True)
6465
class ParamInfo:
6566
"""Information describing a parameter in a PyTree.
6667
6768
Note that ParamInfo is distinct from SaveArgs and RestoreArgs in that in
6869
represents information not provided by a user, and should be computed
6970
internally.
7071
71-
name:
72-
Name of the parameter.
73-
parent_dir:
74-
A path providing location where all files under the same checkpoint should
75-
be saved under. All `ParamInfo` provided to a given TypeHandler should have
76-
the same `parent_dir`. The parent_dir is assumed to be a directory.
77-
path:
78-
Do not provide directly. Automatically set to `parent_dir / name`.
79-
skip_deserialize:
80-
If specified, skips deserialization of the given parameter using the
81-
TypeHandler. This may be for multiple different reasons, including that the
82-
parameter may have been aggregated, or it will be unneeded after
83-
transformations. Note: this parameter is handled by PyTreeCheckpointHandler,
84-
so it is unnecessary for TypeHandler implementations to deal with it.
85-
byte_limiter:
86-
Object to limit the number of bytes that can be read or written in
87-
parallel.
88-
device_host_byte_limiter:
89-
Object to limit the number of bytes that can be transferred from device to
90-
host memory in parallel.
91-
is_ocdbt_checkpoint:
92-
Indicates whether the checkpoint path uses OCDBT format
93-
or not. Only used for restoration.
94-
use_compression:
95-
When True, turn on zstd compression. Default is True.
96-
use_zarr3:
97-
If True, use Zarr ver3 otherwise ver2.
98-
ocdbt_target_data_file_size:
99-
Specifies the target size (in bytes) of each OCDBT data file. If set to 0,
100-
data file size is not limited. If omitted (None), the TensorStore default
101-
is used.
102-
ts_context:
103-
Tensorstore context to use for reading/writing.
104-
value_typestr: stores the original value's typestr (from TypeHandler).
105-
Only required when saving.
106-
enable_pinned_host_transfer:
107-
True by default. If False, disables transfer to pinned host when copying
108-
from device to host, regardless of the presence of pinned host memory.
109-
raise_array_data_missing_error:
110-
Only used for restoring. See documentation in `tensorstore_utils.py`. Comes
111-
from tree metadata and should be the same across all parameters.
112-
write_shape:
113-
Shape of the array shard. Used in the subchunking context.
114-
is_prioritized_key_fn: See `IsPrioritizedKeyFn` definition.
72+
Attributes:
73+
name: Name of the parameter.
74+
parent_dir: A path providing location where all files under the same
75+
checkpoint should be saved under. All ParamInfo provided to a given
76+
TypeHandler should have the same parent_dir. The parent_dir is assumed to
77+
be a directory. Accessing this property raises ValueError if the
78+
underlying path is a PathAwaitingCreation that has not been resolved via
79+
await_path_creation().
80+
path: Automatically set to parent_dir / name. Accessing this property raises
81+
ValueError if the underlying path is a PathAwaitingCreation that has not
82+
been resolved via await_path_creation().
83+
skip_deserialize: If specified, skips deserialization of the given parameter
84+
using the TypeHandler.
85+
byte_limiter: Object to limit the number of bytes that can be read or
86+
written in parallel.
87+
device_host_byte_limiter: Object to limit the number of bytes that can be
88+
transferred from device to host memory in parallel.
89+
is_ocdbt_checkpoint: Indicates whether the checkpoint path uses OCDBT format
90+
or not. Only used for restoration.
91+
use_compression: When True, turn on zstd compression. Default is True.
92+
use_zarr3: If True, use Zarr ver3 otherwise ver2.
93+
ocdbt_target_data_file_size: Specifies the target size (in bytes) of each
94+
OCDBT data file.
95+
ts_context: Tensorstore context to use for reading/writing.
96+
value_typestr: Stores the original value's typestr (from TypeHandler).
97+
enable_pinned_host_transfer: If False, disables transfer to pinned host.
98+
raise_array_data_missing_error: Only used for restoring.
99+
write_shape: Shape of the array shard. Used in the subchunking context.
100+
is_prioritized_key_fn: See ``IsPrioritizedKeyFn`` definition.
101+
keypath: Tuple of keys identifying the parameter's position in the PyTree.
115102
"""
116103

117-
name: str
118-
parent_dir: epath.Path
119-
path: Optional[epath.Path] = None
120-
keypath: Optional[Tuple[Any, ...]] = None
121-
skip_deserialize: Optional[bool] = None
122-
byte_limiter: Optional[limits.ByteLimiter] = None
123-
device_host_byte_limiter: Optional[limits.ByteLimiter] = None
124-
is_ocdbt_checkpoint: Optional[bool] = None
125-
use_compression: bool | None = True
126-
use_zarr3: Optional[bool] = False
127-
ocdbt_target_data_file_size: Optional[int] = None
128-
ts_context: Optional[ts.Context] = None
129-
value_typestr: Optional[str] = None
130-
enable_pinned_host_transfer: bool = False
131-
raise_array_data_missing_error: bool = True
132-
write_shape: arrays_types.Shape | None = None
133-
is_prioritized_key_fn: Optional[IsPrioritizedKeyFn] = None
134-
135-
def __post_init__(self):
136-
if self.path is None:
137-
self.path = self.parent_dir / self.name
104+
def __init__(
105+
self,
106+
*,
107+
name: str,
108+
parent_dir: epath.Path | path_types.PathAwaitingCreation,
109+
path: epath.Path | path_types.PathAwaitingCreation | None = None,
110+
keypath: Optional[Tuple[Any, ...]] = None,
111+
skip_deserialize: Optional[bool] = None,
112+
byte_limiter: Optional[limits.ByteLimiter] = None,
113+
device_host_byte_limiter: Optional[limits.ByteLimiter] = None,
114+
is_ocdbt_checkpoint: Optional[bool] = None,
115+
use_compression: bool | None = True,
116+
use_zarr3: Optional[bool] = False,
117+
ocdbt_target_data_file_size: Optional[int] = None,
118+
ts_context: Optional[ts.Context] = None,
119+
value_typestr: Optional[str] = None,
120+
enable_pinned_host_transfer: bool = False,
121+
raise_array_data_missing_error: bool = True,
122+
write_shape: arrays_types.Shape | None = None,
123+
is_prioritized_key_fn: Optional[IsPrioritizedKeyFn] = None,
124+
):
125+
self.name = name
126+
self._parent_dir = parent_dir
127+
self._path = path if path is not None else parent_dir / name
128+
self.keypath = keypath
129+
self.skip_deserialize = skip_deserialize
130+
self.byte_limiter = byte_limiter
131+
self.device_host_byte_limiter = device_host_byte_limiter
132+
self.is_ocdbt_checkpoint = is_ocdbt_checkpoint
133+
self.use_compression = use_compression
134+
self.use_zarr3 = use_zarr3
135+
self.ocdbt_target_data_file_size = ocdbt_target_data_file_size
136+
self.ts_context = ts_context
137+
self.value_typestr = value_typestr
138+
self.enable_pinned_host_transfer = enable_pinned_host_transfer
139+
self.raise_array_data_missing_error = raise_array_data_missing_error
140+
self.write_shape = write_shape
141+
self.is_prioritized_key_fn = is_prioritized_key_fn
142+
143+
@property
144+
def parent_dir(self) -> epath.Path:
145+
if isinstance(self._parent_dir, path_types.PathAwaitingCreation):
146+
raise ValueError(
147+
'parent_dir is a PathAwaitingCreation and has not been resolved yet. '
148+
'Call `await info.await_path_creation()` first.'
149+
)
150+
return self._parent_dir
151+
152+
@parent_dir.setter
153+
def parent_dir(self, value: epath.Path | path_types.PathAwaitingCreation):
154+
self._parent_dir = value
155+
156+
@property
157+
def path(self) -> epath.Path:
158+
if isinstance(self._path, path_types.PathAwaitingCreation):
159+
raise ValueError(
160+
'path is a PathAwaitingCreation and has not been resolved yet. '
161+
'Call `await info.await_path_creation()` first.'
162+
)
163+
if self._path is None:
164+
raise ValueError('path is None.')
165+
return self._path
166+
167+
@path.setter
168+
def path(self, value: epath.Path | path_types.PathAwaitingCreation | None):
169+
self._path = value
170+
171+
def replace(self, **kwargs) -> ParamInfo:
172+
new = copy.copy(self)
173+
for k, v in kwargs.items():
174+
setattr(new, k, v)
175+
return new
176+
177+
async def await_path_creation(self) -> None:
178+
if isinstance(self._parent_dir, path_types.PathAwaitingCreation):
179+
resolved_path = await self._parent_dir.await_creation()
180+
self._parent_dir = resolved_path
181+
self._path = resolved_path / self.name
138182

139183

140184
@dataclasses.dataclass

0 commit comments

Comments
 (0)