Skip to content

Commit 890d9bb

Browse files
committed
Type checking improvements
* Ensure we do not accumulate previous information across clauses * Avoid union(a, not a) when building domains * Ensure union with an open map/tuples returns the open map/tuple * Remove deep traversal on bdd leaf intersection which was slowing compilation of projects like open_api_spex Closes #15285.
1 parent 20fef8c commit 890d9bb

File tree

7 files changed

+199
-118
lines changed

7 files changed

+199
-118
lines changed

lib/elixir/lib/module/types.ex

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -321,28 +321,41 @@ defmodule Module.Types do
321321
stack = stack |> fresh_stack(mode, fun_arity) |> with_file_meta(meta)
322322
base_info = {:def, kind, fun, expected}
323323

324-
{_, _, _, mapping, clauses_types, clauses_context} =
325-
Enum.reduce(clauses, {0, 0, Pattern.init_previous(), [], [], context}, fn
326-
{meta, args, guards, body}, {index, total, previous, mapping, inferred, acc_context} ->
324+
{_, _, _, domain, mapping, clauses_types, clauses_context} =
325+
Enum.reduce(clauses, {0, 0, Pattern.init_previous(), [], [], [], context}, fn
326+
{meta, args, guards, body},
327+
{index, total, previous, domain, mapping, inferred, acc_context} ->
327328
fresh_context = fresh_context(acc_context)
328329
info = {base_info, args, guards}
329330

330331
try do
331-
{trees, previous, context} =
332+
{trees, head_no_previous_args_types, previous, head_context} =
332333
Pattern.of_head(args, guards, expected, previous, info, meta, stack, fresh_context)
333334

334335
{return_type, context} =
335-
Expr.of_expr(body, Descr.term(), body, stack, context)
336+
Expr.of_expr(body, Descr.term(), body, stack, head_context)
336337

337338
args_types = Pattern.of_domain(trees, stack, context)
338339

339340
{type_index, inferred} =
340341
add_inferred(inferred, args_types, return_type, total - 1, [])
341342

343+
domain =
344+
case domain do
345+
[] ->
346+
args_types
347+
348+
_ ->
349+
head_args_types = Pattern.of_domain(trees, stack, head_context)
350+
compute_domain(args_types, head_args_types, head_no_previous_args_types, domain)
351+
end
352+
342353
if type_index == -1 do
343-
{index + 1, total + 1, previous, [{index, total} | mapping], inferred, context}
354+
mapping = [{index, total} | mapping]
355+
{index + 1, total + 1, previous, domain, mapping, inferred, context}
344356
else
345-
{index + 1, total, previous, [{index, type_index} | mapping], inferred, context}
357+
mapping = [{index, type_index} | mapping]
358+
{index + 1, total, previous, domain, mapping, inferred, context}
346359
end
347360
rescue
348361
e ->
@@ -352,19 +365,69 @@ defmodule Module.Types do
352365

353366
domain =
354367
case clauses_types do
355-
[_] ->
356-
nil
357-
358-
_ ->
359-
clauses_types
360-
|> Enum.map(fn {args, _} -> args end)
361-
|> Enum.zip_with(fn types -> Enum.reduce(types, &Descr.union/2) end)
368+
[_] -> nil
369+
_ -> domain
362370
end
363371

364372
inferred = {:infer, domain, Enum.reverse(clauses_types)}
365373
{inferred, mapping, restore_context(clauses_context, context)}
366374
end
367375

376+
defp compute_domain(
377+
[arg | args_types],
378+
[head_arg | head_args_types],
379+
[no_prev_arg | no_prev_args_types],
380+
[d | domain]
381+
) do
382+
[
383+
# This is an optimization that broadens the domain, but it is acceptable
384+
# because the domain is used for reverse arrows and not type checking.
385+
#
386+
# The overall idea is that, if we have a function with three clauses,
387+
# the domain is computed by unioning their inferred types. However, their
388+
# inferred types often have the different of the previous clauses:
389+
#
390+
# union(r3 ^ (c3 - c2 - c1), r2 ^ (c2 - c1), r1 ^ c1)
391+
#
392+
# Where `rN` represents the refinement in every function body.
393+
#
394+
# What this function does is, if the type of a given arg in a clause
395+
# before and after the body is the same (meaning r3 is term), then
396+
# we replace all of `(c3 - c2 - c1)` by just `c3`, which removes
397+
# many of the differences in the node. However, keep in mind that,
398+
# because `r2` may have refine `c2` in the previous clause, the domain
399+
# may end-up being broader. Take this example:
400+
#
401+
# % %{..., foo: integer()} -> binary()
402+
# def example(%{foo: var}), do: Integer.to_string(var)
403+
#
404+
# % %{...} and not %{..., foo: term()} -> :error
405+
# def example(%{}), do: :error
406+
#
407+
# The actual domain is:
408+
#
409+
# %{..., foo: not_set()} or %{..., foo: integer()}
410+
# #=> %{..., foo: if_set(integer())}
411+
#
412+
# But we will infer:
413+
#
414+
# %{...} or %{..., foo: integer()}
415+
# #=> %{...}
416+
#
417+
# We lose precision but this is exactly what we want: to have simpler types.
418+
# Furthermore, the signature used in type checking is not refined in any way,
419+
# so type checking is still sound.
420+
if arg == head_arg do
421+
Descr.union(Descr.upper_bound(no_prev_arg), d)
422+
else
423+
Descr.union(arg, d)
424+
end
425+
| compute_domain(args_types, head_args_types, no_prev_args_types, domain)
426+
]
427+
end
428+
429+
defp compute_domain([], [], [], []), do: []
430+
368431
# We check for term equality of types as an optimization
369432
# to reduce the amount of check we do at runtime.
370433
defp add_inferred([{args, existing_return} | tail], args, return, index, acc),

lib/elixir/lib/module/types/descr.ex

Lines changed: 59 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,12 @@ defmodule Module.Types.Descr do
28702870
end
28712871
end
28722872

2873+
defp map_union(bdd_leaf(:open, fields) = leaf, _) when is_fields_empty(fields),
2874+
do: leaf
2875+
2876+
defp map_union(_, bdd_leaf(:open, fields) = leaf) when is_fields_empty(fields),
2877+
do: leaf
2878+
28732879
defp map_union(bdd_leaf(tag1, fields1), bdd_leaf(tag2, fields2)) do
28742880
case maybe_optimize_map_union(tag1, fields1, tag2, fields2) do
28752881
{tag, fields} -> bdd_leaf(tag, fields)
@@ -3043,6 +3049,9 @@ defmodule Module.Types.Descr do
30433049
defp map_difference(_, bdd_leaf(:open, [])),
30443050
do: :bdd_bot
30453051

3052+
defp map_difference(bdd_leaf(:open, []), {_, _, _, _} = bdd2),
3053+
do: bdd_negation(bdd2)
3054+
30463055
defp map_difference(bdd1, bdd2),
30473056
do: bdd_difference(bdd1, bdd2, &map_leaf_difference/3)
30483057

@@ -4961,10 +4970,10 @@ defmodule Module.Types.Descr do
49614970
end
49624971
end
49634972

4964-
defp tuple_difference(bdd_leaf(:open, []), bdd_leaf(:open, [])),
4973+
defp tuple_difference(_, bdd_leaf(:open, [])),
49654974
do: :bdd_bot
49664975

4967-
defp tuple_difference(bdd_leaf(:open, []), bdd2),
4976+
defp tuple_difference(bdd_leaf(:open, []), {_, _, _, _} = bdd2),
49684977
do: bdd_negation(bdd2)
49694978

49704979
defp tuple_difference(bdd1, bdd2),
@@ -5163,6 +5172,12 @@ defmodule Module.Types.Descr do
51635172
end)
51645173
end
51655174

5175+
defp tuple_union(bdd_leaf(:open, fields) = leaf, _) when is_fields_empty(fields),
5176+
do: leaf
5177+
5178+
defp tuple_union(_, bdd_leaf(:open, fields) = leaf) when is_fields_empty(fields),
5179+
do: leaf
5180+
51665181
defp tuple_union(
51675182
bdd_leaf(tag1, elements1) = tuple1,
51685183
bdd_leaf(tag2, elements2) = tuple2
@@ -5845,8 +5860,8 @@ defmodule Module.Types.Descr do
58455860
defp bdd_difference_union(i, u1, u2),
58465861
do: bdd_difference(i, bdd_union(u1, u2))
58475862

5848-
# We avoid unions because they are lazy and we prune
5849-
# intersections more actively.
5863+
# We avoid bdd_negation(bdd_union(u1, u2)) because the negation
5864+
# would spread the unions across constrained and dual parts anyway.
58505865
defp bdd_negation_union(u1, u2) do
58515866
bdd_intersection(bdd_negation(u1), bdd_negation(u2))
58525867
end
@@ -6049,109 +6064,70 @@ defmodule Module.Types.Descr do
60496064

60506065
# Intersections are great because they allow us to cut down
60516066
# the number of nodes in the tree. So whenever we have a leaf,
6052-
# we actually propagate it throughout the whole tree, cutting
6053-
# down nodes.
6054-
defguardp is_not_open(leaf) when elem(leaf, 0) != :open
6055-
6067+
# we propagate it throughout the whole tree, cutting down nodes.
60566068
defp bdd_intersection(bdd_leaf(_, _) = leaf1, bdd_leaf(_, _) = leaf2, leaf_intersection) do
60576069
leaf_intersection.(leaf1, leaf2)
60586070
end
60596071

6060-
defp bdd_intersection(bdd_leaf(_, _) = leaf, bdd, leaf_intersection) when is_not_open(leaf) do
6061-
bdd_leaf_intersection(leaf, bdd, leaf_intersection)
6072+
defp bdd_intersection(bdd, bdd_leaf(tag, _) = leaf, leaf_intersection) when tag != :open do
6073+
bdd_non_open_leaf_intersection(leaf, bdd, leaf_intersection)
60626074
end
60636075

6064-
defp bdd_intersection(bdd, bdd_leaf(_, _) = leaf, leaf_intersection) when is_not_open(leaf) do
6065-
bdd_leaf_intersection(leaf, bdd, leaf_intersection)
6076+
defp bdd_intersection(bdd_leaf(tag, _) = leaf, bdd, leaf_intersection) when tag != :open do
6077+
bdd_non_open_leaf_intersection(leaf, bdd, leaf_intersection)
60666078
end
60676079

6068-
# Take two BDDs, B1 = {a1, C1, U2, D2} and B2.
6069-
#
6070-
# When C1 = :bdd_top, we have:
6080+
defp bdd_intersection(bdd1, bdd2, _leaf_intersection) do
6081+
bdd_intersection(bdd1, bdd2)
6082+
end
6083+
6084+
# Take two BDDs, B1 = {a1, C1, U1, D1} and B2 = a2.
60716085
#
6072-
# ((a1 and C1) or U2 or (not a1 and D2)) and B2
6073-
# (a1 and B2) or (B2 and U2) or (B2 and not a1 and D2)
6086+
# We have:
60746087
#
6075-
# When C1 = :bdd_bot, we have:
6088+
# ((a1 and C1) or U1 or (not a1 and D1)) and a2
6089+
# (a1 and a2 and C1) or (a2 and U1) or (a2 and not a1 and D1)
60766090
#
6077-
# (U2 or (not a1 and D2)) and B2
6078-
# (B2 and U2) or (B2 and not a1 and D2)
6079-
defp bdd_intersection({leaf, :bdd_top, u, d}, bdd, leaf_intersection) when is_not_open(leaf) do
6080-
bdd_leaf_intersection(leaf, bdd, leaf_intersection)
6081-
|> bdd_union(bdd_intersection(u, bdd, leaf_intersection))
6082-
|> case do
6083-
result when d == :bdd_bot -> result
6084-
result -> bdd_union(result, bdd_intersection(bdd, {leaf, :bdd_bot, :bdd_bot, d}))
6085-
end
6091+
# When C1 = :bdd_top, (a1 and a2) or (a2 and U2) or (a2 and not a1 and D2)
6092+
# When C2 = :bdd_bot, (a2 and U2) or (a2 and not a1 and D2)
6093+
defp bdd_non_open_leaf_intersection(leaf1, bdd_leaf(_, _) = leaf2, leaf_intersection) do
6094+
leaf_intersection.(leaf1, leaf2)
60866095
end
60876096

6088-
defp bdd_intersection(bdd, {leaf, :bdd_top, u, d}, leaf_intersection) when is_not_open(leaf) do
6089-
bdd_leaf_intersection(leaf, bdd, leaf_intersection)
6090-
|> bdd_union(bdd_intersection(u, bdd, leaf_intersection))
6097+
defp bdd_non_open_leaf_intersection(leaf, {a, :bdd_top, u, d}, leaf_intersection) do
6098+
leaf_intersection.(a, leaf)
6099+
|> bdd_union(bdd_non_open_leaf_intersection(leaf, u, leaf_intersection))
60916100
|> case do
6092-
result when d == :bdd_bot -> result
6093-
result -> bdd_union(result, bdd_intersection(bdd, {leaf, :bdd_bot, :bdd_bot, d}))
6094-
end
6095-
end
6101+
result when d == :bdd_bot ->
6102+
result
60966103

6097-
defp bdd_intersection({leaf, :bdd_bot, u, d}, bdd, leaf_intersection) when is_not_open(leaf) do
6098-
case bdd_intersection(u, bdd, leaf_intersection) do
6099-
result when d == :bdd_bot -> result
6100-
result -> bdd_union(result, bdd_intersection(bdd, {leaf, :bdd_bot, :bdd_bot, d}))
6104+
result ->
6105+
leaf
6106+
|> bdd_non_open_leaf_intersection(d, leaf_intersection)
6107+
|> bdd_difference(a)
6108+
|> bdd_union(result)
61016109
end
61026110
end
61036111

6104-
defp bdd_intersection(bdd, {leaf, :bdd_bot, u, d}, leaf_intersection) when is_not_open(leaf) do
6105-
case bdd_intersection(u, bdd, leaf_intersection) do
6106-
result when d == :bdd_bot -> result
6107-
result -> bdd_union(result, bdd_intersection(bdd, {leaf, :bdd_bot, :bdd_bot, d}))
6112+
defp bdd_non_open_leaf_intersection(leaf, {a, :bdd_bot, u, d}, leaf_intersection) do
6113+
case bdd_non_open_leaf_intersection(leaf, u, leaf_intersection) do
6114+
result when d == :bdd_bot ->
6115+
result
6116+
6117+
result ->
6118+
leaf
6119+
|> bdd_non_open_leaf_intersection(d, leaf_intersection)
6120+
|> bdd_difference(a)
6121+
|> bdd_union(result)
61086122
end
61096123
end
61106124

6111-
defp bdd_intersection(bdd1, bdd2, _leaf_intersection) do
6125+
defp bdd_non_open_leaf_intersection(bdd1, bdd2, _leaf_intersection) do
61126126
bdd_intersection(bdd1, bdd2)
61136127
end
61146128

6115-
defp bdd_leaf_intersection(leaf, bdd, intersection) do
6116-
case bdd do
6117-
:bdd_top ->
6118-
leaf
6119-
6120-
:bdd_bot ->
6121-
:bdd_bot
6122-
6123-
bdd_leaf(_, _) ->
6124-
intersection.(leaf, bdd)
6125-
6126-
{bdd_leaf(:open, _), _, _, _} ->
6127-
bdd_intersection(leaf, bdd)
6128-
6129-
{lit, c, u, _} when lit == leaf ->
6130-
case bdd_union(c, u) do
6131-
:bdd_bot -> :bdd_bot
6132-
cu -> {lit, cu, :bdd_bot, :bdd_bot}
6133-
end
6134-
6135-
{lit, c, u, d} ->
6136-
rest =
6137-
bdd_union(
6138-
bdd_leaf_intersection(leaf, u, intersection),
6139-
bdd_difference(bdd_leaf_intersection(leaf, d, intersection), lit)
6140-
)
6141-
6142-
if c == :bdd_bot do
6143-
rest
6144-
else
6145-
case intersection.(leaf, lit) do
6146-
:bdd_bot -> rest
6147-
new_leaf -> bdd_union(bdd_leaf_intersection(new_leaf, c, intersection), rest)
6148-
end
6149-
end
6150-
end
6151-
end
6152-
6153-
# {lit, c, u, d} = (lit and c) or u or (not lit and d), so
6154-
# its negation is ((lit and not c) or (not lit and not d)) and not u.
6129+
# {lit, c, u, d} = (lit and c) or u or (not lit and d),
6130+
# so its negation is ((lit and not c) or (not lit and not d)) and not u.
61556131
def bdd_negation(:bdd_top), do: :bdd_bot
61566132
def bdd_negation(:bdd_bot), do: :bdd_top
61576133
def bdd_negation({_, _} = pair), do: {pair, :bdd_bot, :bdd_bot, :bdd_top}

lib/elixir/lib/module/types/expr.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ defmodule Module.Types.Expr do
834834
{patterns, guards} = extract_head(head)
835835
info = {base_info, head}
836836

837-
{trees, previous, context} =
837+
{trees, _, previous, context} =
838838
Pattern.of_head(patterns, guards, domain, previous, info, meta, stack, context)
839839

840840
{result, context} = of_body.(trees, body, context)

lib/elixir/lib/module/types/pattern.ex

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ defmodule Module.Types.Pattern do
5050
end
5151

5252
defp previous_to_string({list, _}) do
53-
Enum.map_join(list, "\n ", fn types ->
53+
list
54+
|> Enum.reverse()
55+
|> Enum.map_join("\n ", fn types ->
5456
types
5557
|> Enum.map_join(", ", &(&1 |> upper_bound() |> to_quoted_string()))
5658
|> indent(4)
@@ -190,27 +192,28 @@ defmodule Module.Types.Pattern do
190192
# First we check if it fails without previous, if it doesn't, check if it is redundant.
191193
case of_precise_head(patterns, guards, expected, init_previous(), tag, stack, original) do
192194
{other_trees, _, _, _, %{failed: true} = other_context} ->
193-
{other_trees, previous, other_context}
195+
{other_trees, args_types, previous, other_context}
194196

195197
{other_trees, _, _, args_types, other_context} ->
196198
if previous_subtype?(args_types, previous) do
197199
warning = {:redundant, tag, expected, args_types, previous, other_context}
198-
{other_trees, previous, warn(__MODULE__, warning, meta, stack, other_context)}
200+
context = warn(__MODULE__, warning, meta, stack, other_context)
201+
{other_trees, args_types, previous, context}
199202
else
200-
{trees, previous, context}
203+
{trees, args_types, previous, context}
201204
end
202205
end
203206
else
204207
cond do
205208
check_previous? and previous_subtype?(args_types, previous) ->
206209
warning = {:redundant, tag, expected, args_types, previous, context}
207-
{trees, previous, warn(__MODULE__, warning, meta, stack, context)}
210+
{trees, args_types, previous, warn(__MODULE__, warning, meta, stack, context)}
208211

209212
precise? ->
210-
{trees, concat_previous(args_types, previous), context}
213+
{trees, args_types, concat_previous(args_types, previous), context}
211214

212215
true ->
213-
{trees, previous, context}
216+
{trees, args_types, previous, context}
214217
end
215218
end
216219
end
@@ -235,7 +238,7 @@ defmodule Module.Types.Pattern do
235238
of_pattern_intersect(trees, 0, [], pattern_info, tag, stack, context),
236239
# We compute the args types before we do the intersection with previous clauses
237240
args_types =
238-
(with [_ | _] <- previous,
241+
(with false <- empty_previous?(previous),
239242
{:ok, _types, context} <-
240243
of_pattern_refine(types, changed, pattern_info, tag, stack, context) do
241244
trees_to_args_types(trees, stack, context)

0 commit comments

Comments
 (0)