【模板】线段树

目录

题目

题目传送门: 【P3373】【模板】线段树 2 – 洛谷

如题,已知一个数列,你需要进行下面三种操作: 1. 将某区间每一个数乘上 $x$。 2. 将某区间每一个数加上 $x$。 3. 求出某区间每一个数的和。

AC 代码

#include <cstdio>

const int MAXN = 1e5 + 10;
template<typename T, int MAXN = (int)1e5 + 10>
class SegmentTree {
    struct Node {
        T val, sum, mul;
        bool leaf;
    } tree[MAXN << 2];      // 四倍空间
    int n, ha;
    inline int lc(int const &p) { return p << 1; }
    inline int rc(int const &p) { return p << 1 | 1; }
    inline int mid(int const &l, int const &r) { return 1ll * (l + r) >> 1; }
    inline void push_up(int const &p) {
        if (tree[p].leaf) return;
        tree[p].val = (tree[lc(p)].val + tree[rc(p)].val)
    }
    inline void push_down(int const &p, int const &l, int const &r) {
        if (tree[p].leaf) return;                       // 如果是叶节点就直接 return
        if (tree[p].mul != 1) {
            (tree[lc(p)].val *= tree[p].mul)
            (tree[rc(p)].val *= tree[p].mul)
            (tree[lc(p)].mul *= tree[p].mul)
            (tree[rc(p)].mul *= tree[p].mul)
            (tree[lc(p)].sum *= tree[p].mul)
            (tree[rc(p)].sum *= tree[p].mul)
            tree[p].mul = 1;                            // 清空乘法懒标记
        }
        if (tree[p].sum) {
            (tree[lc(p)].val += tree[p].sum * (mid(l, r) - l + 1))
            (tree[rc(p)].val += tree[p].sum * (r - mid(l, r)))
            (tree[lc(p)].sum += tree[p].sum)
            (tree[rc(p)].sum += tree[p].sum)
            tree[p].sum = 0;                                                // 清空加法懒标记
        }
    }
    void _build(T const *a, int const p, int const l, int const r) {
        if (l == r) {                       // 该节点直接表示一个区间上的元素
            tree[p].val = a[l];
            tree[p].leaf = 1;
            return;
        }
        tree[p].mul = 1, tree[p].sum = 0;   // 在这里初始化而不用构造函数以节省空间
        _build(a, lc(p), l, mid(l, r));     // 递归建树
        _build(a, rc(p), mid(l, r) + 1, r);
        push_up(p);                         // 回溯更新答案
    }
public:
    SegmentTree() {}
    SegmentTree(T const *a, int const &len, int const &mod) {
        n = len, ha = mod;
        _build(a, 1, 1, n);
    }
    void build(T const *a, int const &len, int const &mod) {
        n = len, ha = mod;
        _build(a, 1, 1, n);
    }
    void update_mul(int const cl, int const cr, int const k, int p = 1, int l = 1, int r = -1) {    // cl、cr 为待修改区间,下同
        if (r == -1) r = this->n;
        if (cr < l || cl > r) return;   // 如果当前节点表示的区间和待修改区间无交集则直接 return
        if (cl <= l && r <= cr) {       // 如果当前节点表示的区间包含于待修改区间则直接修改值
            (tree[p].val *= k)
            (tree[p].mul *= k)
            (tree[p].sum *= k)
            return;
        }
        push_down(p, l, r);             // 下放标记
        update_mul(cl, cr, k, lc(p), l, mid(l, r));     // 递归修改
        update_mul(cl, cr, k, rc(p), mid(l, r) + 1, r);
        push_up(p);
    }
    void update_sum(int const cl, int const cr, int const k, int p = 1, int l = 1, int r = -1) {
        if (r == -1) r = this->n;
        if (cr < l || cl > r) return; 
        if (cl <= l && r <= cr) {
            (tree[p].val += k * (r - l + 1))
            (tree[p].sum += k)
            return;
        }
        push_down(p, l, r);
        update_sum(cl, cr, k, lc(p), l, mid(l, r));
        update_sum(cl, cr, k, rc(p), mid(l, r) + 1, r);
        push_up(p);
    }
    T query(int const ql, int const qr, int p = 1, int l = 1, int r = -1) {     // ql、qr 为待查区间
        if (r == -1) r = this->n;
        if (qr < l || ql > r) return 0;     // 如果待查区间和当前节点表示的区间没有交集则直接返回 0
        if (ql <= l && qr >= r)             // 如果前节点表示的区间包含于待查区间,则返回当前节点的值
            return tree[p].val;
        push_down(p, l, r);             // 下放标记
        return (query(ql, qr, lc(p), l, mid(l, r)) + query(ql, qr, rc(p), mid(l, r) + 1, r))
    }
};

int n, m, p;
long long a[MAXN], k;
SegmentTree<long long> segt;
int main(int argc, char const *argv[]) {
    scanf(
    for (register int i = 1; i <= n; i++) scanf(
    segt.build(a, n, p);
    for (int o, x, y;m--;) {
        scanf(
        if (o == 1) {
            scanf(
            segt.update_mul(x, y, k);
        } else 
        if (o == 2) {
            scanf(
            segt.update_sum(x, y, k);
        } else 
        if (o == 3) {
            printf(
        }
    }
    return 0;
}

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注