Hi nshepperd,
I just found this colab notebook, and was trying to test, but following error comes up during the run.
Is it due to different version of Jax, Python, etc?
I'm really not good at debugging when it comes to these scripts, so I'd be very happy if you could help.
CELL
title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
prompt = txt(title)
style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
batch_size = 1
clip_guidance_scale = 2000
style_guidance_scale = 300
tv_scale = 150
cutn = 32 # effective cutn is 4x this because we do 4 iterations in base_cond_fn
cut_pow = 0.5
style_cutn = 32
n_batches = 4
init_image = None
skip_timesteps = 0
seed = 5
ERROR
AttributeError Traceback (most recent call last)
in <cell line: 2>()
1 title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
----> 2 prompt = txt(title)
3 style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
4 batch_size = 1
5 clip_guidance_scale = 2000
1 frames
in txt(prompt)
85 text = clip_jax.tokenize([prompt])
86 text_embed = text_fn(clip_params, text)
---> 87 return norm1(text_embed.reshape(512))
88
89 @jax.jit
in norm1(x)
68 def norm1(x):
69 """Normalize to the unit sphere."""
---> 70 return x / x.square().sum(axis=-1, keepdims=True).sqrt()
71
72 def spherical_dist_loss(x, y):
AttributeError: 'ArrayImpl' object has no attribute 'square'
Hi nshepperd,
I just found this colab notebook, and was trying to test, but following error comes up during the run.
Is it due to different version of Jax, Python, etc?
I'm really not good at debugging when it comes to these scripts, so I'd be very happy if you could help.
CELL
title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
prompt = txt(title)
style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
batch_size = 1
clip_guidance_scale = 2000
style_guidance_scale = 300
tv_scale = 150
cutn = 32 # effective cutn is 4x this because we do 4 iterations in base_cond_fn
cut_pow = 0.5
style_cutn = 32
n_batches = 4
init_image = None
skip_timesteps = 0
seed = 5
ERROR
AttributeError Traceback (most recent call last)
in <cell line: 2>()
1 title = 'a beautiful fantasy land forest of glass, trending on ArtStation'
----> 2 prompt = txt(title)
3 style_embed = norm1(jnp.array(cborfile('jax-guided-diffusion/data/openimages_512x_png_embed224.cbor'))) - norm1(jnp.array(cborfile('jax-guided-diffusion/data/imagenet_512x_jpg_embed224.cbor')))
4 batch_size = 1
5 clip_guidance_scale = 2000
1 frames
in txt(prompt)
85 text = clip_jax.tokenize([prompt])
86 text_embed = text_fn(clip_params, text)
---> 87 return norm1(text_embed.reshape(512))
88
89 @jax.jit
in norm1(x)
68 def norm1(x):
69 """Normalize to the unit sphere."""
---> 70 return x / x.square().sum(axis=-1, keepdims=True).sqrt()
71
72 def spherical_dist_loss(x, y):
AttributeError: 'ArrayImpl' object has no attribute 'square'