Skip to content

Commit

Permalink
Add spin_pols_iter to iterate over a process' spin/pol combinations (#…
Browse files Browse the repository at this point in the history
…118)

As the title says, this adds an iterator to yield all possible
combinations of spins and polarizations allowed by a process' set
`spin_pols()`. For example:

```Julia
julia> using QEDbase; using QEDcore; using QEDprocesses;

julia> proc = ScatteringProcess((Photon(), Photon(), Photon(), Electron()), (Photon(), Electron()), (SyncedPolarization(1), SyncedPolarization(2), SyncedPolarization(1), SpinUp()), (SyncedPolarization(2), AllSpin()))
generic QED process
    incoming: photon (synced polarization 1), photon (synced polarization 2), photon (synced polarization 1), electron (spin up)
    outgoing: photon (synced polarization 2), electron (all spins)


julia> for sp_combo in spin_pols_iter(proc) println(sp_combo) end
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin up))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin up))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin up))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin up))
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin down))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin down))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin down))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin down))

julia> length(spin_pols_iter(proc))
8
```
The above is also a `jldoctest`.

As a side-note I also added an alias of `SyncedPol` to
`SyncedPolarization`.

The code is not incredibly concise and also not incredibly fast, but for
the reasonable cases that I tested `@benchmark` reports well under 1ms.
Since I don't think this iterator would be the critical path of anything
this should be fine.
The only problem I could see is that due to everything using `Tuple`s in
its arguments, the compile time is relatively large. If this becomes a
problem we could change it to using `Vector`s instead, likely trading
some runtime for much better compile time.

Fixes #107
  • Loading branch information
AntonReinhard authored Sep 12, 2024
1 parent c544490 commit 25ce9bd
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8"
QEDbase = "10e22c08-3ccb-4172-bfcf-7d7aa3d04d93"
QEDcore = "35dc0263-cb5f-4c33-a114-1d7f54ab753e"
QEDprocesses = "46de9c38-1bb3-4547-a1ec-da24d767fdad"
4 changes: 3 additions & 1 deletion src/QEDbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ export AbstractDefinitePolarization, AbstractIndefinitePolarization
export PolarizationX, PolX, PolarizationY, PolY, AllPolarization, AllPol
export AbstractDefiniteSpin, AbstractIndefiniteSpin
export SpinUp, SpinDown, AllSpin
export SyncedSpin, SyncedPolarization
export SyncedSpin, SyncedPolarization, SyncedPol
export spin_pols_iter

# probabilities
export differential_probability, unsafe_differential_probability
Expand Down Expand Up @@ -114,6 +115,7 @@ include("interfaces/phase_space_point.jl")
include("implementations/process/momenta.jl")
include("implementations/process/particles.jl")
include("implementations/process/spin_pols.jl")
include("implementations/process/spin_pol_iterator.jl")

include("implementations/cross_section/diff_probability.jl")
include("implementations/cross_section/diff_cross_section.jl")
Expand Down
132 changes: 132 additions & 0 deletions src/implementations/process/spin_pol_iterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
SpinPolIter
An iterator type to iterate over spin and polarization combinations. Should be used through [`spin_pols_iter`](@ref).
"""
struct SpinPolIter{I,O}
# product iterator doing the actual iterating
product_iter::Base.Iterators.ProductIterator
# lookup table for which indices go where, translating the base iterator to the actual result
indexing_lut::Tuple{NTuple{I,Int},NTuple{O,Int}}
end

"""
all_spin_pols(process::AbstractProcessDefinition)
This function returns an iterator, yielding every fully definite combination of spins and polarizations allowed by the
process' [`spin_pols`](@ref). Each returned element is a Tuple of the incoming and the outgoing spins and polarizations,
in the order of the process' own spins and polarizations.
This works together with the definite spins and polarizations, [`AllSpin`](@ref), [`AllPolarization`](@ref), and the synced versions
[`SyncedPolarization`](@ref) and [`SyncedSpin`](@ref).
```jldoctest
julia> using QEDbase; using QEDcore; using QEDprocesses;
julia> proc = ScatteringProcess((Photon(), Photon(), Photon(), Electron()), (Photon(), Electron()), (SyncedPolarization(1), SyncedPolarization(2), SyncedPolarization(1), SpinUp()), (SyncedPolarization(2), AllSpin()))
generic QED process
incoming: photon (synced polarization 1), photon (synced polarization 2), photon (synced polarization 1), electron (spin up)
outgoing: photon (synced polarization 2), electron (all spins)
julia> for sp_combo in spin_pols_iter(proc) println(sp_combo) end
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin up))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin up))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin up))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin up))
((x-polarized, x-polarized, x-polarized, spin up), (x-polarized, spin down))
((y-polarized, x-polarized, y-polarized, spin up), (x-polarized, spin down))
((x-polarized, y-polarized, x-polarized, spin up), (y-polarized, spin down))
((y-polarized, y-polarized, y-polarized, spin up), (y-polarized, spin down))
julia> length(spin_pols_iter(proc))
8
```
"""
function spin_pols_iter(process::AbstractProcessDefinition)
DEF_SPINS = (SpinUp(), SpinDown())
DEF_POLS = (PolX(), PolY())

in_sp = incoming_spin_pols(process)
I = length(in_sp)
out_sp = outgoing_spin_pols(process)
O = length(out_sp)

# concatenate for now for easier indices, split again later
sps = (in_sp..., out_sp...)

# keep indices of first seen SyncedSpins or SyncedPols
synced_seen = Dict{AbstractSpinOrPolarization,Int}()
index = 0
for sp in sps
index += 1
if !(sp isa SyncedSpin || sp isa SyncedPolarization)
continue
end
if !haskey(synced_seen, sp)
synced_seen[sp] = index
end
end

# keep indices of the synced spins/pols in the iterator (not necessarily the same as synced_seen)
synced_indices = Dict{AbstractSpinOrPolarization,Int}()

iter_tuples = Vector()
lut = Vector{Int}()
index = 0
for sp in sps
index += 1
if sp isa AbstractDefiniteSpin || sp isa AbstractDefinitePolarization
push!(iter_tuples, (sp,))
push!(lut, length(iter_tuples))
elseif sp isa SyncedSpin
# check if it's the first synced
if index == synced_seen[sp]
push!(iter_tuples, DEF_SPINS)
synced_indices[sp] = length(iter_tuples)
end
push!(lut, synced_indices[sp])
elseif sp isa SyncedPolarization
if index == synced_seen[sp]
push!(iter_tuples, DEF_POLS)
synced_indices[sp] = length(iter_tuples)
end
push!(lut, synced_indices[sp])
elseif sp isa AllSpin
push!(iter_tuples, DEF_SPINS)
push!(lut, length(iter_tuples))
elseif sp isa AllPol
push!(iter_tuples, DEF_POLS)
push!(lut, length(iter_tuples))
end
end

return SpinPolIter(
Iterators.product(iter_tuples...),
(tuple(lut[begin:I]...), tuple(lut[(I + 1):end]...)),
)
end

function Base.iterate(iterator::SpinPolIter, state=nothing)
local prod_iter_res
if isnothing(state)
prod_iter_res = iterate(iterator.product_iter)
else
prod_iter_res = iterate(iterator.product_iter, state)
end

if isnothing(prod_iter_res)
return nothing
end
prod_iter_res, state = prod_iter_res

# translate prod_iter_res into actual result
in_t = ((prod_iter_res[i] for i in iterator.indexing_lut[1])...,)
out_t = ((prod_iter_res[i] for i in iterator.indexing_lut[2])...,)

return (in_t, out_t), state
end

function Base.length(iterator::SpinPolIter)
return length(iterator.product_iter)
end
3 changes: 3 additions & 0 deletions src/interfaces/particles/spin_pol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ struct SyncedPolarization{N} <: AbstractIndefinitePolarization
return new{N}()
end
end
const SyncedPol = SyncedPolarization
Base.show(io::IO, ::SyncedPolarization{N}) where {N} = print(io, "synced polarization $N")

"""
SyncedSpin{N::Int} <: AbstractIndefiniteSpin
Expand All @@ -213,3 +215,4 @@ struct SyncedSpin{N} <: AbstractIndefiniteSpin
return new{N}()
end
end
Base.show(io::IO, ::SyncedSpin{N}) where {N} = print(io, "synced spin $N")
41 changes: 41 additions & 0 deletions test/particle_properties.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using QEDbase
using QEDcore
using StaticArrays
using Random

include("test_implementation/TestImplementation.jl")
using .TestImplementation: TestParticleBoson, TestParticleFermion

# test function to test scalar broadcasting
test_broadcast(x::AbstractParticle) = x
test_broadcast(x::ParticleDirection) = x
Expand Down Expand Up @@ -29,3 +33,40 @@ test_broadcast(x::AbstractSpinOrPolarization) = x
end
end
end

TESTPROCS = (
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(AllPol(), AllSpin()),
(AllPol(), AllSpin()),
),
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(SyncedPol(1), SyncedPol(1), AllSpin()),
(AllPol(), AllSpin()),
),
TestImplementation.TestProcessSP(
(TestParticleBoson(), TestParticleBoson(), TestParticleFermion()),
(TestParticleBoson(), TestParticleFermion()),
(SyncedPol(1), SyncedPol(2), SyncedSpin(2)),
(SyncedPol(2), SyncedSpin(2)),
),
)

@testset "spin_pol iterator ($proc)" for proc in TESTPROCS
@test length(spin_pols_iter(proc)) == multiplicity(proc)

for combinations in spin_pols_iter(proc)
@test length(combinations) == 2
in_comb, out_comb = combinations

@test length(in_comb) == length(incoming_particles(proc))
@test length(out_comb) == length(outgoing_particles(proc))

for sp in Iterators.flatten((in_comb, out_comb))
@test sp isa AbstractDefiniteSpin || sp isa AbstractDefinitePolarization
end
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using QEDbase
using Test
using SafeTestsets

Expand Down

0 comments on commit 25ce9bd

Please sign in to comment.