nrels and ngens to __constant__

This commit is contained in:
2019-12-10 20:56:46 -05:00
parent 2e24bb6bb2
commit 7297bc1cea

View File

@@ -9,7 +9,9 @@
#include "util.h"
#include "groups.h"
__constant__ Rel rels[128];
__constant__ Rel c_rels[128];
__constant__ int c_nrels[1];
__constant__ int c_ngens[1];
struct Row {
int rel;
@@ -38,13 +40,10 @@ std::ostream &operator<<(std::ostream &o, const Row &r) {
// this performs a pass on one relation table row, applying learned data to the coset table.
struct Solver {
int ngens;
int *cosets;
Solver(int ngens,
thrust::device_vector<int> &cosets)
: ngens(ngens),
cosets(thrust::raw_pointer_cast(cosets.data())) {
Solver(thrust::device_vector<int> &cosets)
: cosets(thrust::raw_pointer_cast(cosets.data())) {
}
__device__
@@ -54,25 +53,25 @@ struct Solver {
}
while (r.r - r.l > 0) {
int gen = rels[r.rel].gens[r.l & 1];
int next = cosets[r.from * ngens + gen];
int gen = c_rels[r.rel].gens[r.l & 1];
int next = cosets[r.from * *c_ngens + gen];
if (next < 0) break;
r.l++;
r.from = next;
}
while (r.r - r.l > 0) {
int gen = rels[r.rel].gens[r.r & 1];
int next = cosets[r.to * ngens + gen];
int gen = c_rels[r.rel].gens[r.r & 1];
int next = cosets[r.to * *c_ngens + gen];
if (next < 0) break;
r.r--;
r.to = next;
}
if (r.r - r.l <= 0) {
int gen = rels[r.rel].gens[r.l & 1];
cosets[r.from * ngens + gen] = r.to;
cosets[r.to * ngens + gen] = r.from;
int gen = c_rels[r.rel].gens[r.l & 1];
cosets[r.from * *c_ngens + gen] = r.to;
cosets[r.to * *c_ngens + gen] = r.from;
return;
}
}
@@ -102,7 +101,7 @@ struct RowGen {
__device__
Row operator()(int rel) {
return Row(rel, coset, rels[rel].mul * 2);
return Row(rel, coset, c_rels[rel].mul * 2);
}
};
@@ -189,7 +188,7 @@ thrust::device_vector<int> solve(
// will break out later
while (true) {
// create a solver and apply it until nothing is being learned
Solver solve(ngens, cosets);
Solver solve(cosets);
thrust::for_each(
thrust::device,
rows.begin(), rows.end(),
@@ -221,11 +220,15 @@ int main(int argc, const char* argv[]) {
Coxeter cox;
cox = proc_args(argc, argv);
std::vector<int> subs = {};
int nrels = cox.rels.size();
int ngens = cox.ngens;
cudaMemcpyToSymbol(rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
cudaMemcpyToSymbol(c_ngens, &ngens, sizeof(int));
cudaMemcpyToSymbol(c_nrels, &nrels, sizeof(int));
cudaMemcpyToSymbol(c_rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
auto s = std::chrono::system_clock::now();
thrust::host_vector<int> cosets = solve(cox.ngens, cox.rels.size(), subs);
thrust::host_vector<int> cosets = solve(cox.ngens, nrels, subs);
auto e = std::chrono::system_clock::now();
std::chrono::duration<float> diff = e - s;