Skip to content

Commit da3e865

Browse files
authored
Use a mutable copy of input if inplace scaling is required (#243)
* Use a mutable copy of input if inplace scaling is required * Convert to Array instead of using similar * Add test * Fix type-signature of _plan_mul
1 parent 2827377 commit da3e865

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

src/chebyshevtransform.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ end
4949

5050

5151
# convert x if necessary
52-
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x)
53-
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x))
52+
_maybemutablecopy(x::StridedArray{T}, ::Type{T}) where {T} = x
53+
_maybemutablecopy(x, T) = Array{T}(x)
54+
@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, _maybemutablecopy(x, T))
5455

5556

5657
for op in (:ldiv, :lmul)
@@ -309,7 +310,8 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,2,K,false,N},
309310
_icheb2_rescale!(P.plan.region, y)
310311
end
311312

312-
*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} = mul!(similar(x), P, x)
313+
*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} =
314+
mul!(similar(x), P, _maybemutablecopy(x, T))
313315
ichebyshevtransform!(x::AbstractArray, dims...; kwds...) = plan_ichebyshevtransform!(x, dims...; kwds...)*x
314316
ichebyshevtransform(x, dims...; kwds...) = plan_ichebyshevtransform(x, dims...; kwds...)*x
315317

test/chebyshevtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ using FastTransforms, Test
451451
@testset "immutable vectors" begin
452452
F = plan_chebyshevtransform([1.,2,3])
453453
@test chebyshevtransform(1.0:3) == F * (1:3)
454+
@test ichebyshevtransform(1.0:3) == ichebyshevtransform([1.0:3;])
454455
end
455456

456457
@testset "inv" begin

0 commit comments

Comments
 (0)