nrels and ngens to __constant__

This commit is contained in:
2019-12-10 20:56:46 -05:00
parent 2e24bb6bb2
commit 7297bc1cea

View File

@@ -9,7 +9,9 @@
#include "util.h" #include "util.h"
#include "groups.h" #include "groups.h"
__constant__ Rel rels[128]; __constant__ Rel c_rels[128];
__constant__ int c_nrels[1];
__constant__ int c_ngens[1];
struct Row { struct Row {
int rel; int rel;
@@ -38,13 +40,10 @@ std::ostream &operator<<(std::ostream &o, const Row &r) {
// this performs a pass on one relation table row, applying learned data to the coset table. // this performs a pass on one relation table row, applying learned data to the coset table.
struct Solver { struct Solver {
int ngens;
int *cosets; int *cosets;
Solver(int ngens, Solver(thrust::device_vector<int> &cosets)
thrust::device_vector<int> &cosets) : cosets(thrust::raw_pointer_cast(cosets.data())) {
: ngens(ngens),
cosets(thrust::raw_pointer_cast(cosets.data())) {
} }
__device__ __device__
@@ -54,25 +53,25 @@ struct Solver {
} }
while (r.r - r.l > 0) { while (r.r - r.l > 0) {
int gen = rels[r.rel].gens[r.l & 1]; int gen = c_rels[r.rel].gens[r.l & 1];
int next = cosets[r.from * ngens + gen]; int next = cosets[r.from * *c_ngens + gen];
if (next < 0) break; if (next < 0) break;
r.l++; r.l++;
r.from = next; r.from = next;
} }
while (r.r - r.l > 0) { while (r.r - r.l > 0) {
int gen = rels[r.rel].gens[r.r & 1]; int gen = c_rels[r.rel].gens[r.r & 1];
int next = cosets[r.to * ngens + gen]; int next = cosets[r.to * *c_ngens + gen];
if (next < 0) break; if (next < 0) break;
r.r--; r.r--;
r.to = next; r.to = next;
} }
if (r.r - r.l <= 0) { if (r.r - r.l <= 0) {
int gen = rels[r.rel].gens[r.l & 1]; int gen = c_rels[r.rel].gens[r.l & 1];
cosets[r.from * ngens + gen] = r.to; cosets[r.from * *c_ngens + gen] = r.to;
cosets[r.to * ngens + gen] = r.from; cosets[r.to * *c_ngens + gen] = r.from;
return; return;
} }
} }
@@ -102,7 +101,7 @@ struct RowGen {
__device__ __device__
Row operator()(int rel) { Row operator()(int rel) {
return Row(rel, coset, rels[rel].mul * 2); return Row(rel, coset, c_rels[rel].mul * 2);
} }
}; };
@@ -189,7 +188,7 @@ thrust::device_vector<int> solve(
// will break out later // will break out later
while (true) { while (true) {
// create a solver and apply it until nothing is being learned // create a solver and apply it until nothing is being learned
Solver solve(ngens, cosets); Solver solve(cosets);
thrust::for_each( thrust::for_each(
thrust::device, thrust::device,
rows.begin(), rows.end(), rows.begin(), rows.end(),
@@ -221,11 +220,15 @@ int main(int argc, const char* argv[]) {
Coxeter cox; Coxeter cox;
cox = proc_args(argc, argv); cox = proc_args(argc, argv);
std::vector<int> subs = {}; std::vector<int> subs = {};
int nrels = cox.rels.size();
int ngens = cox.ngens;
cudaMemcpyToSymbol(rels, cox.rels.data(), cox.rels.size() * sizeof(Rel)); cudaMemcpyToSymbol(c_ngens, &ngens, sizeof(int));
cudaMemcpyToSymbol(c_nrels, &nrels, sizeof(int));
cudaMemcpyToSymbol(c_rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
auto s = std::chrono::system_clock::now(); auto s = std::chrono::system_clock::now();
thrust::host_vector<int> cosets = solve(cox.ngens, cox.rels.size(), subs); thrust::host_vector<int> cosets = solve(cox.ngens, nrels, subs);
auto e = std::chrono::system_clock::now(); auto e = std::chrono::system_clock::now();
std::chrono::duration<float> diff = e - s; std::chrono::duration<float> diff = e - s;