Implement vectorization (UNTESTED)

This commit is contained in:
JCRaymond
2019-12-06 21:30:25 -05:00
parent 81f6a77adb
commit 04ecc147fb

View File

@@ -1,5 +1,6 @@
#include <cstdlib>
#include <immintrin.h>
#include "aligned_allocator.cpp"
#include <iostream>
#include <vector>
@@ -162,10 +163,10 @@ Coxeter E8() {
struct RelTable {
std::vector<int> coset_poss;
std::vector<Cos> init_cosets;
std::vector<Cos> start_cosets;
std::vector<Cos> end_cosets;
std::vector<Ind> start_inds;
std::vector<Ind> end_inds;
std::vector<Cos,aligned_allocator<Cos,ALIGN_SIZE>> start_cosets;
std::vector<Cos,aligned_allocator<Cos,ALIGN_SIZE>> end_cosets;
std::vector<Ind,aligned_allocator<Cos,ALIGN_SIZE>> start_inds;
std::vector<Ind,aligned_allocator<Cos,ALIGN_SIZE>> end_inds;
int num_rows;
Gen gen[2];
Ind end_ind;
@@ -270,9 +271,6 @@ int add_coset(const Coxeter &cox,
return -1;
}
//#define __AVX2__
#ifndef __AVX2__
/**
@@ -354,6 +352,224 @@ void learn(const Coxeter &cox, CosetTable &cosets,
#else
const auto mask = _mm256_set1_epi32(1);
inline void compute_lookups(Ind* inds, Cos* coss, Gen* gens, Cos* base, __m256i step, Cos* tar) {
_mm256_store_si256(
(__m256i*)tar,
_mm256_i32gather_epi32(
(int*)base,
_mm256_add_epi32(
_mm256_mullo_epi32(
_mm256_load_si256((__m256i*)coss),
step
),
_mm256_i32gather_epi32(
gens,
_mm256_and_si256(
_mm256_load_si256((__m256i*)inds),
mask
),
4
)
),
4
)
);
}
/**
* learn until it can't
*/
void learn(const Coxeter &cox, CosetTable &cosets,
std::vector<RelTable> &reltables) {
const int nrels = cox.nrels;
const int ngens = cox.ngens;
const auto step = _mm256_set1_epi32(ngens);
while (true) {
bool complete = true;
alignas(32) Gen gens[2];
alignas(32) Cos lookups[8];
alignas(32) Cos init_cosets[8];
alignas(32) Ind start_inds[8];
alignas(32) Ind end_inds[8];
alignas(32) Cos start_cosets[8];
alignas(32) Cos end_cosets[8];
#pragma omp parallel for schedule(static, 1) reduction(&:complete) private(gens, lookups, init_cosets, start_inds, end_inds, start_cosets, end_cosets)
for (unsigned int r = 0; r < nrels; ++r) {
auto &table = reltables[r];
gens[0] = table.gen[0];
gens[1] = table.gen[1];
unsigned int c;
bool redo_cval;
for (c = 0; c < ((table.num_rows>>3)<<3); c+=8) {
redo_cval = false;
for (int c_ = 0; c_ < 8; c_++) {
init_cosets[c_] = table.init_cosets[c+c_];
start_inds[c_] = table.start_inds[c+c_];
end_inds[c_] = table.end_inds[c+c_];
start_cosets[c_] = table.start_cosets[c+c_];
end_cosets[c_] = table.end_cosets[c+c_];
}
bool startdone = false;
bool need_reload = false;
int reload_idx;
int idx, lookup;
while (!startdone) {
startdone = true;
if (need_reload and idx < table.num_rows) {
init_cosets[reload_idx] = table.init_cosets[idx];
start_inds[reload_idx] = table.start_inds[idx];
end_inds[reload_idx] = table.end_inds[idx];
start_cosets[reload_idx] = table.start_cosets[idx];
end_cosets[reload_idx] = table.end_cosets[idx];
}
compute_lookups(start_inds, start_cosets, gens, &(cosets[0]), step, lookups);
for (int c_ = 0; c_ < 8; c_++) {
lookup = lookups[c_];
if (start_inds[c_] < end_inds[c_] and lookup >= 0) {
start_inds[c_]++;
start_cosets[c_] = lookups[c_];
startdone = false;
if (lookup > init_cosets[c_]) {
int idx = table.coset_poss[lookup];
if (idx >= 0) {
table.rem_row(idx);
if ( (idx>>3)<<3 == c ) {
need_reload = true;
reload_idx = (idx & 8);
break;
}
}
}
}
}
}
bool enddone = false;
while (!enddone) {
enddone = true;
if (need_reload and idx < table.num_rows) {
init_cosets[reload_idx] = table.init_cosets[idx];
start_inds[reload_idx] = table.start_inds[idx];
end_inds[reload_idx] = table.end_inds[idx];
start_cosets[reload_idx] = table.start_cosets[idx];
end_cosets[reload_idx] = table.end_cosets[idx];
}
compute_lookups(end_inds, end_cosets, gens, &(cosets[0]), step, lookups);
for (int c_ = 0; c_ < 8; c_++) {
lookup = lookups[c_];
if (start_inds[c_] < end_inds[c_] and lookup >= 0) {
end_inds[c_]--;
end_cosets[c_] = lookups[c_];
enddone = false;
if (lookup > init_cosets[c_]) {
int idx = table.coset_poss[lookup];
if (idx >= 0) {
table.rem_row(idx);
if ( (idx>>3)<<3 == c ) {
need_reload = true;
reload_idx = (idx & 8);
redo_cval = true;
break;
}
}
}
}
}
}
for (int c_ = 0; c_ < 8 and c+c_ < table.num_rows; c_++) {
table.start_inds[c+c_] = start_inds[c_];
table.end_inds[c+c_] = end_inds[c_];
table.start_cosets[c+c_] = start_cosets[c_];
table.end_cosets[c+c_] = end_cosets[c_];
}
for (int c_ = 8; c_ >= 0; c_--) {
if (c+c_ < table.num_rows)
continue;
Ind s_i = start_inds[c_];
if (start_inds[s_i] == end_inds[c_]) {
complete = false;
const Gen gen = gens[s_i&1];
Cos s_c = start_cosets[c_];
Cos e_c = end_cosets[c_];
cosets[s_c*ngens + gen] = e_c;
cosets[e_c*ngens + gen] = s_c;
table.rem_row(c+c_);
redo_cval = true;
}
}
if (redo_cval)
c-=8;
}
for (; c < table.num_rows; c++) {
auto s_i = table.start_inds[c];
auto e_i = table.end_inds[c];
auto s_c = table.start_cosets[c];
auto e_c = table.end_cosets[c];
auto i_c = table.init_cosets[c];
while (s_i < e_i) {
const Cos lookup = cosets[s_c*ngens + gens[s_i&1]];
if (lookup < 0) break;
s_i++;
s_c = lookup;
if (s_c > i_c) {
int idx = table.coset_poss[s_c];
if (idx >= 0)
table.rem_row(idx);
}
}
table.start_inds[c] = s_i;
table.start_cosets[c] = s_c;
while (s_i < e_i) {
const Cos lookup = cosets[e_c*ngens + gens[e_i&1]];
if (lookup < 0) break;
e_i--;
e_c = lookup;
if (e_c > i_c) {
int idx = table.coset_poss[e_c];
if (idx >= 0)
table.rem_row(idx);
}
}
table.end_inds[c] = e_i;
table.end_cosets[c] = e_c;
if (s_i == e_i) {
complete = false;
const Gen gen = gens[s_i&1];
cosets[s_c*ngens + gen] = e_c;
cosets[e_c*ngens + gen] = s_c;
table.rem_row(c);
c--;
}
}
}
if (complete) break;
}
}
#endif
CosetTable solve_tc(const Coxeter &cox, const Gens &subgens) {