nrels and ngens to __constant__
This commit is contained in:
@@ -9,7 +9,9 @@
|
||||
#include "util.h"
|
||||
#include "groups.h"
|
||||
|
||||
__constant__ Rel rels[128];
|
||||
__constant__ Rel c_rels[128];
|
||||
__constant__ int c_nrels[1];
|
||||
__constant__ int c_ngens[1];
|
||||
|
||||
struct Row {
|
||||
int rel;
|
||||
@@ -38,13 +40,10 @@ std::ostream &operator<<(std::ostream &o, const Row &r) {
|
||||
|
||||
// this performs a pass on one relation table row, applying learned data to the coset table.
|
||||
struct Solver {
|
||||
int ngens;
|
||||
int *cosets;
|
||||
|
||||
Solver(int ngens,
|
||||
thrust::device_vector<int> &cosets)
|
||||
: ngens(ngens),
|
||||
cosets(thrust::raw_pointer_cast(cosets.data())) {
|
||||
Solver(thrust::device_vector<int> &cosets)
|
||||
: cosets(thrust::raw_pointer_cast(cosets.data())) {
|
||||
}
|
||||
|
||||
__device__
|
||||
@@ -54,25 +53,25 @@ struct Solver {
|
||||
}
|
||||
|
||||
while (r.r - r.l > 0) {
|
||||
int gen = rels[r.rel].gens[r.l & 1];
|
||||
int next = cosets[r.from * ngens + gen];
|
||||
int gen = c_rels[r.rel].gens[r.l & 1];
|
||||
int next = cosets[r.from * *c_ngens + gen];
|
||||
if (next < 0) break;
|
||||
r.l++;
|
||||
r.from = next;
|
||||
}
|
||||
|
||||
while (r.r - r.l > 0) {
|
||||
int gen = rels[r.rel].gens[r.r & 1];
|
||||
int next = cosets[r.to * ngens + gen];
|
||||
int gen = c_rels[r.rel].gens[r.r & 1];
|
||||
int next = cosets[r.to * *c_ngens + gen];
|
||||
if (next < 0) break;
|
||||
r.r--;
|
||||
r.to = next;
|
||||
}
|
||||
|
||||
if (r.r - r.l <= 0) {
|
||||
int gen = rels[r.rel].gens[r.l & 1];
|
||||
cosets[r.from * ngens + gen] = r.to;
|
||||
cosets[r.to * ngens + gen] = r.from;
|
||||
int gen = c_rels[r.rel].gens[r.l & 1];
|
||||
cosets[r.from * *c_ngens + gen] = r.to;
|
||||
cosets[r.to * *c_ngens + gen] = r.from;
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -102,7 +101,7 @@ struct RowGen {
|
||||
|
||||
__device__
|
||||
Row operator()(int rel) {
|
||||
return Row(rel, coset, rels[rel].mul * 2);
|
||||
return Row(rel, coset, c_rels[rel].mul * 2);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -189,7 +188,7 @@ thrust::device_vector<int> solve(
|
||||
// will break out later
|
||||
while (true) {
|
||||
// create a solver and apply it until nothing is being learned
|
||||
Solver solve(ngens, cosets);
|
||||
Solver solve(cosets);
|
||||
thrust::for_each(
|
||||
thrust::device,
|
||||
rows.begin(), rows.end(),
|
||||
@@ -221,11 +220,15 @@ int main(int argc, const char* argv[]) {
|
||||
Coxeter cox;
|
||||
cox = proc_args(argc, argv);
|
||||
std::vector<int> subs = {};
|
||||
int nrels = cox.rels.size();
|
||||
int ngens = cox.ngens;
|
||||
|
||||
cudaMemcpyToSymbol(rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
|
||||
cudaMemcpyToSymbol(c_ngens, &ngens, sizeof(int));
|
||||
cudaMemcpyToSymbol(c_nrels, &nrels, sizeof(int));
|
||||
cudaMemcpyToSymbol(c_rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
|
||||
|
||||
auto s = std::chrono::system_clock::now();
|
||||
thrust::host_vector<int> cosets = solve(cox.ngens, cox.rels.size(), subs);
|
||||
thrust::host_vector<int> cosets = solve(cox.ngens, nrels, subs);
|
||||
auto e = std::chrono::system_clock::now();
|
||||
|
||||
std::chrono::duration<float> diff = e - s;
|
||||
|
||||
Reference in New Issue
Block a user