Skip to content

Load safetensors checkpoints per the target sharding in v1#3238

Closed
mridul-sahu wants to merge 0 commit into
google:mainfrom
mridul-sahu:main
Closed

Load safetensors checkpoints per the target sharding in v1#3238
mridul-sahu wants to merge 0 commit into
google:mainfrom
mridul-sahu:main

Conversation

@mridul-sahu
Copy link
Copy Markdown
Collaborator

Summary

In v1, SafetensorsLayout loaded a .safetensors checkpoint on multiple
hosts by splitting the file into contiguous byte bundles (one ~1/N slab
per host, chosen by file layout, not by the user's sharding), building a
(num_hosts, *shape) transient array, and collapsing it to the target
sharding with a per-tensor jnp.sum reshard collective.

This replaces that with a sharding-driven load: each process reads
exactly the bytes of the shards its own devices need.

Old vs. new

Old load_multi_host New
What a host reads a file-contiguity 1/N slab, unrelated to the sharding exactly its devices' shards
Cross-host traffic jnp.sum over a hosts axis — a collective per tensor none
Compilation jax.jit(lambda …) built inline, recompiled per tensor none
Single- vs. multi-host two separate code paths one unified path

How it works

For each tensor, every process:

  1. takes the target jax.sharding.Sharding and resolves which shards its
    own devices hold — devices_indices_map(...) filtered by
    device.process_index;
  2. maps each shard's index domain to byte ranges in the file via
    index_domain_to_byte_runs (one contiguous run for a
    leading-dimension shard, strided runs for an inner-dimension shard;
    coalesce then merges nearby runs into fewer reads);
  3. reads exactly those ranges and assembles the array with
    jax.make_array_from_single_device_arrays.

This mirrors how the v0 array deserializer
(_src/serialization/serialization.py) drives reads from the sharding —
with hand-computed byte ranges standing in for TensorStore's chunk index,
since a safetensors tensor is a single contiguous blob.

The file is not read multiple times

A jax.sharding partitions a tensor, so the processes' byte ranges are
disjoint and their union is the whole tensor. Each process issues ranged
reads (seek + read(length) — an HTTP Range GET on object storage)
for only its shards, and within a process a shard shared by several local
devices is read once (reads are de-duplicated by index domain) and
placed on each device. Total bytes pulled from storage across the cluster
≈ 1× the file. (A replicated tensor is necessarily read by every host
that holds it — inherent to replication, not redundant reading.)

Tests

  • index_domain_to_byte_runs, coalesce, byte-stride helpers — scalar,
    whole-tensor, leading-dimension (one run), inner-dimension (strided),
    non-zero offsets, 3-D, non-4-byte itemsize; coalescing of overlapping,
    unsorted, exact-gap and partial runs.
  • _read_shard_bytes — coalesced vs. uncoalesced extraction agree.
  • Layout: shape mismatch, dtype cast, nested abstract state, scalar and
    multi-dimensional tensors, empty abstract state, duplicate tensor
    across files, and a dtype sweep (int8 / int32 / float16 / bfloat16 /
    float8_e4m3fn).
  • The multi-process suite passes across mesh, shape and sharding-type
    combinations, including ignore_load_sharding.

@github-actions github-actions Bot added the pull ready Ready to be pulled from GitHub into Google label May 22, 2026
@mridul-sahu mridul-sahu removed the pull ready Ready to be pulled from GitHub into Google label May 23, 2026
@mridul-sahu mridul-sahu added the pull ready Ready to be pulled from GitHub into Google label May 23, 2026
@mridul-sahu mridul-sahu force-pushed the main branch 4 times, most recently from 7ce2a07 to 734f828 Compare May 23, 2026 11:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready to be pulled from GitHub into Google

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant