/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.hash;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableDataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.hash.HashedFeatureMap;
import org.tribuo.hash.Hasher;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

public final class HashingTrainer<T extends Output<T>>
implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(HashingTrainer.class.getName());
    @Config(mandatory=true, description="Trainer to use.")
    private Trainer<T> innerTrainer;
    @Config(mandatory=true, description="Feature hashing function to use.")
    private Hasher hasher;

    private HashingTrainer() {
    }

    public HashingTrainer(Trainer<T> trainer, Hasher hasher) {
        this.innerTrainer = trainer;
        this.hasher = hasher;
    }

    @Override
    public Model<T> train(Dataset<T> dataset, Map<String, Provenance> instanceProvenance) {
        return this.train(dataset, instanceProvenance, -1);
    }

    @Override
    public Model<T> train(Dataset<T> dataset, Map<String, Provenance> instanceProvenance, int invocationCount) {
        logger.log(Level.INFO, "Before hashing, had " + dataset.getFeatureMap().size() + " features.");
        ImmutableDataset<T> hashedData = ImmutableDataset.hashFeatureMap(dataset, this.hasher);
        logger.log(Level.INFO, "After hashing, had " + hashedData.getFeatureMap().size() + " features.");
        Model<T> model = this.innerTrainer.train(hashedData, instanceProvenance, invocationCount);
        if (!(model.getFeatureIDMap() instanceof HashedFeatureMap)) {
            throw new IllegalStateException("Trainer " + this.innerTrainer.getClass().getName() + " does not support hashing.");
        }
        return model;
    }

    @Override
    public int getInvocationCount() {
        return this.innerTrainer.getInvocationCount();
    }

    @Override
    public synchronized void setInvocationCount(int invocationCount) {
        this.innerTrainer.setInvocationCount(invocationCount);
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}

