Thanks for your suggestions. I downgraded the version with jax==0.2.14 jaxlib==0.1.69+cuda111 -f
. It's work, but a new error when running Run_DMFold.py, as follow:
Code: Select all
I0117 07:23:01.527364 22723190105920 run_alphafold_mymsa.py:209] Running model model_1_multimer_v3_pred_0 on seq
I0117 07:23:01.527707 22723190105920 model.py:165] Running predict with shape(feat) = {'aatype': (1778,), 'residue_index': (1778,), 'seq_length': (), 'msa': (2701, 1778), 'num_alignments': (), 'template_aatype': (4, 1778), 'template_all_atom_mask': (4, 1778, 37), 'template_all_atom_positions': (4, 1778, 37, 3), 'asym_id': (1778,), 'sym_id': (1778,), 'entity_id': (1778,), 'deletion_matrix': (2701, 1778), 'deletion_mean': (1778,), 'all_atom_mask': (1778, 37), 'all_atom_positions': (1778, 37, 3), 'assembly_num_chains': (), 'entity_mask': (1778,), 'num_templates': (), 'cluster_bias_mask': (2701,), 'bert_mask': (2701, 1778), 'seq_mask': (1778,), 'msa_mask': (2701, 1778)}
Traceback (most recent call last):
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/run_alphafold_mymsa.py", line 514, in <module>
app.run(main)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/run_alphafold_mymsa.py", line 490, in main
predict_structure(
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/run_alphafold_mymsa.py", line 217, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/api.py", line 424, in cache_miss
out_flat = xla.xla_call(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/core.py", line 1560, in bind
return call_bind(self, fun, *args, **params)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/core.py", line 1563, in process
return trace.process_call(self, fun, tracers, params)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/xla.py", line 592, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/xla.py", line 667, in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/transform.py", line 125, in apply_fn
out, state = f.apply(params, {}, *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/transform.py", line 313, in apply_fn
out = f(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/model.py", line 77, in _forward_fn
return model(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 485, in __call__
prev, safe_key = hk.fori_loop(0, num_iter, recycle_body, (prev, safe_key))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 544, in fori_loop
state, val = jax.lax.fori_loop(lower, upper, pure_body_fun, init_val)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 212, in fori_loop
(_, result), _ = scan(_fori_scan_body_fun(body_fun), (lower_, init_val),
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1288, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1274, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 70, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 143, in scanned_fun
return (i + 1, body_fun(i, x)), None
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 524, in pure_body_fun
val = body_fun(i, val)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 482, in recycle_body
ret = apply_network(prev=prev, safe_key=safe_key2)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 449, in apply_network
return impl(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 321, in __call__
repr_shape = hk.eval_shape(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 665, in eval_shape
out_shape = jax.eval_shape(fun, *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/api.py", line 2403, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 410, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 322, in <lambda>
lambda: embedding_module(batch, is_training))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 661, in __call__
template_act = template_module(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 841, in __call__
summed_template_embeddings, _ = hk.scan(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 504, in scan
(carry, state), ys = jax.lax.scan(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1288, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1274, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 70, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 487, in stateful_fun
carry, out = f(carry, x)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 837, in scan_fn
return carry + partial_template_embedder(*x), None
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 824, in partial_template_embedder
return template_embedder(query_embedding,
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 987, in __call__
act, safe_key = template_stack((act, safe_subkey))
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/layer_stack.py", line 265, in wrapped
ret = _LayerStackNoState(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/layer_stack.py", line 156, in __call__
carry, zs = hk.scan(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 504, in scan
(carry, state), ys = jax.lax.scan(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1288, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 1274, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 76, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 186, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/util.py", line 179, in cached
return f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/_src/lax/control_flow.py", line 70, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1252, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/stateful.py", line 487, in stateful_fun
carry, out = f(carry, x)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/layer_stack.py", line 149, in layer
out_x, z = self._call_wrapped(carry.x, *scanned.args_ys)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/layer_stack.py", line 182, in _call_wrapped
ret = self._f(*args)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 973, in template_iteration_fn
act = template_iteration(
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules_multimer.py", line 1034, in __call__
act = dropout_wrapper_fn(
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules.py", line 76, in dropout_wrapper
residual = module(input_act, mask, is_training=is_training, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/Software/DMFold1.0/bin/alphafold_multimer/alphafold/model/modules.py", line 1283, in __call__
act = hk.LayerNorm(axis=[-1], create_scale=True, create_offset=True,
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 428, in wrapped
out = f(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/module.py", line 279, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/layer_norm.py", line 128, in __call__
scale = hk.get_parameter("scale", param_shape, inputs.dtype,
File "/home/dsj/miniconda3/envs/alphafold2nondocker/lib/python3.8/site-packages/haiku/_src/base.py", line 293, in get_parameter
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Unable to retrieve parameter 'scale' for module 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input'. All parameters must be created as part of `init`.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.