Load safetensors checkpoints per the target sharding in v1#3238
Closed
mridul-sahu wants to merge 0 commit into
Closed
Load safetensors checkpoints per the target sharding in v1#3238mridul-sahu wants to merge 0 commit into
mridul-sahu wants to merge 0 commit into
Conversation
7ce2a07 to
734f828
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
In v1,
SafetensorsLayoutloaded a.safetensorscheckpoint on multiplehosts by splitting the file into contiguous byte bundles (one ~1/N slab
per host, chosen by file layout, not by the user's sharding), building a
(num_hosts, *shape)transient array, and collapsing it to the targetsharding with a per-tensor
jnp.sumreshard collective.This replaces that with a sharding-driven load: each process reads
exactly the bytes of the shards its own devices need.
Old vs. new
load_multi_hostjnp.sumover ahostsaxis — a collective per tensorjax.jit(lambda …)built inline, recompiled per tensorHow it works
For each tensor, every process:
jax.sharding.Shardingand resolves which shards itsown devices hold —
devices_indices_map(...)filtered bydevice.process_index;index_domain_to_byte_runs(one contiguous run for aleading-dimension shard, strided runs for an inner-dimension shard;
coalescethen merges nearby runs into fewer reads);jax.make_array_from_single_device_arrays.This mirrors how the v0 array deserializer
(
_src/serialization/serialization.py) drives reads from the sharding —with hand-computed byte ranges standing in for TensorStore's chunk index,
since a safetensors tensor is a single contiguous blob.
The file is not read multiple times
A
jax.shardingpartitions a tensor, so the processes' byte ranges aredisjoint and their union is the whole tensor. Each process issues ranged
reads (
seek+read(length)— an HTTPRangeGET on object storage)for only its shards, and within a process a shard shared by several local
devices is read once (reads are de-duplicated by index domain) and
placed on each device. Total bytes pulled from storage across the cluster
≈ 1× the file. (A replicated tensor is necessarily read by every host
that holds it — inherent to replication, not redundant reading.)
Tests
index_domain_to_byte_runs,coalesce, byte-stride helpers — scalar,whole-tensor, leading-dimension (one run), inner-dimension (strided),
non-zero offsets, 3-D, non-4-byte itemsize; coalescing of overlapping,
unsorted, exact-gap and partial runs.
_read_shard_bytes— coalesced vs. uncoalesced extraction agree.multi-dimensional tensors, empty abstract state, duplicate tensor
across files, and a dtype sweep (int8 / int32 / float16 / bfloat16 /
float8_e4m3fn).
combinations, including
ignore_load_sharding.