#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <osl/osl.h>

void
perforate_domain(osl_relation_p rel, int i) {
    /* multiply */
    for(int r = 0; r < rel->nb_rows; r++) {
        auto row = rel->m[r];
        osl_int_mul_si(rel->precision, &row[i + 1], row[i + 1], 2);
    }
}

int
main(int argc, char *argv[]) {
    if(argc != 5) {
        std::cerr << "usage: <INPUT SCOP> <OUTPUT SCOP> <STATEMENT> <ITERATOR>\n";
        return -1;
    }

    int statement = atoi(argv[3]);
    int iterator = atoi(argv[4]);

    auto *fp = fopen(argv[1], "r");
    if(!fp) {
        std::cerr << "openscop read failed\n";
        return -1;
    }
    auto scop = osl_scop_read(fp);
    fclose(fp);
    /* FIXME: missing: change domain N->N/2 */
    
    /* find statement */
    int i = statement;
    auto s = scop->statement;
    for(; s && i; s = s->next)
        i--;

    if(i || !s) {
        std::cerr << "statement not found\n";
        return -1;
    }

    while(s) {

        perforate_domain(s->domain, iterator);
        auto body_interface = osl_interface_lookup(s->extension->interface, "body");
        if(body_interface) {
            auto f = (osl_body_p)s->extension->data;
            std::string body(osl_strings_sprint(f->expression));
            std::string iter(osl_strings_sprint(f->iterators));

            /* select iterator according to args */
            /* FIXME: iterator can also be no char but string, d'oh! */
            char it = 0;
            i = iterator;
            for(auto cc : iter) {
                if(std::isspace(cc))
                    continue;
                if(!i) {
                    it = cc;
                    break;
                }
                i--;
            }

            if(i || !it) {
                std::cerr << "iterator not found\n";
                return -1;
            }

            /* replace all occurences of "it" with "it*2" */
            std::string s(1, it);
            std::string r(s + "*2");

            size_t pos = body.find(s);
            while(pos != std::string::npos) {
                body.replace(pos, s.size(), r);
                pos = body.find(s, pos + r.size());
            }

            /* use vector to get non const array of char for osl_strings_sread() */
            std::vector<char> char_body(body.begin(), body.end());
            char *c = &char_body[0];
            f->expression = osl_strings_sread(&c);
        }

        /* FIXME: some thing is wrong here, modifying j instead of i, tried to fix, need to verify */
        auto x = s->access;
        while(x) {
            if(x->elt->nb_rows >= 2) { /* >= 2 is array access */
                int col = x->elt->nb_columns - x->elt->nb_parameters - 3;
                for(int i = 0; i < x->elt->nb_rows; i++) {
                    osl_int_mul_si(x->elt->precision, &(x->elt->m[i][col]), x->elt->m[i][col], 2);
                }
                osl_relation_dump(stdout, x->elt);
            }
            x = x->next;
        }
        break; /* we only perforate one statement */
        s = s->next;
    }

    /* save output */
    fp = fopen(argv[2], "w");
    if(!fp) {
        std::cerr << "openscop write failed\n";
        return -1;
    }
    osl_scop_print(fp, scop);
    fclose(fp);

    osl_scop_free(scop);
    return 0;
}