add optimizations
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
main : main.cu util.h
|
||||
nvcc -o main -std=c++11 main.cu
|
||||
nvcc -o main -std=c++11 -O3 main.cu
|
||||
|
||||
clean :
|
||||
rm main
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "util.h"
|
||||
#include "groups.h"
|
||||
|
||||
__constant__ Rel rels[128];
|
||||
|
||||
struct Row {
|
||||
int rel;
|
||||
|
||||
@@ -38,14 +40,11 @@ std::ostream &operator<<(std::ostream &o, const Row &r) {
|
||||
struct Solver {
|
||||
int ngens;
|
||||
int *cosets;
|
||||
Rel *rels;
|
||||
|
||||
Solver(int ngens,
|
||||
thrust::device_vector<int> &cosets,
|
||||
thrust::device_vector<Rel> &rels)
|
||||
thrust::device_vector<int> &cosets)
|
||||
: ngens(ngens),
|
||||
cosets(thrust::raw_pointer_cast(cosets.data())),
|
||||
rels(thrust::raw_pointer_cast(rels.data())) {
|
||||
cosets(thrust::raw_pointer_cast(cosets.data())) {
|
||||
}
|
||||
|
||||
__device__
|
||||
@@ -95,13 +94,11 @@ struct CosetInitializer {
|
||||
|
||||
// this creates rows for cosets by index of each relation table
|
||||
struct RowGen {
|
||||
Rel *rels;
|
||||
|
||||
int coset;
|
||||
|
||||
RowGen(int coset, thrust::device_vector<Rel> &rels)
|
||||
: coset(coset),
|
||||
rels(thrust::raw_pointer_cast(rels.data())) {}
|
||||
RowGen(int coset)
|
||||
: coset(coset) {
|
||||
}
|
||||
|
||||
__device__
|
||||
Row operator()(int rel) {
|
||||
@@ -130,7 +127,8 @@ bool add_coset(
|
||||
int ngens,
|
||||
int *coset,
|
||||
int *hint,
|
||||
thrust::device_vector<int> &cosets) {
|
||||
thrust::device_vector<int> &dcosets) {
|
||||
thrust::host_vector<int> cosets = dcosets;
|
||||
*coset = cosets.size() / ngens;
|
||||
|
||||
// todo: this part especially.
|
||||
@@ -142,10 +140,10 @@ bool add_coset(
|
||||
int from = *hint / ngens;
|
||||
int gen = *hint % ngens;
|
||||
|
||||
add_row(ngens, cosets);
|
||||
add_row(ngens, dcosets);
|
||||
|
||||
cosets[*hint] = *coset;
|
||||
cosets[*coset * ngens + gen] = from;
|
||||
dcosets[*hint] = *coset;
|
||||
dcosets[*coset * ngens + gen] = from;
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -153,23 +151,23 @@ bool add_coset(
|
||||
// add a row for each relation table for some coset
|
||||
void gen_rows(
|
||||
int coset,
|
||||
thrust::device_vector<Rel> &rels,
|
||||
int nrels,
|
||||
thrust::device_vector<Row> &rows) {
|
||||
rows.resize(rows.size() + rels.size());
|
||||
rows.resize(rows.size() + nrels);
|
||||
|
||||
thrust::counting_iterator<int> counter(0);
|
||||
thrust::transform(
|
||||
thrust::device,
|
||||
counter, counter + rels.size(),
|
||||
rows.end() - rels.size(),
|
||||
RowGen(coset, rels));
|
||||
counter, counter + nrels,
|
||||
rows.end() - nrels,
|
||||
RowGen(coset));
|
||||
}
|
||||
|
||||
// do everything. data is implicitly passed to the device via device_vector.
|
||||
thrust::device_vector<int> solve(
|
||||
int ngens,
|
||||
thrust::device_vector<int> subs,
|
||||
thrust::device_vector<Rel> rels) {
|
||||
int nrels,
|
||||
thrust::device_vector<int> subs) {
|
||||
|
||||
thrust::device_vector<int> cosets;
|
||||
thrust::device_vector<Row> rows;
|
||||
@@ -182,7 +180,7 @@ thrust::device_vector<int> solve(
|
||||
CosetInitializer(cosets));
|
||||
|
||||
// generate initial relation table rows for coset 0
|
||||
gen_rows(0, rels, rows);
|
||||
gen_rows(0, nrels, rows);
|
||||
|
||||
// these keep track of what progress has been made
|
||||
int coset = 0;
|
||||
@@ -191,7 +189,7 @@ thrust::device_vector<int> solve(
|
||||
// will break out later
|
||||
while (true) {
|
||||
// create a solver and apply it until nothing is being learned
|
||||
Solver solve(ngens, cosets, rels);
|
||||
Solver solve(ngens, cosets);
|
||||
thrust::for_each(
|
||||
thrust::device,
|
||||
rows.begin(), rows.end(),
|
||||
@@ -205,7 +203,7 @@ thrust::device_vector<int> solve(
|
||||
if (done) break;
|
||||
|
||||
// generate relation table rows for new coset
|
||||
gen_rows(coset, rels, rows);
|
||||
gen_rows(coset, nrels, rows);
|
||||
|
||||
// move completed rows to the end of the list and remove.
|
||||
auto cut = thrust::partition(
|
||||
@@ -224,8 +222,10 @@ int main(int argc, const char* argv[]) {
|
||||
cox = proc_args(argc, argv);
|
||||
std::vector<int> subs = {};
|
||||
|
||||
cudaMemcpyToSymbol(rels, cox.rels.data(), cox.rels.size() * sizeof(Rel));
|
||||
|
||||
auto s = std::chrono::system_clock::now();
|
||||
thrust::host_vector<int> cosets = solve(cox.ngens, subs, cox.rels);
|
||||
thrust::host_vector<int> cosets = solve(cox.ngens, cox.rels.size(), subs);
|
||||
auto e = std::chrono::system_clock::now();
|
||||
|
||||
std::chrono::duration<float> diff = e - s;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
main : main.cu util.h
|
||||
nvcc -o main -std=c++11 main.cu
|
||||
nvcc -o main -std=c++11 -O3 main.cu
|
||||
|
||||
clean :
|
||||
rm main
|
||||
|
||||
Reference in New Issue
Block a user