Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing KeyIterator used for JuMPArray to work also when index sets are non indexable. #836

Merged
merged 8 commits into from
Sep 27, 2016
1 change: 1 addition & 0 deletions src/JuMPArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
end

Base.getindex(d::JuMPArray, ::Colon) = d.innerArray[:]

@generated function Base.getindex{T,N,NT<:NTuple}(d::JuMPArray{T,N,NT}, idx...)
if N != length(idx)
error("Indexed into a JuMPArray with $(length(idx)) indices (expected $N indices)")
Expand Down
81 changes: 75 additions & 6 deletions src/JuMPContainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ Base.ndims{T,N}(x::JuMPDict{T,N}) = N
Base.abs(x::JuMPDict) = map(abs, x)
# avoid dangerous behavior with "end" (#730)
Base.endof(x::JuMPArray) = error("endof() (and \"end\" syntax) not implemented for JuMPArray objects.")
Base.size(x::JuMPArray) = error("size (and \"end\" syntax) not implemented for JuMPArray objects. Use JuMP.size if you want to access the dimensions.")
Base.size(x::JuMPArray,k) = error("size (and \"end\" syntax) not implemented for JuMPArray objects. Use JuMP.size if you want to access the dimensions.")
Base.size(x::JuMPArray) = error(string("size (and \"end\" syntax) not implemented for JuMPArray objects.",
"Use JuMP.size if you want to access the dimensions."))
Base.size(x::JuMPArray,k) = error(string("size (and \"end\" syntax) not implemented for JuMPArray objects.",
" Use JuMP.size if you want to access the dimensions."))
size(x::JuMPArray) = size(x.innerArray)
size(x::JuMPArray,k) = size(x.innerArray,k)
# for uses of size() within JuMP
Expand Down Expand Up @@ -183,13 +185,80 @@ Base.length(it::ValueIterator) = length(it.x)

type KeyIterator{JA<:JuMPArray}
x::JA
dim::Int
next_k_cache::Array{Any,1}
function KeyIterator(d)
n = ndims(d.innerArray)
new(d, n, Array(Any, n+1))
end
end

KeyIterator{JA}(d::JA) = KeyIterator{JA}(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method necessary with the inner constructor above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleting line 196 gives this error

ERROR: LoadError: MethodError: Cannot convert an object of type JuMP.JuMPArray{Float64,2,Tuple{UnitRange{Int32},UnitRange{Int32}}} to an object of type JuMP.KeyIterator{JA<:JuMP.JuMPArray}
This may have arisen from a call to the constructor JuMP.KeyIterator{JA<:JuMP.JuMPArray}(...),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, strange. OK, thanks.


function indexability(x::JuMPArray)
for i in 1:length(x.indexsets)
if !method_exists(getindex, (typeof(x.indexsets[i]),))
return false
end
end

return true
end

function Base.start(it::KeyIterator)
if indexability(it.x)
return start(it.x.innerArray)
else
return notindexable_start(it.x)
end
end

@generated function notindexable_start{T,N,NT}(x::JuMPArray{T,N,NT})
quote
$(Expr(:tuple, 0, [:(start(x.indexsets[$i])) for i in 1:N]...))
end
end
Base.start(it::KeyIterator) = start(it.x.innerArray)
@generated __next{T,N,NT}(x::JuMPArray{T,N,NT}, k) =

@generated function _next{T,N,NT}(x::JuMPArray{T,N,NT}, k::Tuple)
quote
$(Expr(:tuple, [:(next(x.indexsets[$i], k[$i+1])[1]) for i in 1:N]...))
end
end

function Base.next(it::KeyIterator, k::Tuple)
cartesian_key = _next(it.x, k)
pos = -1
for i in 1:it.dim
if !done(it.x.indexsets[i], next(it.x.indexsets[i], k[i+1])[2] )
pos = i
break
end
end
if pos == - 1
it.next_k_cache[1] = 1
return cartesian_key, tuple(it.next_k_cache...)
end
it.next_k_cache[1] = 0
for i in 1:it.dim
if i < pos
it.next_k_cache[i+1] = start(it.x.indexsets[i])
elseif i == pos
it.next_k_cache[i+1] = next(it.x.indexsets[i], k[i+1])[2]
else
it.next_k_cache[i+1] = k[i+1]
end
end
cartesian_key, tuple(it.next_k_cache...)
end

Base.done(it::KeyIterator, k::Tuple) = (k[1] == 1)

@generated __next{T,N,NT}(x::JuMPArray{T,N,NT}, k::Integer) =
quote
subidx = ind2sub(size(x),k)
$(Expr(:tuple, [:(x.indexsets[$i][subidx[$i]]) for i in 1:N]...)), next(x.innerArray,k)[2]
end
Base.next(it::KeyIterator, k) = __next(it.x,k)
Base.done(it::KeyIterator, k) = done(it.x.innerArray, k)
Base.next(it::KeyIterator, k) = __next(it.x,k::Integer)
Base.done(it::KeyIterator, k) = done(it.x.innerArray, k::Integer)

Base.length(it::KeyIterator) = length(it.x.innerArray)
21 changes: 21 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,3 +925,24 @@ facts("[model] Nonliteral exponents in @constraint") do
@fact m.quadconstr[3].terms --> x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 + x^2 - 1
@fact m.quadconstr[4].terms --> QuadExpr(x + x + x - 1)
end

facts("[model] sets used as indexsets in JuMPArray") do
set = IntSet()
for i in 4:5
push!(set, i)
end
set2 = IntSet()
for i in 21:23
push!(set2, i)
end
m = Model()
@variable(m, x[set, set2], Bin)
@objective(m , Max, sum{sum{x[e,p], e in set}, p in set2})
solve(m)
sol = getvalue(x)
checked_objval = 0
for i in keys(sol)
checked_objval += sol[i...]
end
@fact checked_objval --> 6
end