I sometimes just write `f64::ln(num)` though. Bit verbose with the type all the time, but I don't think it's too bad.
I sometimes just write `f64::ln(num)` though. Bit verbose with the type all the time, but I don't think it's too bad.
Also seems to work with onnx (github.com/pymc-devs/nu...)
But for some reason I can't find any references in the jax docs. I'm really confused by this by the way, and maybe I just misunderstand something...
Also seems to work with onnx (github.com/pymc-devs/nu...)
But for some reason I can't find any references in the jax docs. I'm really confused by this by the way, and maybe I just misunderstand something...
One thing that has always bugged me in jax is that I can't find a way to use multiple cuda streams. I think at least a part of the NUTS overhead goes away if different chains run in different streams, so that the GPU doesn't have to sit around idle when a different chain could run.
One thing that has always bugged me in jax is that I can't find a way to use multiple cuda streams. I think at least a part of the NUTS overhead goes away if different chains run in different streams, so that the GPU doesn't have to sit around idle when a different chain could run.