changed CHANGELOG.md
 
@@ -1,5 +1,19 @@
1
1
# Changelog
2
2
3
+ ## v0.6.1 (2023-09-12)
4
+
5
+ ### Enhancements
6
+
7
+ * [Nx] Add multivariate normal distribution
8
+ * [Nx.Serving] Automatically split exceeding batch sizes
9
+
10
+ ### Bug fixes
11
+
12
+ * [Nx] Fix `Nx.pad/2` with different backends
13
+ * [Nx] Fix `Nx.clip/3` with non-finite values
14
+ * [Nx.Serving] Emit batches as they arrive in `Nx.Serving.streaming/2`
15
+ * [Nx.Serving] Ensure batch key is preserved when a batch is split
16
+
3
17
## v0.6.0 (2023-08-15)
4
18
5
19
### Enhancements
changed hex_metadata.config
 
@@ -1,6 +1,6 @@
1
1
{<<"links">>,[{<<"GitHub">>,<<"https://github.com/elixir-nx/nx">>}]}.
2
2
{<<"name">>,<<"nx">>}.
3
- {<<"version">>,<<"0.6.0">>}.
3
+ {<<"version">>,<<"0.6.1">>}.
4
4
{<<"description">>,
5
5
<<"Multi-dimensional arrays (tensors) and numerical definitions for Elixir">>}.
6
6
{<<"elixir">>,<<"~> 1.14">>}.
changed lib/nx.ex
 
@@ -4062,7 +4062,7 @@ defmodule Nx do
4062
4062
shape = Nx.Shape.pad(tensor.shape, padding_config)
4063
4063
4064
4064
out = %{tensor | type: output_type, shape: shape}
4065
- impl!(tensor).pad(out, tensor, pad_value, padding_config)
4065
+ impl!(tensor, pad_value).pad(out, tensor, pad_value, padding_config)
4066
4066
end)
4067
4067
end
4068
4068
 
@@ -7704,10 +7704,10 @@ defmodule Nx do
7704
7704
[0, 2, 4]
7705
7705
>
7706
7706
7707
- iex> Nx.indexed_put(Nx.tensor([0, 0, 0]), Nx.tensor([[1], [2], [1]]), Nx.tensor([3, 4, 2]))
7707
+ iex> Nx.indexed_put(Nx.tensor([0, 0, 0]), Nx.tensor([[1], [2]]), Nx.tensor([3, 4]))
7708
7708
#Nx.Tensor<
7709
7709
s64[3]
7710
- [0, 2, 4]
7710
+ [0, 3, 4]
7711
7711
>
7712
7712
7713
7713
iex> t = Nx.iota({1, 2, 3})
changed lib/nx/backend.ex
 
@@ -178,31 +178,46 @@ defmodule Nx.Backend do
178
178
Note the `binary` may have fewer elements than the
179
179
tensor size but, in such cases, it must strictly have
180
180
more elements than `inspect_opts.limit`
181
+
182
+ ## Options
183
+
184
+ The following must be passed through `Inspect` `:custom_options`
185
+
186
+ * `:nx_precision` - Configures the floating-point number printing precision.
187
+ If set, will print floating-point numbers in scientific notation using the
188
+ specified number of significant digits. Otherwise, default Elixir printing
189
+ rules are applied.
181
190
"""
182
191
def inspect(%{shape: shape, type: type}, binary, inspect_opts) do
183
192
open = IA.color("[", :list, inspect_opts)
184
193
sep = IA.color(",", :list, inspect_opts)
185
194
close = IA.color("]", :list, inspect_opts)
186
195
196
+ # TO-DO: This is a paliative accessibility-related solution
197
+ precision = inspect_opts.custom_options[:nx_precision]
198
+
187
199
dims = Tuple.to_list(shape)
188
- {data, _rest, _limit} = chunk(dims, binary, type, inspect_opts.limit, {open, sep, close})
200
+
201
+ {data, _rest, _limit} =
202
+ chunk(dims, binary, type, inspect_opts.limit, precision, {open, sep, close})
203
+
189
204
data
190
205
end
191
206
192
- defp chunk([], data, type, limit, _docs) do
207
+ defp chunk([], data, type, limit, precision, _docs) do
193
208
{doc, tail} =
194
209
Nx.Shared.match_types [type] do
195
210
<<match!(head, 0), tail::binary>> = data
196
- {inspect_value(read!(head, 0)), tail}
211
+ {inspect_value(read!(head, 0), precision), tail}
197
212
end
198
213
199
214
if limit == :infinity, do: {doc, tail, limit}, else: {doc, tail, limit - 1}
200
215
end
201
216
202
- defp chunk([dim | dims], data, type, limit, {open, sep, close} = docs) do
217
+ defp chunk([dim | dims], data, type, limit, precision, {open, sep, close} = docs) do
203
218
{acc, rest, limit} =
204
219
chunk_each(dim, data, [], limit, fn chunk, limit ->
205
- chunk(dims, chunk, type, limit, docs)
220
+ chunk(dims, chunk, type, limit, precision, docs)
206
221
end)
207
222
208
223
{open, sep, close, nest} =
 
@@ -234,10 +249,50 @@ defmodule Nx.Backend do
234
249
chunk_each(dim - 1, rest, [doc | acc], limit, fun)
235
250
end
236
251
237
- defp inspect_value(%Complex{} = val), do: Complex.to_string(val)
238
- defp inspect_value(integer) when is_integer(integer), do: Integer.to_string(integer)
239
- defp inspect_value(float) when is_float(float), do: Float.to_string(float)
240
- defp inspect_value(:neg_infinity), do: "-Inf"
241
- defp inspect_value(:infinity), do: "Inf"
242
- defp inspect_value(:nan), do: "NaN"
252
+ defp inspect_value(integer, _) when is_integer(integer), do: Integer.to_string(integer)
253
+ defp inspect_value(:neg_infinity, _), do: "-Inf"
254
+ defp inspect_value(:infinity, _), do: "Inf"
255
+ defp inspect_value(:nan, _), do: "NaN"
256
+ defp inspect_value(%Complex{} = val, precision), do: complex_to_string(val, precision)
257
+
258
+ defp inspect_value(float, precision), do: float_to_string(float, precision)
259
+
260
+ defp float_to_string(float, precision) do
261
+ [integer_part, decimal_part, exponent_part] =
262
+ case String.split(Float.to_string(float), [".", "e"], parts: 3) do
263
+ [i, d] -> [i, d, ""]
264
+ [i, d, e] -> [i, d, "e" <> e]
265
+ end
266
+
267
+ # We'll now prune decimal_part to ensure we have at most `precision`
268
+ # digits there.
269
+
270
+ decimal_part =
271
+ decimal_part
272
+ |> binary_part(0, min(byte_size(decimal_part), precision))
273
+
274
+ # We also prune trailing zeros. Only for more than 1 digit because that single
275
+ # digit always needs to stay put.
276
+ decimal_part =
277
+ if byte_size(decimal_part) > 1 do
278
+ String.trim_trailing(decimal_part, "0")
279
+ else
280
+ decimal_part
281
+ end
282
+
283
+ integer_part <> "." <> decimal_part <> exponent_part
284
+ end
285
+
286
+ def complex_to_string(%Complex{re: re, im: im}, precision) do
287
+ re_str = inspect_value(re, precision)
288
+ im_str = inspect_value(im, precision)
289
+
290
+ im_str =
291
+ case im_str do
292
+ "-" <> _ -> im_str
293
+ s -> "+" <> s
294
+ end
295
+
296
+ re_str <> im_str <> "i"
297
+ end
243
298
end
changed lib/nx/batch.ex
 
@@ -12,20 +12,22 @@ defmodule Nx.Batch do
12
12
13
13
The `:size` field is public.
14
14
"""
15
+ @enforce_keys [:key]
15
16
@derive {Inspect, only: [:size, :pad]}
16
- defstruct stack: [], size: 0, template: nil, pad: 0, key: :default
17
+ defstruct [:key, stack: [], size: 0, template: nil, pad: 0]
17
18
18
19
@type t :: %Nx.Batch{
19
20
stack: list(),
20
21
size: non_neg_integer(),
21
22
template: Nx.Container.t() | Nx.Tensor.t() | nil,
22
- pad: non_neg_integer()
23
+ pad: non_neg_integer(),
24
+ key: term()
23
25
}
24
26
25
27
@doc """
26
28
Returns a new empty batch.
27
29
"""
28
- def new, do: %Nx.Batch{}
30
+ def new, do: %Nx.Batch{key: :default}
29
31
30
32
@doc """
31
33
Sets the batch key for the given batch.
 
@@ -121,17 +123,17 @@ defmodule Nx.Batch do
121
123
>
122
124
"""
123
125
def split(%Nx.Batch{} = batch, n) when is_integer(n) and n > 0 do
124
- %{template: template, stack: stack, pad: pad, size: size} = batch
126
+ %{template: template, stack: stack, pad: pad, size: size, key: key} = batch
125
127
126
128
if n < size do
127
129
{left, right} = drop_split(stack, size - n, [])
128
130
129
131
{%{batch | stack: left, size: n, pad: 0},
130
- %Nx.Batch{template: template, pad: pad, size: size - n, stack: right}}
132
+ %Nx.Batch{template: template, pad: pad, size: size - n, stack: right, key: key}}
131
133
else
132
134
right_pad = max(size + pad - n, 0)
133
135
left_pad = pad - right_pad
134
- {%{batch | pad: left_pad}, %Nx.Batch{template: template, pad: right_pad}}
136
+ {%{batch | pad: left_pad}, %Nx.Batch{template: template, pad: right_pad, key: key}}
135
137
end
136
138
end
changed lib/nx/binary_backend.ex
 
@@ -1885,7 +1885,23 @@ defmodule Nx.BinaryBackend do
1885
1885
in_data = to_binary(tensor)
1886
1886
min = binary_to_number(to_binary(min), min.type)
1887
1887
max = binary_to_number(to_binary(max), max.type)
1888
- out_data = binary_to_binary(in_data, tensor.type, out.type, &min(max(&1, min), max))
1888
+
1889
+ comparison_fn = fn x ->
1890
+ clipped_min =
1891
+ if element_greater(nil, x, min) == 1 do
1892
+ x
1893
+ else
1894
+ min
1895
+ end
1896
+
1897
+ if element_less(nil, clipped_min, max) == 1 do
1898
+ clipped_min
1899
+ else
1900
+ max
1901
+ end
1902
+ end
1903
+
1904
+ out_data = binary_to_binary(in_data, tensor.type, out.type, comparison_fn)
1889
1905
from_binary(out, out_data)
1890
1906
end
changed lib/nx/lin_alg.ex
 
@@ -32,7 +32,7 @@ defmodule Nx.LinAlg do
32
32
#Nx.Tensor<
33
33
c64[2][2]
34
34
[
35
- [1.0+0.0i, 3.0+0.0i],
35
+ [1.0-0.0i, 3.0-0.0i],
36
36
[0.0-2.0i, 0.0+4.0i]
37
37
]
38
38
>
 
@@ -1945,7 +1945,7 @@ defmodule Nx.LinAlg do
1945
1945
iex> Nx.LinAlg.determinant(t)
1946
1946
#Nx.Tensor<
1947
1947
c64
1948
- 0.0-6.0i
1948
+ -0.0-6.0i
1949
1949
>
1950
1950
1951
1951
"""
changed lib/nx/random.ex
 
@@ -551,6 +551,186 @@ defmodule Nx.Random do
551
551
|> stop_grad()
552
552
end
553
553
554
+ @doc """
555
+ Returns a sample from a multivariate normal distribution with given `mean` and `covariance` (matrix).
556
+ The function assumes that the covariance is a positive semi-definite matrix.
557
+ Otherwise, the result will not be normally distributed.
558
+
559
+ ## Options
560
+
561
+ * `:type` - a float type for the returned tensor
562
+
563
+ * `:shape` - batch shape of the returned tensor, i.e. the prefix of the result shape excluding the last axis
564
+
565
+ * `:names` - the names of the returned tensor
566
+
567
+ * `:method` - a decomposition method used for the covariance. Must be one of :svd, :eigh, and :cholesky.
568
+ Defaults to :cholesky. For singular covariance matrices, use :svd or :eigh.
569
+
570
+ ## Examples
571
+
572
+ iex> key = Nx.Random.key(12)
573
+ iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]))
574
+ iex> multivariate_normal
575
+ #Nx.Tensor<
576
+ f32[1]
577
+ [0.735927939414978]
578
+ >
579
+
580
+ iex> key = Nx.Random.key(12)
581
+ iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]))
582
+ iex> multivariate_normal
583
+ #Nx.Tensor<
584
+ f32[2]
585
+ [-1.3425945043563843, -0.40812060236930847]
586
+ >
587
+
588
+ iex> key = Nx.Random.key(12)
589
+ iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0]), Nx.tensor([[1]]), shape: {3, 2}, type: :f16)
590
+ iex> multivariate_normal
591
+ #Nx.Tensor<
592
+ f16[3][2][1]
593
+ [
594
+ [
595
+ [0.326904296875],
596
+ [0.2176513671875]
597
+ ],
598
+ [
599
+ [0.316650390625],
600
+ [0.1109619140625]
601
+ ],
602
+ [
603
+ [0.53955078125],
604
+ [-0.8857421875]
605
+ ]
606
+ ]
607
+ >
608
+
609
+ iex> key = Nx.Random.key(12)
610
+ iex> {multivariate_normal, _new_key} = Nx.Random.multivariate_normal(key, Nx.tensor([0, 0]), Nx.tensor([[1, 0], [0, 1]]), shape: {3, 2})
611
+ iex> multivariate_normal
612
+ #Nx.Tensor<
613
+ f32[3][2][2]
614
+ [
615
+ [
616
+ [0.9891449809074402, 1.0795185565948486],
617
+ [-0.9467806220054626, 1.47813880443573]
618
+ ],
619
+ [
620
+ [2.2095863819122314, -1.529456377029419],
621
+ [-0.7933920621871948, 1.121195673942566]
622
+ ],
623
+ [
624
+ [0.10976295918226242, -0.9959557056427002],
625
+ [0.4754556119441986, 1.1413804292678833]
626
+ ]
627
+ ]
628
+ >
629
+ """
630
+ defn multivariate_normal(key, mean, covariance, opts \\ []) do
631
+ keys = split(key)
632
+ {multivariate_normal_split(keys[1], mean, covariance, opts), keys[0]}
633
+ end
634
+
635
+ @doc """
636
+ Same as `multivariate_normal/4` but assumes the key has already been split.
637
+ """
638
+ defn multivariate_normal_split(key, mean, covariance, opts \\ []) do
639
+ assert_key!(key)
640
+ opts = keyword!(opts, [:names, :type, :shape, method: :cholesky])
641
+ {type, shape} = validate_multivariate_normal_args(mean, covariance, opts)
642
+ mean = Nx.as_type(mean, type)
643
+ covariance = Nx.as_type(covariance, type)
644
+
645
+ a =
646
+ case opts[:method] do
647
+ :svd ->
648
+ {u, s, _} = Nx.LinAlg.svd(covariance)
649
+ u * Nx.sqrt(s)
650
+
651
+ :eigh ->
652
+ {s, u} = Nx.LinAlg.eigh(covariance)
653
+ u * Nx.sqrt(s)
654
+
655
+ :cholesky ->
656
+ Nx.LinAlg.cholesky(covariance)
657
+ end
658
+
659
+ z =
660
+ normal_split(key, 0.0, 1.0, shape: shape, type: type)
661
+ |> Nx.reshape({:auto, Nx.size(mean)})
662
+
663
+ # x = mean + z * a^T
664
+ (mean + Nx.dot(z, [1], a, [1]))
665
+ |> Nx.reshape(shape, names: opts[:names])
666
+ |> stop_grad()
667
+ end
668
+
669
+ deftransformp validate_multivariate_normal_args(mean, covariance, opts) do
670
+ if opts[:method] not in [:svd, :eigh, :cholesky] do
671
+ raise ArgumentError,
672
+ """
673
+ method must be one of :svd, :eigh, and :cholesky
674
+ """
675
+ end
676
+
677
+ type = infer_float_type(mean, covariance, opts)
678
+
679
+ case type do
680
+ {:f, _} -> nil
681
+ {:bf, _} -> nil
682
+ {:c, _} -> Nx.Shared.raise_complex_not_supported(:multivariate_normal_split, 4)
683
+ _ -> raise ArgumentError, "expected float or complex type, got type #{inspect(type)}"
684
+ end
685
+
686
+ if Nx.rank(mean) != 1 do
687
+ raise ArgumentError,
688
+ """
689
+ expected mean to have rank 1, got tensor with rank #{Nx.rank(mean)}
690
+ """
691
+ end
692
+
693
+ dim = Nx.size(mean)
694
+
695
+ if Nx.rank(covariance) != 2 do
696
+ raise ArgumentError,
697
+ """
698
+ expected covariance to have rank 2, got tensor with rank #{Nx.rank(covariance)}
699
+ """
700
+ end
701
+
702
+ if Nx.axis_size(covariance, 0) != Nx.axis_size(covariance, 1) do
703
+ raise ArgumentError,
704
+ """
705
+ expected covariance to be a square matrix, got tensor with shape #{Nx.shape(covariance)}
706
+ """
707
+ end
708
+
709
+ if Nx.axis_size(covariance, 0) != dim do
710
+ raise ArgumentError,
711
+ """
712
+ expected mean and covariance to have the same dimensions, got #{dim} and #{Nx.axis_size(covariance, 0)}
713
+ """
714
+ end
715
+
716
+ shape =
717
+ case opts[:shape] do
718
+ nil ->
719
+ {dim}
720
+
721
+ dims when is_tuple(dims) ->
722
+ Tuple.append(dims, dim)
723
+
724
+ _ ->
725
+ raise ArgumentError,
726
+ """
727
+ shape must be a tuple of integers
728
+ """
729
+ end
730
+
731
+ {type, shape}
732
+ end
733
+
554
734
@doc """
555
735
Randomly shuffles tensor elements along an axis.
556
736
 
@@ -780,7 +960,7 @@ defmodule Nx.Random do
780
960
781
961
case Nx.rank(p) do
782
962
1 -> :ok
783
- r -> raise ArgumentError, "propability tensor must have rank 1, got: #{r}"
963
+ r -> raise ArgumentError, "probability tensor must have rank 1, got: #{r}"
784
964
end
785
965
786
966
case {Nx.size(p), Nx.axis_size(tensor, axis)} do
changed lib/nx/serving.ex
 
@@ -519,6 +519,9 @@ defmodule Nx.Serving do
519
519
return a stream. The stream is must be consumed in the same
520
520
process that calls `run/2` or `batched_run/2`.
521
521
522
+ Batches will be streamed as they arrive. You may also opt-in
523
+ to stream `Nx.Defn` hooks.
524
+
522
525
## Options
523
526
524
527
* `:hooks` - a list of hook names that will become streaming events
 
@@ -533,17 +536,22 @@ defmodule Nx.Serving do
533
536
534
537
{hook_name, term()}
535
538
536
- Once the stream is done, it will emit `{:done, output, metadata}`.
537
- The client postprocessing is often expected to call
538
- `Stream.transform/3` to process those events into something usable
539
- by callers.
539
+ The stream will also receive events in the shape of
540
+ `{:batch, output, metadata}` as batches are processed by the
541
+ serving. The client postprocessing is often expected to call
542
+ `Stream.transform/3` to process those events into something
543
+ usable by callers.
540
544
541
- ### Batch breaking
545
+ If the `:hooks` option is given, only a single `:batch` event
546
+ is emitted, at the end, as detailed next.
542
547
543
- Another consequence of streaming is that a serving server can
544
- no longer break a batch. For example, imagine you have a
545
- `batch_size` of 3 and you push three batches of two elements
546
- (AA, BB, and CC). Without streaming, the batches will be consumed as:
548
+ ### Batch limits
549
+
550
+ If you are streaming hooks, the serving server can no longer break
551
+ batch and you are unable to push a payload bigger than `:batch_size`.
552
+ For example, imagine you have a `batch_size` of 3 and you push three
553
+ batches of two elements (AA, BB, and CC). Without hooks, the batches
554
+ will be consumed as:
547
555
548
556
AAB -> BCC
549
557
 
@@ -598,7 +606,7 @@ defmodule Nx.Serving do
598
606
streaming: streaming
599
607
} = serving
600
608
601
- {ref, defn_options} = run_streaming_hooks(streaming, defn_options)
609
+ {ref, defn_options} = run_streaming(streaming, defn_options)
602
610
{%{size: size, key: key} = batch, info} = handle_preprocessing(preprocessing, input)
603
611
{:ok, state} = handle_init(module, :inline, arg, [[batch_keys: [key]] ++ defn_options])
604
612
{:execute, function, _} = handle_batch(module, batch, 0, state)
 
@@ -621,23 +629,23 @@ defmodule Nx.Serving do
621
629
handle_postprocessing(postprocessing, execution_result, info)
622
630
end
623
631
624
- defp run_streaming_hooks(nil, defn_options), do: {nil, defn_options}
632
+ defp run_streaming(nil, defn_options), do: {nil, defn_options}
625
633
626
- defp run_streaming_hooks(%{hooks: hooks}, defn_options) do
634
+ defp run_streaming(%{hooks: hooks}, defn_options) do
627
635
parent = self()
628
636
ref = make_ref()
629
637
630
638
defn_options =
631
639
update_in(defn_options[:hooks], fn acc ->
632
640
Enum.reduce(hooks, acc || %{}, fn hook, acc ->
633
- Map.put(acc, hook, &run_streaming_hook(parent, ref, hook, &1))
641
+ Map.put(acc, hook, &run_hook(parent, ref, hook, &1))
634
642
end)
635
643
end)
636
644
637
645
{ref, defn_options}
638
646
end
639
647
640
- defp run_streaming_hook(pid, ref, hook, result) do
648
+ defp run_hook(pid, ref, hook, result) do
641
649
send(pid, {ref, {hook, 0, result}})
642
650
end
643
651
 
@@ -823,7 +831,7 @@ defmodule Nx.Serving do
823
831
preprocessing: preprocessing,
824
832
postprocessing: postprocessing,
825
833
limit: limit,
826
- streaming?: streaming?,
834
+ mode: mode,
827
835
batch_keys: batch_keys
828
836
} =
829
837
:persistent_term.get(persistent_key(name), nil) ||
 
@@ -835,9 +843,9 @@ defmodule Nx.Serving do
835
843
836
844
{batch, info} = handle_preprocessing(preprocessing, input)
837
845
838
- if batch.size > limit do
846
+ if mode == :hooks and batch.size > limit do
839
847
raise ArgumentError,
840
- "batch size (#{batch.size}) cannot exceed Nx.Serving server batch size of #{limit}"
848
+ "batch size (#{batch.size}) cannot exceed Nx.Serving server batch size of #{limit} when streaming hooks"
841
849
end
842
850
843
851
unless is_map_key(batch_keys, batch.key) do
 
@@ -849,17 +857,19 @@ defmodule Nx.Serving do
849
857
ref = :erlang.monitor(:process, pid, alias: :demonitor)
850
858
Process.send(pid, {__MODULE__, :batched_run, ref, batch}, [:noconnect])
851
859
852
- if streaming? do
853
- stream = receive_stream("batched_run/2", ref, batch.size)
854
- {:ok, handle_postprocessing(postprocessing, stream, info)}
855
- else
856
- case receive_batched(ref, batch.size, 0, [], nil) do
857
- {:done, tensor, metadata} ->
858
- {:ok, handle_postprocessing(postprocessing, {tensor, metadata}, info)}
860
+ case mode do
861
+ :execute ->
862
+ case receive_execute(ref, batch.size, 0, [], nil) do
863
+ {:ok, tensor, metadata} ->
864
+ {:ok, handle_postprocessing(postprocessing, {tensor, metadata}, info)}
859
865
860
- {:DOWN, reason} ->
861
- {:DOWN, reason}
862
- end
866
+ {:DOWN, reason} ->
867
+ {:DOWN, reason}
868
+ end
869
+
870
+ _ ->
871
+ stream = receive_stream("batched_run/2", ref, batch.size)
872
+ {:ok, handle_postprocessing(postprocessing, stream, info)}
863
873
end
864
874
end
865
875
 
@@ -885,7 +895,7 @@ defmodule Nx.Serving do
885
895
Node.spawn_monitor(node(pid), __MODULE__, :__distributed_batched_run__, args)
886
896
887
897
receive do
888
- {^ref, :stream} ->
898
+ {^ref, :hooks} ->
889
899
owner = self()
890
900
891
901
Stream.resource(
 
@@ -901,7 +911,7 @@ defmodule Nx.Serving do
901
911
{^ref, event} ->
902
912
{[event], :ok}
903
913
904
- {:DOWN, ^monitor_ref, _, _, {^ref, :stream}} ->
914
+ {:DOWN, ^monitor_ref, _, _, {^ref, :hooks}} ->
905
915
{:halt, :ok}
906
916
907
917
{:DOWN, ^monitor_ref, _, _, reason} ->
 
@@ -930,13 +940,13 @@ defmodule Nx.Serving do
930
940
931
941
case local_batched_run(pid, name, input) do
932
942
{:ok, result} ->
933
- %{streaming?: streaming?, distributed_postprocessing: dist_post} =
943
+ %{mode: mode, distributed_postprocessing: dist_post} =
934
944
:persistent_term.get(persistent_key(name))
935
945
936
- if streaming? do
937
- send(client_pid, {ref, :stream})
946
+ if mode == :hooks do
947
+ send(client_pid, {ref, :hooks})
938
948
Enum.each(dist_post.(result), &send(client_pid, {ref, &1}))
939
- exit({ref, :stream})
949
+ exit({ref, :hooks})
940
950
else
941
951
exit({ref, dist_post.(result)})
942
952
end
 
@@ -957,26 +967,31 @@ defmodule Nx.Serving do
957
967
raise "the stream returned from Nx.Serving.#{fun} must be consumed in the same process"
958
968
end
959
969
960
- {0, [], nil}
970
+ 0
961
971
end,
962
972
fn
963
- {index, acc, template} ->
964
- case receive_batched(ref, size, index, acc, template) do
965
- {:done, _tensor, _metadata} = result -> {[result], :done}
966
- {:hook, name, value, index_acc_template} -> {[{name, value}], index_acc_template}
967
- {:DOWN, reason} -> exit({reason, {Nx.Serving, :streaming, []}})
968
- end
969
-
970
- :done ->
973
+ ^size ->
971
974
{:halt, :done}
975
+
976
+ index ->
977
+ case receive_each(ref, size, index) do
978
+ {:hook, {hook, start, output}} ->
979
+ value = remove_maybe_padded(output, start, size)
980
+ {[{hook, value}], index}
981
+
982
+ {:batch, {output_start, output_size, output, metadata}} ->
983
+ value = remove_maybe_padded(output, output_start, output_size)
984
+ {[{:batch, value, metadata}], index + output_size}
985
+
986
+ {:DOWN, reason} ->
987
+ exit({reason, {Nx.Serving, :streaming, []}})
988
+ end
972
989
end,
973
990
fn _ -> :ok end
974
991
)
975
992
end
976
993
977
- defp receive_batched(ref, size, size, acc, {template, metadata}) do
978
- Process.demonitor(ref, [:flush])
979
-
994
+ defp receive_execute(_ref, size, size, acc, {template, metadata}) do
980
995
tensors =
981
996
acc
982
997
|> Enum.reverse()
 
@@ -987,21 +1002,16 @@ defmodule Nx.Serving do
987
1002
{tensor, tensors}
988
1003
end)
989
1004
990
- {:done, output, metadata}
1005
+ {:ok, output, metadata}
991
1006
end
992
1007
993
- defp receive_batched(ref, size, index, acc, template_metadata) do
994
- receive do
995
- {^ref, {hook, start, output}} ->
996
- output = remove_maybe_padded(output, start, size)
997
- {:hook, hook, output, {index, acc, template_metadata}}
998
-
999
- {^ref, {output_start, output_size, output, metadata}} ->
1008
+ defp receive_execute(ref, size, index, acc, _template_metadata) do
1009
+ case receive_each(ref, size, index) do
1010
+ {:batch, {output_start, output_size, output, metadata}} ->
1000
1011
# If we have a single response, slice and return immediately.
1001
1012
# Otherwise we collect their contents and build the concatenated result later.
1002
1013
if acc == [] and output_size == size - index do
1003
- Process.demonitor(ref, [:flush])
1004
- {:done, remove_maybe_padded(output, output_start, output_size), metadata}
1014
+ {:ok, remove_maybe_padded(output, output_start, output_size), metadata}
1005
1015
else
1006
1016
funs =
1007
1017
output
 
@@ -1011,9 +1021,26 @@ defmodule Nx.Serving do
1011
1021
)
1012
1022
|> Enum.reverse()
1013
1023
1014
- receive_batched(ref, size, index + output_size, [funs | acc], {output, metadata})
1024
+ receive_execute(ref, size, index + output_size, [funs | acc], {output, metadata})
1015
1025
end
1016
1026
1027
+ {:DOWN, reason} ->
1028
+ {:DOWN, reason}
1029
+ end
1030
+ end
1031
+
1032
+ defp receive_each(ref, size, index) do
1033
+ receive do
1034
+ {^ref, {_hook, _start, _output} = payload} ->
1035
+ {:hook, payload}
1036
+
1037
+ {^ref, {_output_start, output_size, _output, _metadata} = payload} ->
1038
+ if output_size == size - index do
1039
+ Process.demonitor(ref, [:flush])
1040
+ end
1041
+
1042
+ {:batch, payload}
1043
+
1017
1044
{:DOWN, ^ref, _, _, reason} ->
1018
1045
# We fake monitor messages, so still demonitor and flush.
1019
1046
Process.demonitor(ref, [:flush])
 
@@ -1035,7 +1062,7 @@ defmodule Nx.Serving do
1035
1062
Process.flag(:trap_exit, true)
1036
1063
partitions_opts = serving_partitions(serving, partitions?)
1037
1064
partitions_count = length(partitions_opts)
1038
- {partitions_opts, streaming_table} = serving_streaming(serving, partitions_opts)
1065
+ {mode, partitions_opts, hooks_table} = serving_streaming(serving, partitions_opts)
1039
1066
partitions_opts = Enum.map(partitions_opts, &Keyword.put(&1, :batch_keys, batch_keys))
1040
1067
{:ok, module_state} = handle_init(serving.module, :process, serving.arg, partitions_opts)
1041
1068
 
@@ -1046,7 +1073,7 @@ defmodule Nx.Serving do
1046
1073
preprocessing: serving.client_preprocessing,
1047
1074
postprocessing: serving.client_postprocessing,
1048
1075
distributed_postprocessing: serving.distributed_postprocessing,
1049
- streaming?: serving.streaming != nil,
1076
+ mode: mode,
1050
1077
batch_keys: Map.from_keys(batch_keys, [])
1051
1078
}
1052
1079
)
 
@@ -1069,7 +1096,7 @@ defmodule Nx.Serving do
1069
1096
tasks: [],
1070
1097
pending_batches: Map.from_keys(batch_keys, @empty_queue),
1071
1098
task_supervisor: task_supervisor,
1072
- streaming_table: streaming_table
1099
+ hooks_table: hooks_table
1073
1100
}
1074
1101
1075
1102
{:ok, state}
 
@@ -1085,7 +1112,11 @@ defmodule Nx.Serving do
1085
1112
end
1086
1113
1087
1114
defp serving_streaming(%Nx.Serving{streaming: nil}, partitions) do
1088
- {partitions, nil}
1115
+ {:execute, partitions, nil}
1116
+ end
1117
+
1118
+ defp serving_streaming(%Nx.Serving{streaming: %{hooks: []}}, partitions) do
1119
+ {:batches, partitions, nil}
1089
1120
end
1090
1121
1091
1122
defp serving_streaming(%Nx.Serving{streaming: %{hooks: hooks}}, partitions) do
 
@@ -1095,15 +1126,15 @@ defmodule Nx.Serving do
1095
1126
Enum.with_index(partitions, fn defn_options, index ->
1096
1127
update_in(defn_options[:hooks], fn acc ->
1097
1128
Enum.reduce(hooks, acc || %{}, fn hook, acc ->
1098
- Map.put(acc, hook, &server_streaming_hook(ets, index, hook, &1))
1129
+ Map.put(acc, hook, &server_hook(ets, index, hook, &1))
1099
1130
end)
1100
1131
end)
1101
1132
end)
1102
1133
1103
- {partitions, ets}
1134
+ {:hooks, partitions, ets}
1104
1135
end
1105
1136
1106
- defp server_streaming_hook(ets, index, hook, result) do
1137
+ defp server_hook(ets, index, hook, result) do
1107
1138
for {ref, start, _size} <- :ets.lookup_element(ets, index, 2) do
1108
1139
send(ref, {ref, {hook, start, result}})
1109
1140
end
 
@@ -1124,34 +1155,15 @@ defmodule Nx.Serving do
1124
1155
|> server_stack(key, ref, batch, :skip_timer)
1125
1156
|> server_execute(key)
1126
1157
1127
- # First entry in batch.
1128
- count == 0 ->
1129
- server_stack(state, key, ref, batch, :start_timer)
1130
-
1131
- # We don't exceed the limit.
1132
- batch.size + count < limit ->
1133
- server_stack(state, key, ref, batch, :skip_timer)
1134
-
1135
- # We go over the limit, but if streaming, we can't split.
1136
- batch.size + count > limit and state.streaming_table != nil ->
1158
+ # We go over the limit, but if using hooks, we can't split.
1159
+ batch.size + count > limit and state.hooks_table != nil ->
1137
1160
state
1138
1161
|> server_execute(key)
1139
- |> server_stack(key, ref, batch, :start_timer)
1162
+ |> server_stack(key, ref, batch, :set_timer)
1140
1163
1141
- # We go over the limit, split it across runs.
1142
- batch.size + count > limit ->
1143
- {current, next} = Nx.Batch.split(batch, limit - count)
1144
-
1145
- state
1146
- |> server_stack(key, ref, current, :skip_timer)
1147
- |> server_execute(key)
1148
- |> server_stack(key, ref, next, :start_timer)
1149
-
1150
- # Exact match.
1164
+ # Split as necessary.
1151
1165
true ->
1152
- state
1153
- |> server_stack(key, ref, batch, :skip_timer)
1154
- |> server_execute(key)
1166
+ server_stack_and_execute_loop(state, batch, count, key, ref)
1155
1167
end
1156
1168
1157
1169
{:noreply, state}
 
@@ -1175,7 +1187,7 @@ defmodule Nx.Serving do
1175
1187
case Enum.split_with(tasks, &(elem(&1, 0).ref == ref)) do
1176
1188
{[{_task, partition, _ref_sizes}], tasks} ->
1177
1189
Process.demonitor(ref, [:flush])
1178
- {:noreply, server_task_done(state, tasks, partition)}
1190
+ noreply_task_done_and_continue(state, tasks, partition)
1179
1191
1180
1192
_ ->
1181
1193
{:noreply, state}
 
@@ -1186,7 +1198,7 @@ defmodule Nx.Serving do
1186
1198
case Enum.split_with(tasks, &(elem(&1, 0).ref == ref)) do
1187
1199
{[{_task, partition, ref_sizes}], tasks} ->
1188
1200
server_reply_down(reason, ref_sizes)
1189
- {:noreply, server_task_done(state, tasks, partition)}
1201
+ noreply_task_done_and_continue(state, tasks, partition)
1190
1202
1191
1203
_ ->
1192
1204
{:noreply, state}
 
@@ -1198,6 +1210,11 @@ defmodule Nx.Serving do
1198
1210
{:noreply, state}
1199
1211
end
1200
1212
1213
+ @impl true
1214
+ def handle_continue(:maybe_task, state) do
1215
+ {:noreply, server_maybe_task(state)}
1216
+ end
1217
+
1201
1218
@impl true
1202
1219
def terminate(_reason, %{tasks: tasks, pending_batches: pending_batches}) do
1203
1220
for {batch_key, queue} <- pending_batches do
 
@@ -1223,21 +1240,50 @@ defmodule Nx.Serving do
1223
1240
:ok
1224
1241
end
1225
1242
1243
+ # We don't spawn the task here because, if it crashes,
1244
+ # we want a checked-in version of the state that knows
1245
+ # the current task has finished.
1246
+ defp noreply_task_done_and_continue(%{out_queue: out_queue} = state, tasks, partition) do
1247
+ out_queue = :queue.in(partition, out_queue)
1248
+ {:noreply, %{state | tasks: tasks, out_queue: out_queue}, {:continue, :maybe_task}}
1249
+ end
1250
+
1226
1251
defp server_reply_down(reason, ref_sizes) do
1227
1252
for {ref, _start, _size} <- ref_sizes do
1228
1253
send(ref, {:DOWN, ref, :process, self(), reason})
1229
1254
end
1230
1255
end
1231
1256
1257
+ defp server_stack_and_execute_loop(state, batch, count, key, ref) do
1258
+ %{limit: limit} = state
1259
+ %{size: size} = batch
1260
+
1261
+ cond do
1262
+ size + count < limit ->
1263
+ server_stack(state, key, ref, batch, :set_timer)
1264
+
1265
+ size + count > limit ->
1266
+ {current, batch} = Nx.Batch.split(batch, limit - count)
1267
+
1268
+ state
1269
+ |> server_stack(key, ref, current, :skip_timer)
1270
+ |> server_execute(key)
1271
+ |> server_stack_and_execute_loop(batch, 0, key, ref)
1272
+
1273
+ true ->
1274
+ state
1275
+ |> server_stack(key, ref, batch, :skip_timer)
1276
+ |> server_execute(key)
1277
+ end
1278
+ end
1279
+
1232
1280
defp server_stack(%{limit: limit} = state, key, ref, batch, timer_mode) do
1233
1281
stack_update(key, fn {stack, count, timer} when batch.size + count <= limit ->
1234
1282
timer =
1235
- case timer_mode do
1236
- :start_timer when timer == :none ->
1237
- Process.send_after(self(), {@timeout_message, key}, state.timeout)
1238
-
1239
- :skip_timer ->
1240
- timer
1283
+ if timer == :none and timer_mode == :set_timer do
1284
+ Process.send_after(self(), {@timeout_message, key}, state.timeout)
1285
+ else
1286
+ timer
1241
1287
end
1242
1288
1243
1289
{[{ref, batch} | stack], count + batch.size, timer}
 
@@ -1280,13 +1326,13 @@ defmodule Nx.Serving do
1280
1326
{batch_refs, Map.put(pending_batches, key, queue)}
1281
1327
end
1282
1328
1283
- %{module: module, module_state: module_state, streaming_table: streaming_table} = state
1329
+ %{module: module, module_state: module_state, hooks_table: hooks_table} = state
1284
1330
{:execute, function, module_state} = handle_batch(module, batch, partition, module_state)
1285
1331
1286
1332
wrapped_function = fn ->
1287
1333
:telemetry.span([:nx, :serving, :execute], %{module: module}, fn ->
1288
- if streaming_table do
1289
- :ets.insert(streaming_table, {partition, ref_sizes})
1334
+ if hooks_table do
1335
+ :ets.insert(hooks_table, {partition, ref_sizes})
1290
1336
end
1291
1337
1292
1338
{output, metadata} = function.()
 
@@ -1315,11 +1361,6 @@ defmodule Nx.Serving do
1315
1361
end
1316
1362
end
1317
1363
1318
- defp server_task_done(%{out_queue: out_queue} = state, tasks, partition) do
1319
- out_queue = :queue.in(partition, out_queue)
1320
- server_maybe_task(%{state | tasks: tasks, out_queue: out_queue})
1321
- end
1322
-
1323
1364
## Stack management
1324
1365
#
1325
1366
# The stack is stored in the process dictionary for performance
changed mix.exs
 
@@ -6,7 +6,7 @@ defmodule Nx.MixProject do
6
6
use Mix.Project
7
7
8
8
@source_url "https://github.com/elixir-nx/nx"
9
- @version "0.6.0"
9
+ @version "0.6.1"
10
10
11
11
def project do
12
12
[
 
@@ -41,7 +41,7 @@ defmodule Nx.MixProject do
41
41
[
42
42
{:complex, "~> 0.5"},
43
43
{:telemetry, "~> 0.4.0 or ~> 1.0"},
44
- {:ex_doc, "~> 0.29.0", only: :docs}
44
+ {:ex_doc, "~> 0.29", only: :docs}
45
45
]
46
46
end