nrels and ngens to __constant__
This commit is contained in:
@@ -9,7 +9,9 @@
|
|||||||
#include "util.h"
|
#include "util.h"
|
||||||
#include "groups.h"
|
#include "groups.h"
|
||||||
|
|
||||||
__constant__ Rel rels[128];
|
__constant__ Rel c_rels[128];
|
||||||
|
__constant__ int c_nrels[1];
|
||||||
|
__constant__ int c_ngens[1];
|
||||||
|
|
||||||
struct Row {
|
struct Row {
|
||||||
int rel;
|
int rel;
|
||||||
@@ -38,13 +40,10 @@ std::ostream &operator<<(std::ostream &o, const Row &r) {
|
|||||||
|
|
||||||
// this performs a pass on one relation table row, applying learned data to the coset table.
|
// this performs a pass on one relation table row, applying learned data to the coset table.
|
||||||
struct Solver {
|
struct Solver {
|
||||||
int ngens;
|
|
||||||
int *cosets;
|
int *cosets;
|
||||||
|
|
||||||
Solver(int ngens,
|
Solver(thrust::device_vector<int> &cosets)
|
||||||
thrust::device_vector<int> &cosets)
|
: cosets(thrust::raw_pointer_cast(cosets.data())) {
|
||||||
: ngens(ngens),
|
|
||||||
cosets(thrust::raw_pointer_cast(cosets.data())) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__
|
__device__
|
||||||
@@ -54,25 +53,25 @@ struct Solver {
|
|||||||
}
|
}
|
||||||
|
|
||||||
while (r.r - r.l > 0) {
|
while (r.r - r.l > 0) {
|
||||||
int gen = rels[r.rel].gens[r.l & 1];
|
int gen = c_rels[r.rel].gens[r.l & 1];
|
||||||
int next = cosets[r.from * ngens + gen];
|
int next = cosets[r.from * *c_ngens + gen];
|
||||||
if (next < 0) break;
|
if (next < 0) break;
|
||||||
r.l++;
|
r.l++;
|
||||||
r.from = next;
|
r.from = next;
|
||||||
}
|
}
|
||||||
|
|
||||||
while (r.r - r.l > 0) {
|
while (r.r - r.l > 0) {
|
||||||
int gen = rels[r.rel].gens[r.r & 1];
|
int gen = c_rels[r.rel].gens[r.r & 1];
|
||||||
int next = cosets[r.to * ngens + gen];
|
int next = cosets[r.to * *c_ngens + gen];
|
||||||
if (next < 0) break;
|
if (next < 0) break;
|
||||||
r.r--;
|
r.r--;
|
||||||
r.to = next;
|
r.to = next;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (r.r - r.l <= 0) {
|
if (r.r - r.l <= 0) {
|
||||||
int gen = rels[r.rel].gens[r.l & 1];
|
int gen = c_rels[r.rel].gens[r.l & 1];
|
||||||
cosets[r.from * ngens + gen] = r.to;
|
cosets[r.from * *c_ngens + gen] = r.to;
|
||||||
cosets[r.to * ngens + gen] = r.from;
|
cosets[r.to * *c_ngens + gen] = r.from;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,7 +101,7 @@ struct RowGen {
|
|||||||
|
|
||||||
__device__
|
__device__
|
||||||
Row operator()(int rel) {
|
Row operator()(int rel) {
|
||||||
return Row(rel, coset, rels[rel].mul * 2);
|
return Row(rel, coset, c_rels[rel].mul * 2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -189,7 +188,7 @@ thrust::device_vector<int> solve(
|
|||||||
// will break out later
|
// will break out later
|
||||||
while (true) {
|
while (true) {
|
||||||
// create a solver and apply it until nothing is being learned
|
// create a solver and apply it until nothing is being learned
|
||||||
Solver solve(ngens, cosets);
|
Solver solve(cosets);
|
||||||
thrust::for_each(
|
thrust::for_each(
|
||||||
thrust::device,
|
thrust::device,
|
||||||
rows.begin(), rows.end(),
|
rows.begin(), rows.end(),
|
||||||
@@ -221,11 +220,15 @@ int main(int argc, const char* argv[]) {
|
|||||||
Coxeter cox;
|
Coxeter cox;
|
||||||
cox = proc_args(argc, argv);
|
cox = proc_args(argc, argv);
|
||||||
std::vector<int> subs = {};
|
std::vector<int> subs = {};
|
||||||
|
int nrels = cox.rels.size();
|
||||||
|
int ngens = cox.ngens;
|
||||||
|
|
||||||
cudaMemcpyToSymbol(rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
|
cudaMemcpyToSymbol(c_ngens, &ngens, sizeof(int));
|
||||||
|
cudaMemcpyToSymbol(c_nrels, &nrels, sizeof(int));
|
||||||
|
cudaMemcpyToSymbol(c_rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
|
||||||
|
|
||||||
auto s = std::chrono::system_clock::now();
|
auto s = std::chrono::system_clock::now();
|
||||||
thrust::host_vector<int> cosets = solve(cox.ngens, cox.rels.size(), subs);
|
thrust::host_vector<int> cosets = solve(cox.ngens, nrels, subs);
|
||||||
auto e = std::chrono::system_clock::now();
|
auto e = std::chrono::system_clock::now();
|
||||||
|
|
||||||
std::chrono::duration<float> diff = e - s;
|
std::chrono::duration<float> diff = e - s;
|
||||||
|
|||||||
Reference in New Issue
Block a user