Skip to content

AttributeError: 'ArrayImpl'  #4

@GB14KR

Description

@GB14KR

Hi nshepperd,
I just found this colab notebook, and was trying to test, but following error comes up during the run.
Is it due to different version of Jax, Python, etc?
I'm really not good at debugging when it comes to these scripts, so I'd be very happy if you could help.

CELL


title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
prompt = txt(title)
style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
batch_size = 1
clip_guidance_scale = 2000
style_guidance_scale = 300
tv_scale = 150
cutn = 32 # effective cutn is 4x this because we do 4 iterations in base_cond_fn
cut_pow = 0.5
style_cutn = 32
n_batches = 4
init_image = None
skip_timesteps = 0
seed = 5


ERROR

AttributeError Traceback (most recent call last)
in <cell line: 2>()
1 title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
----> 2 prompt = txt(title)
3 style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
4 batch_size = 1
5 clip_guidance_scale = 2000

1 frames
in txt(prompt)
85 text = clip_jax.tokenize([prompt])
86 text_embed = text_fn(clip_params, text)
---> 87 return norm1(text_embed.reshape(512))
88
89 @jax.jit

in norm1(x)
68 def norm1(x):
69 """Normalize to the unit sphere."""
---> 70 return x / x.square().sum(axis=-1, keepdims=True).sqrt()
71
72 def spherical_dist_loss(x, y):

AttributeError: 'ArrayImpl' object has no attribute 'square'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions