Distributed Deep learning with Keras & Spark

Overview

Elephas: Distributed Deep Learning with Keras & Spark

Elephas

Build Status license

Elephas is an extension of Keras, which allows you to run distributed deep learning models at scale with Spark. Elephas currently supports a number of applications, including:

Schematically, elephas works as follows.

Elephas

Table of content:

Introduction

Elephas brings deep learning with Keras to Spark. Elephas intends to keep the simplicity and high usability of Keras, thereby allowing for fast prototyping of distributed models, which can be run on massive data sets. For an introductory example, see the following iPython notebook.

ἐλέφας is Greek for ivory and an accompanying project to κέρας, meaning horn. If this seems weird mentioning, like a bad dream, you should confirm it actually is at the Keras documentation. Elephas also means elephant, as in stuffed yellow elephant.

Elephas implements a class of data-parallel algorithms on top of Keras, using Spark's RDDs and data frames. Keras Models are initialized on the driver, then serialized and shipped to workers, alongside with data and broadcasted model parameters. Spark workers deserialize the model, train their chunk of data and send their gradients back to the driver. The "master" model on the driver is updated by an optimizer, which takes gradients either synchronously or asynchronously.

Getting started

Just install elephas from PyPI with, Spark will be installed through pyspark for you.

pip install elephas

That's it, you should now be able to run Elephas examples.

Basic Spark integration

After installing both Elephas, you can train a model as follows. First, create a local pyspark context

from pyspark import SparkContext, SparkConf
conf = SparkConf().setAppName('Elephas_App').setMaster('local[8]')
sc = SparkContext(conf=conf)

Next, you define and compile a Keras model

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import SGD
model = Sequential()
model.add(Dense(128, input_dim=784))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer=SGD())

and create an RDD from numpy arrays (or however you want to create an RDD)

from elephas.utils.rdd_utils import to_simple_rdd
rdd = to_simple_rdd(sc, x_train, y_train)

The basic model in Elephas is the SparkModel. You initialize a SparkModel by passing in a compiled Keras model, an update frequency and a parallelization mode. After that you can simply fit the model on your RDD. Elephas fit has the same options as a Keras model, so you can pass epochs, batch_size etc. as you're used to from tensorflow.keras.

from elephas.spark_model import SparkModel

spark_model = SparkModel(model, frequency='epoch', mode='asynchronous')
spark_model.fit(rdd, epochs=20, batch_size=32, verbose=0, validation_split=0.1)

Your script can now be run using spark-submit

spark-submit --driver-memory 1G ./your_script.py

Increasing the driver memory even further may be necessary, as the set of parameters in a network may be very large and collecting them on the driver eats up a lot of resources. See the examples folder for a few working examples.

Distributed Inference / Evaluation

The SparkModel can also be used for distributed inference (prediction) and evaluation. Similar to the fit method, the predict and evaluate methods conform to the Keras Model API.

from elephas.spark_model import SparkModel

# create/train the model, similar to the previous section (Basic Spark Integration)
model = ...
spark_model = SparkModel(model, ...)
spark_model.fit(...)

x_test, y_test = ... # load test data

predictions = spark_model.predict(x_test) # perform inference
evaluation = spark_model.evaluate(x_test, y_test) # perform evaluation/scoring

The paradigm is identical to the data parallelism in training, as the model is serialized and shipped to the workers and used to evaluate a chunk of the testing data. The predict method will take either a numpy array or an RDD.

Spark MLlib integration

Following up on the last example, to use Spark's MLlib library with Elephas, you create an RDD of LabeledPoints for supervised training as follows

from elephas.utils.rdd_utils import to_labeled_point
lp_rdd = to_labeled_point(sc, x_train, y_train, categorical=True)

Training a given LabeledPoint-RDD is very similar to what we've seen already

from elephas.spark_model import SparkMLlibModel
spark_model = SparkMLlibModel(model, frequency='batch', mode='hogwild')
spark_model.train(lp_rdd, epochs=20, batch_size=32, verbose=0, validation_split=0.1, 
                  categorical=True, nb_classes=nb_classes)

Spark ML integration

To train a model with a SparkML estimator on a data frame, use the following syntax.

df = to_data_frame(sc, x_train, y_train, categorical=True)
test_df = to_data_frame(sc, x_test, y_test, categorical=True)

estimator = ElephasEstimator(model, epochs=epochs, batch_size=batch_size, frequency='batch', mode='asynchronous',
                             categorical=True, nb_classes=nb_classes)
fitted_model = estimator.fit(df)

Fitting an estimator results in a SparkML transformer, which we can use for predictions and other evaluations by calling the transform method on it.

prediction = fitted_model.transform(test_df)
pnl = prediction.select("label", "prediction")
pnl.show(100)

prediction_and_label= pnl.rdd.map(lambda row: (row.label, row.prediction))
metrics = MulticlassMetrics(prediction_and_label)
print(metrics.precision())
print(metrics.recall())

If the model utilizes custom activation function, layer, or loss function, that will need to be supplied using the set_custom_objects method:

def custom_activation(x):
    ...
class CustomLayer(Layer):
    ...
model = Sequential()
model.add(CustomLayer(...))

estimator = ElephasEstimator(model, epochs=epochs, batch_size=batch_size)
estimator.set_custom_objects({'custom_activation': custom_activation, 'CustomLayer': CustomLayer})

Distributed hyper-parameter optimization

Hyper-parameter optimization with elephas is based on hyperas, a convenience wrapper for hyperopt and keras. Each Spark worker executes a number of trials, the results get collected and the best model is returned. As the distributed mode in hyperopt (using MongoDB), is somewhat difficult to configure and error prone at the time of writing, we chose to implement parallelization ourselves. Right now, the only available optimization algorithm is random search.

The first part of this example is more or less directly taken from the hyperas documentation. We define data and model as functions, hyper-parameter ranges are defined through braces. See the hyperas documentation for more on how this works.

from hyperopt import STATUS_OK
from hyperas.distributions import choice, uniform

def data():
    from tensorflow.keras.datasets import mnist
    from tensorflow.keras.utils import to_categorical
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(60000, 784)
    x_test = x_test.reshape(10000, 784)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    nb_classes = 10
    y_train = to_categorical(y_train, nb_classes)
    y_test = to_categorical(y_test, nb_classes)
    return x_train, y_train, x_test, y_test


def model(x_train, y_train, x_test, y_test):
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense, Dropout, Activation
    from tensorflow.keras.optimizers import RMSprop

    model = Sequential()
    model.add(Dense(512, input_shape=(784,)))
    model.add(Activation('relu'))
    model.add(Dropout({{uniform(0, 1)}}))
    model.add(Dense({{choice([256, 512, 1024])}}))
    model.add(Activation('relu'))
    model.add(Dropout({{uniform(0, 1)}}))
    model.add(Dense(10))
    model.add(Activation('softmax'))

    rms = RMSprop()
    model.compile(loss='categorical_crossentropy', optimizer=rms)

    model.fit(x_train, y_train,
              batch_size={{choice([64, 128])}},
              nb_epoch=1,
              show_accuracy=True,
              verbose=2,
              validation_data=(x_test, y_test))
    score, acc = model.evaluate(x_test, y_test, show_accuracy=True, verbose=0)
    print('Test accuracy:', acc)
    return {'loss': -acc, 'status': STATUS_OK, 'model': model.to_yaml()}

Once the basic setup is defined, running the minimization is done in just a few lines of code:

from elephas.hyperparam import HyperParamModel
from pyspark import SparkContext, SparkConf

# Create Spark context
conf = SparkConf().setAppName('Elephas_Hyperparameter_Optimization').setMaster('local[8]')
sc = SparkContext(conf=conf)

# Define hyper-parameter model and run optimization
hyperparam_model = HyperParamModel(sc)
hyperparam_model.minimize(model=model, data=data, max_evals=5)

Distributed training of ensemble models

Building on the last section, it is possible to train ensemble models with elephas by means of running hyper-parameter optimization on large search spaces and defining a resulting voting classifier on the top-n performing models. With data and model defined as above, this is a simple as running

result = hyperparam_model.best_ensemble(nb_ensemble_models=10, model=model, data=data, max_evals=5)

In this example an ensemble of 10 models is built, based on optimization of at most 5 runs on each of the Spark workers.

Discussion

Premature parallelization may not be the root of all evil, but it may not always be the best idea to do so. Keep in mind that more workers mean less data per worker and parallelizing a model is not an excuse for actual learning. So, if you can perfectly well fit your data into memory and you're happy with training speed of the model consider just using keras.

One exception to this rule may be that you're already working within the Spark ecosystem and want to leverage what's there. The above SparkML example shows how to use evaluation modules from Spark and maybe you wish to further process the outcome of an elephas model down the road. In this case, we recommend to use elephas as a simple wrapper by setting num_workers=1.

Note that right now elephas restricts itself to data-parallel algorithms for two reasons. First, Spark simply makes it very easy to distribute data. Second, neither Spark nor Theano make it particularly easy to split up the actual model in parts, thus making model-parallelism practically impossible to realize.

Having said all that, we hope you learn to appreciate elephas as a pretty easy to setup and use playground for data-parallel deep-learning algorithms.

Literature

[1] J. Dean, G.S. Corrado, R. Monga, K. Chen, M. Devin, QV. Le, MZ. Mao, M’A. Ranzato, A. Senior, P. Tucker, K. Yang, and AY. Ng. Large Scale Distributed Deep Networks.

[2] F. Niu, B. Recht, C. Re, S.J. Wright HOGWILD!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent

[3] C. Noel, S. Osindero. Dogwild! — Distributed Hogwild for CPU & GPU

Maintainers / Contributions

This great project was started by Max Pumperla, and is currently maintained by Daniel Cahall (https://github.com/danielenricocahall). If you have any questions, please feel free to open up an issue or send an email to [email protected]. If you want to contribute, feel free to submit a PR, or start a conversation about how we can go about implementing something.

Comments
  • Elephas slice_X error

    Elephas slice_X error

    Elephas encounter this error with the latest theano. Can you help me fix this.

    Traceback (most recent call last):
      File "/home/zeus/workspace/./KerasOnSparkElephas.py", line 28, in <module>
        from elephas.spark_model import SparkModel
      File "/home/zeus/anaconda2/lib/python2.7/site-packages/elephas/spark_model.py", line 18, in <module>
        from keras.models import model_from_yaml, slice_X
    ImportError: cannot import name slice_X
    
    opened by ducminhnguyen 24
  • Massive dataset + data_generator

    Massive dataset + data_generator

    Hello guys,

    Before I begin, I want to thank you for this amazing tool. It's truely awesome and enables true distributed deep learning out of the box. Now for some questions.

    1. From what I've seen I need to load the dataset in memory before I send it to elephas for distributed processing. When the dataset is massive, as in multiple times my ram, how can I use an hdf5 so that each worker can load parts of the dataset from disk and do the processing? Please let me know.
    2. Is there any support for keras "data_generator" functionality? That enables real time data augmentation?

    Thank you for your support.

    opened by AntreasAntoniou 20
  • Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

    Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

    Could you please help me? I am giving my model and the error that I am seeing here.

    model.compile(loss=lossFunc, optimizer=gradDiscent, metrics=['accuracy']);
    ##############################START: DISTRIBUTED MODEL################
    from pyspark import SparkContext, SparkConf
    #Create spark context
    conf = SparkConf().setAppName('NSL-KDD-DISTRIBUTED').setMaster('local[8]');
    sc = SparkContext(conf=conf);
    
    from elephas.utils.rdd_utils import to_simple_rdd
    #Build RDD (Resilient Distributed Dataset) from numpy features and labels
    rdd = to_simple_rdd(sc, trainX, trainY);
    
    from elephas.spark_model import SparkModel
    from elephas import optimizers as elephas_optimizers
    #Initialize SparkModel from Keras model and Spark Context
    elphOptimizer = elephas_optimizers.Adagrad();
    sparkModel = SparkModel(sc, model, optimizer=elphOptimizer, frequency='epoch', model='asynchronous', num_workers=1);
    #Train Spark Model
    sparkModel.train(rdd, nb_epoch=epochs, batch_size=batchSize, verbose=2);
    
    #Evaluate Spark Model
    score = sparkModel.master_network.evaluate(testX, testY, verbose=2);
    print(score);
    

    ####################ERROR######################## Traceback (most recent call last): File "C:\PythonWorks\mine\dm-dist.py", line 230, in sparkModel.train(rdd, nb_epoch=epochs, batch_size=batchSize, verbose=2); File "C:\Miniconda3\lib\site-packages\elephas\spark_model.py", line 194, in train self._train(rdd, nb_epoch, batch_size, verbose, validation_split, master_url) File "C:\Miniconda3\lib\site-packages\elephas\spark_model.py", line 205, in _train self.start_server() File "C:\Miniconda3\lib\site-packages\elephas\spark_model.py", line 125, in start_server self.server.start() File "C:\Miniconda3\lib\multiprocessing\process.py", line 105, in start self._popen = self._Popen(self) File "C:\Miniconda3\lib\multiprocessing\context.py", line 223, in _Popen return _default_context.get_context().Process._Popen(process_obj) File "C:\Miniconda3\lib\multiprocessing\context.py", line 322, in _Popen return Popen(process_obj) File "C:\Miniconda3\lib\multiprocessing\popen_spawn_win32.py", line 65, in init reduction.dump(process_obj, to_child) File "C:\Miniconda3\lib\multiprocessing\reduction.py", line 60, in dump ForkingPickler(file, protocol).dump(obj) File "C:\Miniconda3\lib\site-packages\pyspark\context.py", line 306, in getnewargs "It appears that you are attempting to reference SparkContext from a broadcast " Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

    opened by mohaimenz 18
  • [BUG] ElephasTransformer:get_config has error

    [BUG] ElephasTransformer:get_config has error

    Hi @danielenricocahall I think there's a bug in this line https://github.com/maxpumperla/elephas/blob/master/elephas/ml_model.py#L171

    getattr(self, 'weights', []) should be most likely returning a Broadcast object so it needs .value to extract and actual content

    Also each weight is already np.ndarray so I think the .numpy() is weird?

    This is what I ended up using but not sure if it's a perfet fix: 'weights': [weight.tolist() for weight in self.weights.value],

    A follow-up q: do we need to make weight broadcasted before initializing ElephasTransformer?

    opened by OscarDPan 16
  • Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.

    Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.

    Hi, I'm trying to use elephas for my deep learning models on spark but so far I couldn't even get anything to work on 3 different machines and on multiple notebooks.

    • "ml_pipeline_otto.py" crashes on the load_data_frame function, more specifically on return sqlContext.createDataFrame(data, ['features', 'category']) with the error : Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.

    • "mnist_mlp_spark.py" crashes on the spark_model.fit method with the error : TypeError: can't pickle _thread.RLock objects

    • "My Own Pipeline" crashes right after fitting (it actually trains it) the model with this error : Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.

    I'm running tensorflow 2.1.0, pyspark 3.0.2, jdk-8u281 and python 3.7 and elephas 1.4.2

    opened by diogoribeiro09 13
  • Not working with regression problems

    Not working with regression problems

    Hello and thank you for this package!

    It seems that this package is not useful when the problem is regression rather than classification. After getting errors on my own data when using this package, I realized it always gives error for regression problems. As an example I slightly modified https://github.com/maxpumperla/elephas/blob/master/examples/ml_pipeline_otto.py to make it a regression problem from line 61 onwards:

    model = Sequential()
    model.add(Dense(512, input_shape=(input_dim,)))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(512))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('linear'))
    
    model.compile(optimizer='adam', loss='mean_absolute_error')
    
    sgd = optimizers.SGD(lr=0.01)
    sgd_conf = optimizers.serialize(sgd)
    
    # Initialize Elephas Spark ML Estimator
    estimator = ElephasEstimator()
    estimator.set_keras_model_config(model.to_yaml())
    estimator.set_optimizer_config(sgd_conf)
    estimator.set_mode("synchronous")
    estimator.set_loss("mean_absolute_error")
    estimator.set_metrics(['mae'])
    estimator.setFeaturesCol("scaled_features")
    estimator.setLabelCol("index_category")
    estimator.set_epochs(10)
    estimator.set_batch_size(128)
    estimator.set_num_workers(1)
    estimator.set_verbosity(0)
    estimator.set_validation_split(0.15)
    estimator.set_categorical_labels(False)
    # estimator.set_nb_classes(nb_classes)
    
    # Fitting a model returns a Transformer
    pipeline = Pipeline(stages=[string_indexer, scaler, estimator])
    fitted_pipeline = pipeline.fit(train_df)
    
    # Evaluate Spark model
    prediction = fitted_pipeline.transform(train_df)
    pnl = prediction.select("index_category", "prediction")
    pnl.show(2)
    

    Unfortunately it gives error:

    19/04/10 05:32:34 WARN TaskSetManager: Lost task 0.0 in stage 459.0 (TID 6450, local[0], executor 103): org.apache.spark.api.python.PythonException: Traceback (most rece
    nt call last):
      File "/usr/lib/spark/python/pyspark/worker.py", line 372, in main
        process()
      File "/usr/lib/spark/python/pyspark/worker.py", line 367, in process
        serializer.dump_stream(func(split_index, iterator), outfile)
      File "/usr/lib/spark/python/pyspark/serializers.py", line 390, in dump_stream
        vs = list(itertools.islice(iterator, batch))
      File "/usr/lib/spark/python/pyspark/util.py", line 99, in wrapper
        return f(*args, **kwargs)
      File "/usr/lib/spark/python/pyspark/sql/session.py", line 730, in prepare
        verify_func(obj)
      File "/usr/lib/spark/python/pyspark/sql/types.py", line 1389, in verify
        verify_value(obj)
      File "/usr/lib/spark/python/pyspark/sql/types.py", line 1368, in verify_struct
        "length of fields (%d)" % (len(obj), len(verifiers))))
    ValueError: Length of object (7) does not match with length of fields (5)
    
            at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:452)
            at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:588)
            at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:571)
            at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:406)
            at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
            at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
            at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
            at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
            at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
            at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
            at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
            at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$11$$anon$1.hasNext(WholeStageCodegenExec.scala:619)
            at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
            at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
            at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
            at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
            at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
            at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
            at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
            at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
            at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
            at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
            at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
            at org.apache.spark.scheduler.Task.run(Task.scala:121)
            at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:402)
            at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
            at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:408)
            at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
            at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
            at java.lang.Thread.run(Thread.java:748)
    
    Driver stacktrace:
            at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:2039)
            at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2027)
            at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:2026)
            at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
            at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
            at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2026)
            at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
            at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:966)
            at scala.Option.foreach(Option.scala:257)
            at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:966)
            at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2260)
            at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2209)
            at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2198)
            at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
            at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:777)
            at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
            at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
            at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
            at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
            at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
            at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
            at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
           at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
            at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
            at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
            at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
            at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
            at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
            at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
            at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
            at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
            at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
            at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
            at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
            at java.lang.reflect.Method.invoke(Method.java:498)
            at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
            at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
            at py4j.Gateway.invoke(Gateway.java:282)
            at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
            at py4j.commands.CallCommand.execute(CallCommand.java:79)
            at py4j.GatewayConnection.run(GatewayConnection.java:238)
            at java.lang.Thread.run(Thread.java:748)
    Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
      File "/usr/lib/spark/python/pyspark/worker.py", line 372, in main
        process()
      File "/usr/lib/spark/python/pyspark/worker.py", line 367, in process
        serializer.dump_stream(func(split_index, iterator), outfile)
      File "/usr/lib/spark/python/pyspark/serializers.py", line 390, in dump_stream
        vs = list(itertools.islice(iterator, batch))
      File "/usr/lib/spark/python/pyspark/util.py", line 99, in wrapper
        return f(*args, **kwargs)
      File "/usr/lib/spark/python/pyspark/sql/session.py", line 730, in prepare
        verify_func(obj)
      File "/usr/lib/spark/python/pyspark/sql/types.py", line 1389, in verify
        verify_value(obj)
      File "/usr/lib/spark/python/pyspark/sql/types.py", line 1368, in verify_struct
        "length of fields (%d)" % (len(obj), len(verifiers))))
    ValueError: Length of object (7) does not match with length of fields (5)
    
    opened by mostafam 13
  • Integrate callback functionality into elephas (History, checkpoints, etc.)

    Integrate callback functionality into elephas (History, checkpoints, etc.)

    https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L341

    The fit function in Keras returns a graph that can be used to determine if the model is overfitting or not. This would be very useful from Elephas.

    opened by sd12832 13
  • pickle.PicklingError: Could not serialize object: TypeError: can't pickle _thread.lock objects

    pickle.PicklingError: Could not serialize object: TypeError: can't pickle _thread.lock objects

    The error pickle.PicklingError: Could not serialize object: TypeError: can't pickle _thread.lock objects when I run it on Spark cluster,how to solve it?

    opened by gzgywh 12
  • Custom layers

    Custom layers

    It seems that custom layers are not recognized in the train phase of spark model, specifically when loaded from yaml (model_from_yaml). I have tried with a custom activation function and it didn't work, the get_from_module function raised an exception "Invalid activation function: ...".

    opened by lenlen 12
  • TypeError: can't pickle _thread.lock objects

    TypeError: can't pickle _thread.lock objects

    Similar to the issue here but different I think: https://github.com/maxpumperla/elephas/issues/82

    For the MNIST example but not the jupyter example, instead for the below code:

    
    import sys
    import os
    os.environ['JAVA_HOME'] = os.getenv("JAVA_HOME")
    print(os.getenv("JAVA_HOME"))
    
    import findspark
    findspark.init()
    
    #import pyspark
    from pyspark import SparkContext, SparkConf
    from pyspark.sql import SQLContext
    from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, StandardScaler, VectorAssembler
    from pyspark.ml import Pipeline
    from pyspark.sql.functions import rand
    from pyspark.mllib.evaluation import MulticlassMetrics
    
    # Keras / Deep Learning
    from keras.models import Sequential
    from keras.layers.core import Dense, Dropout, Activation
    from keras import optimizers, regularizers
    from keras.optimizers import Adam
    
    
    if __name__ == '__main__':
    
    	from pyspark import SparkContext, SparkConf
    	conf = SparkConf().setAppName('MNIST').setMaster('local[8]')
    
    
    	from keras.datasets import mnist
    	from keras.utils import np_utils
    
    	(x_train, y_train), (x_test, y_test) = mnist.load_data()
    	x_train = x_train.reshape(60000, 784)
    	x_test = x_test.reshape(10000, 784)
    	x_train = x_train.astype('float32')
    	x_test = x_test.astype('float32')
    	x_train /= 255
    	x_test /= 255
    	nb_classes = 10
    	y_train = np_utils.to_categorical(y_train, nb_classes)
    
    
    	from keras.models import Sequential
    	from keras.layers.core import Dense, Dropout, Activation
    	from keras.optimizers import SGD
    
    	model = Sequential()
    	model.add(Dense(128, input_dim=784))
    	model.add(Activation('relu'))
    	model.add(Dropout(0.2))
    	model.add(Dense(128))
    	model.add(Activation('relu'))
    	model.add(Dropout(0.2))
    	model.add(Dense(10))
    	model.add(Activation('softmax'))
    	model.compile(loss='categorical_crossentropy', optimizer=SGD())
    
    	from elephas.utils.rdd_utils import to_simple_rdd
    	rdd = to_simple_rdd(sc, x_train, y_train)
    
    	from elephas.spark_model import SparkModel
    
    	spark_model = SparkModel(model, frequency='epoch', mode='asynchronous')
    	spark_model.fit(rdd, epochs=20, batch_size=32, verbose=0, validation_split=0.1)
    
    

    I get an analogous error:

    
    C:\Java\Java8
    Using TensorFlow backend.
    19/10/09 15:27:04 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
    Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
    Setting default log level to "WARN".
    To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
    WARNING
    2019-10-09 15:27:12.161649: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
    >>> Fit model
    Traceback (most recent call last):
      File "pyspark.py", line 105, in <module>
        spark_model.fit(rdd, epochs=20, batch_size=32, verbose=0, validation_split=0.1)
      File "C:\Users\\Anaconda3\lib\site-packages\elephas\spark_model.py", line 151, in fit
        self._fit(rdd, epochs, batch_size, verbose, validation_split)
      File "C:\Users\\Anaconda3\lib\site-packages\elephas\spark_model.py", line 163, in _fit
        self.start_server()
      File "C:\Users\\Anaconda3\lib\site-packages\elephas\spark_model.py", line 118, in start_server
        self.parameter_server.start()
      File "C:\Users\\Anaconda3\lib\site-packages\elephas\parameter\server.py", line 85, in start
        self.server.start()
      File "C:\Users\\Anaconda3\lib\multiprocessing\process.py", line 105, in start
        self._popen = self._Popen(self)
      File "C:\Users\\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
        return _default_context.get_context().Process._Popen(process_obj)
      File "C:\Users\\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
        return Popen(process_obj)
      File "C:\Users\\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 65, in __init__
        reduction.dump(process_obj, to_child)
      File "C:\Users\\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
        ForkingPickler(file, protocol).dump(obj)
    TypeError: can't pickle _thread.lock objects
    C:\Java\Java8
    ERFOLGREICH: Der Prozess mit PID 8104 (untergeordnetem Prozess von PID 11096) wurde beendet.
    ERFOLGREICH: Der Prozess mit PID 11096 (untergeordnetem Prozess von PID 14268) wurde beendet.
    ERFOLGREICH: Der Prozess mit PID 14268 (untergeordnetem Prozess von PID 3964) wurde beendet.
    

    And afterwards:

    
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "C:\Users\\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
        exitcode = _main(fd)
      File "C:\Users\\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _main
        self = reduction.pickle.load(from_parent)
    EOFError: Ran out of input
    
    opened by datistiquo 11
  • Asynchronous mode not working

    Asynchronous mode not working

    I am trying to run keras autoencoder model in asyncronous mode and its giving following error

    ------------------------ CODE --------------------------------------------

    Create Spark Model

    spark_model = SparkModel(model, mode='asynchronous', frequency='epoch') ​

    Fit the Spark model

    start= time.time() spark_model.fit(train_rdd, epochs=hyperparameters['epochs'], batch_size=1, verbose=2, validation_split=0.1) print('Time Taken : ', time.time() - start)

    ---------------------------------------Error -------------------------------------

    • Serving Flask app "elephas.parameter.server" (lazy loading)
    • Environment: production WARNING: Do not use the development server in a production environment. Use a production WSGI server instead.
    • Debug mode: on
    • Running on http://0.0.0.0:4000/ (Press CTRL+C to quit)
    • Restarting with stat

    Py4JJavaError Traceback (most recent call last) in () 4 # Fit the Spark model 5 start= time.time() ----> 6 spark_model.fit(train_rdd, epochs=hyperparameters['epochs'], batch_size=1, verbose=2, validation_split=0.1) 7 print('Time Taken : ', time.time() - start)

    ~/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/spark_model.py in fit(self, rdd, epochs, batch_size, verbose, validation_split) 143 144 if self.mode in ['asynchronous', 'synchronous', 'hogwild']: --> 145 self._fit(rdd, epochs, batch_size, verbose, validation_split) 146 else: 147 raise ValueError("Choose from one of the modes: asynchronous, synchronous or hogwild")

    ~/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/spark_model.py in _fit(self, rdd, epochs, batch_size, verbose, validation_split) 169 if self.mode in ['asynchronous', 'hogwild']: 170 worker = AsynchronousSparkWorker(yaml, parameters, mode, train_config, freq, optimizer, loss, metrics, custom) --> 171 rdd.mapPartitions(worker.train).collect() 172 new_parameters = self.client.get_parameters() 173 elif self.mode == 'synchronous':

    ~/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/rdd.py in collect(self) 832 """ 833 with SCCallSiteSync(self.context) as css: --> 834 sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) 835 return list(_load_from_socket(sock_info, self._jrdd_deserializer)) 836

    ~/anaconda3/envs/thesis/lib/python3.5/site-packages/py4j/java_gateway.py in call(self, *args) 1255 answer = self.gateway_client.send_command(command) 1256 return_value = get_return_value( -> 1257 answer, self.gateway_client, self.target_id, self.name) 1258 1259 for temp_arg in temp_args:

    ~/anaconda3/envs/thesis/lib/python3.5/site-packages/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". --> 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError(

    Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0 (TID 1, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1254, in do_open h.request(req.get_method(), req.selector, req.data, headers) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1107, in request self._send_request(method, url, body, headers) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1152, in _send_request self.endheaders(body) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1103, in endheaders self._send_output(message_body) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 934, in _send_output self.send(msg) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 877, in send self.connect() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 849, in connect (self.host,self.port), self.timeout, self.source_address) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/socket.py", line 712, in create_connection raise err File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/socket.py", line 703, in create_connection sock.connect(sa) ConnectionRefusedError: [Errno 111] Connection refused

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 230, in main process() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 225, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 372, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/worker.py", line 98, in train weights_before_training = self.client.get_parameters() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/parameter/client.py", line 55, in get_parameters pickled_weights = urllib2.urlopen(request).read() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 163, in urlopen return opener.open(url, data, timeout) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 466, in open response = self._open(req, data) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 484, in _open '_open', req) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 444, in _call_chain result = func(*args) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1282, in http_open return self.do_open(http.client.HTTPConnection, req) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1256, in do_open raise URLError(err) urllib.error.URLError: <urlopen error [Errno 111] Connection refused>

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:298)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:438)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:421)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:252)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator$class.foreach(Iterator.scala:893)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
    at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
    at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
    at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
    at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
    at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
    at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
    at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
    at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
    

    Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1602) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1590) at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1589) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1589) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831) at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831) at scala.Option.foreach(Option.scala:257) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1823) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1772) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1761) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099) at org.apache.spark.rdd.RDD$$anonfun$collect$1.apply(RDD.scala:939) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:363) at org.apache.spark.rdd.RDD.collect(RDD.scala:938) at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:162) at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last): File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1254, in do_open h.request(req.get_method(), req.selector, req.data, headers) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1107, in request self._send_request(method, url, body, headers) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1152, in _send_request self.endheaders(body) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 1103, in endheaders self._send_output(message_body) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 934, in _send_output self.send(msg) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 877, in send self.connect() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/http/client.py", line 849, in connect (self.host,self.port), self.timeout, self.source_address) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/socket.py", line 712, in create_connection raise err File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/socket.py", line 703, in create_connection sock.connect(sa) ConnectionRefusedError: [Errno 111] Connection refused

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 230, in main process() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 225, in process serializer.dump_stream(func(split_index, iterator), outfile) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 372, in dump_stream vs = list(itertools.islice(iterator, batch)) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/worker.py", line 98, in train weights_before_training = self.client.get_parameters() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/site-packages/elephas/parameter/client.py", line 55, in get_parameters pickled_weights = urllib2.urlopen(request).read() File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 163, in urlopen return opener.open(url, data, timeout) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 466, in open response = self._open(req, data) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 484, in _open '_open', req) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 444, in _call_chain result = func(*args) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1282, in http_open return self.do_open(http.client.HTTPConnection, req) File "/home/dexter/anaconda3/envs/thesis/lib/python3.5/urllib/request.py", line 1256, in do_open raise URLError(err) urllib.error.URLError: <urlopen error [Errno 111] Connection refused>

    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:298)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:438)
    at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:421)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:252)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator$class.foreach(Iterator.scala:893)
    at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
    at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
    at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
    at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
    at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
    at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
    at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
    at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
    at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
    at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
    at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
    at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    ... 1 more
    
    opened by iam-armanahmed 11
  • java.lang.NullPointerException while showing predicted output.

    java.lang.NullPointerException while showing predicted output.

    Dear @maxpumperla / other authors of this repo. A big thanks for developing this library, I have been successful in running it some datasets but I am facing issue with the current one. Please help me though it.

    Dataset: https://www.kaggle.com/datasets/janiobachmann/bank-marketing-dataset

    Dataset looks like this: df.show() image

    Schema: df.printSchema() [deposit being the target variable] image

    Dataset doesn't have any null values image

    After Converting categorical columns. image

    After Converting numerical columns via VectorAssembler -> StandardScaler. image

    I thought of converting the Vector created into individual columns hence exploded the Vector column. image

    Then I converted all the features into a single Vector to create 'features' column. image

    But you can see that some are SparseVector and some are DenseVector. Normally, pyspark's functionality is not affected but since the issue I was facing was not getting resolved, hence I forcefully converted each SparseVector to DenseVector. image

    With all this done, I converted the target variable ('deposit') via StringIndexer as well and took the features and labels columns to a separate df. image

    Keras Model: image

    Elephas estimator config: image

    Training via elephas and output df (pred_test). image

    Error Stack when running pred_test.collect() image

    Full Error Stack

    Py4JJavaError Traceback (most recent call last) Input In [2718], in <cell line: 1>() ----> 1 pred_test.collect()

    File /opt/spark/python/lib/pyspark.zip/pyspark/sql/dataframe.py:693, in DataFrame.collect(self) 683 """Returns all the records as a list of :class:Row. 684 685 .. versionadded:: 1.3.0 (...) 690 [Row(age=2, name='Alice'), Row(age=5, name='Bob')] 691 """ 692 with SCCallSiteSync(self._sc) as css: --> 693 sock_info = self._jdf.collectToPython() 694 return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))

    File /opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/java_gateway.py:1321, in JavaMember.call(self, *args) 1315 command = proto.CALL_COMMAND_NAME +
    1316 self.command_header +
    1317 args_command +
    1318 proto.END_COMMAND_PART 1320 answer = self.gateway_client.send_command(command) -> 1321 return_value = get_return_value( 1322 answer, self.gateway_client, self.target_id, self.name) 1324 for temp_arg in temp_args: 1325 temp_arg._detach()

    File /opt/spark/python/lib/pyspark.zip/pyspark/sql/utils.py:111, in capture_sql_exception..deco(*a, **kw) 109 def deco(*a, **kw): 110 try: --> 111 return f(*a, **kw) 112 except py4j.protocol.Py4JJavaError as e: 113 converted = convert_exception(e.java_exception)

    File /opt/spark/python/lib/py4j-0.10.9.3-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name) 324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client) 325 if answer[1] == REFERENCE_TYPE: --> 326 raise Py4JJavaError( 327 "An error occurred while calling {0}{1}{2}.\n". 328 format(target_id, ".", name), value) 329 else: 330 raise Py4JError( 331 "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 332 format(target_id, ".", name, value))

    Py4JJavaError: An error occurred while calling o17843.collectToPython. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1258.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1258.0 (TID 985) (.internal executor driver): java.lang.RuntimeException: Error while decoding: java.lang.NullPointerException newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize at org.apache.spark.sql.errors.QueryExecutionErrors$.expressionDecodingError(QueryExecutionErrors.scala:1047) at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Deserializer.apply(ExpressionEncoder.scala:184) at org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$scalaConverter$2(ScalaUDF.scala:164) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source) at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$10(EvalPythonExec.scala:126) at scala.collection.Iterator$$anon$10.next(Iterator.scala:461) at scala.collection.Iterator$$anon$10.next(Iterator.scala:461) at scala.collection.Iterator$GroupedIterator.takeDestructively(Iterator.scala:1161) at scala.collection.Iterator$GroupedIterator.go(Iterator.scala:1176) at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1214) at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1217) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:307) at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:53) at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:434) at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2019) at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:269) Caused by: java.lang.NullPointerException

    Driver stacktrace: at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2454) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2403) at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2402) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2402) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1160) at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1160) at scala.Option.foreach(Option.scala:407) at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1160) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2642) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2584) at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2573) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49) at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:938) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2214) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2235) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2254) at org.apache.spark.SparkContext.runJob(SparkContext.scala:2279) at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1030) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112) at org.apache.spark.rdd.RDD.withScope(RDD.scala:414) at org.apache.spark.rdd.RDD.collect(RDD.scala:1029) at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:394) at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:3538) at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3706) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103) at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775) at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64) at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3704) at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:3535) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:282) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.base/java.lang.Thread.run(Thread.java:829) Caused by: java.lang.RuntimeException: Error while decoding: java.lang.NullPointerException newInstance(class org.apache.spark.ml.linalg.VectorUDT).deserialize at org.apache.spark.sql.errors.QueryExecutionErrors$.expressionDecodingError(QueryExecutionErrors.scala:1047) at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Deserializer.apply(ExpressionEncoder.scala:184) at org.apache.spark.sql.catalyst.expressions.ScalaUDF.$anonfun$scalaConverter$2(ScalaUDF.scala:164) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source) at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$10(EvalPythonExec.scala:126) at scala.collection.Iterator$$anon$10.next(Iterator.scala:461) at scala.collection.Iterator$$anon$10.next(Iterator.scala:461) at scala.collection.Iterator$GroupedIterator.takeDestructively(Iterator.scala:1161) at scala.collection.Iterator$GroupedIterator.go(Iterator.scala:1176) at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1214) at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1217) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at scala.collection.Iterator.foreach(Iterator.scala:943) at scala.collection.Iterator.foreach$(Iterator.scala:943) at scala.collection.AbstractIterator.foreach(Iterator.scala:1431) at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:307) at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:53) at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:434) at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2019) at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:269) Caused by: java.lang.NullPointerException

    Please help me out in this. Thanks in advance.!

    opened by aman1931998 6
  • Elephas installation Problem

    Elephas installation Problem

    Hello I face a problem when installing Elephas using pip install elephas, pyspark is installed with all of its versions < 3.2. It said that he is trying to find the best version that meets the requirements. He installed Kersa, tensorflow but the process proceeded and no progress happened. What should I do? I'm in great need of this library. Note: The attached picture shows what is happening as I tried many times and a number of pyspark versions are cached in memory. What are the requirements needed before installing elephas? please I'm a beginner and I need details about Python version, Spark version, etc.. Many thanks for your help. Sama sama

    opened by samasalam 7
  • CNN in elephas

    CNN in elephas

    Hi all Can we write CNN with elephas, i mean if we can add 2d conv layers in keras model and load 2d images. Can we load pretrained models like vgg for transfer learning.

    opened by Asif-Ejaz 0
  • Load Dataset for Image Classification

    Load Dataset for Image Classification

    In the examples the mnist dataset from keras is used, but it is already loaded as numpy.ndarray. I would like to load my RGB image dataset into a Spark dataframe. In Pyspark there is the method:

    spark.read.format("image").option("dropInvalid", True).load(path)

    which allows you to load all the images contained in the path into a dataframe. In the Dataframe there is a row for each image, and each row contains the binary format of the corresponding image. You can convert the binary format to RGB matrices with numpy's methods, but how do you save a Tensor in each row, and then give the Dataframe as input to a convolutional network in Keras?

    Is there any other way to not provide 3 matrices (RGB) for each image, and just provide a large vector of pixels?

    opened by cicciolado 0
  • datasize < batch size?

    datasize < batch size?

    Hi @maxpumperla do you mind give an explanation why you placed this if-statement in the first place? (cc @danielenricocahall )

    https://github.com/maxpumperla/elephas/blob/master/elephas/worker.py#L107 https://github.com/maxpumperla/elephas/blob/master/elephas/worker.py#L116

    Does it matter if my batch size is bigger than the available training data?

    opened by OscarDPan 5
  • Object detection, segmentation using elephas

    Object detection, segmentation using elephas

    Hi Elephas team, awesome project with great efforts, I manage to test it out with some simple model for classification problems. I wonder is there anyway for Elephas to be integrated in some popular net like YOLOv4, EfficientNet, etc. for object detection and segmentation problems. Could I have some documents or examples for them? Big thanks. I would also like to ask why there is a difference in accuracy using RDD and using DF, my friend test it and mention to me those differences. Thanks a lot.

    opened by nhan-dang 4
Releases(3.2.0)
  • 3.2.0(Sep 1, 2022)

  • 3.0.0(Aug 17, 2021)

    • Update to support through the latest Tensorflow (2.6.0), which required converting YAML to JSON (https://github.com/tensorflow/tensorflow/releases/tag/v2.4.3).
    • Remove hyperparameter optimization feature, as Hyperas is archived and causing some compatibility issues
    Source code(tar.gz)
    Source code(zip)
  • 2.1.0(Apr 23, 2021)

  • 2.0.0(Apr 23, 2021)

  • 1.4.1(Feb 6, 2021)

    • Add support to predict probabilities in classification (https://github.com/maxpumperla/elephas/pull/177)
    • Fix gradient computation in synchronous mode (https://github.com/maxpumperla/elephas/pull/176)
    • Performance improvement for distributed predict (https://github.com/maxpumperla/elephas/pull/176)
    Source code(tar.gz)
    Source code(zip)
  • 1.3.1(Jan 31, 2021)

  • 1.2.1(Jan 25, 2021)

  • 1.2.0(Jan 24, 2021)

  • 1.1.0(Jan 19, 2021)

    • Add support for parallel prediction/evaluation in SparkModel
    • Added backwards compatibility for ElephasEstimator with setFeaturesCol, setLabelCol, and setOutputCol, along with a deprecation warning
    • Fixed a bug with socket server in hogwild mode (no longer behaving identical to asynchronous mode)
    • Added typehints, as we're now only supporting Python 3
    Source code(tar.gz)
    Source code(zip)
  • 1.0.0(Jan 14, 2021)

    • Dropped support for Python 2.7
    • Added support for Tensorflow 2.0.x - 2.1.x
    • Added/fixed capability to train with socket client/server

    Originally from: https://github.com/danielenricocahall/elephas/releases/tag/1.0.0

    Source code(tar.gz)
    Source code(zip)
Owner
Max Pumperla
Data Science Professor, Data Scientist & Engineer. DL4J core developer, Hyperopt maintainer, Keras contributor. Author of "Deep Learning and the Game of Go"
Max Pumperla
UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning

UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning This is the official PyTorch implementation for UniMoCo pape

dddzg 49 Jan 02, 2023
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
Get a Grip! - A robotic system for remote clinical environments.

Get a Grip! Within clinical environments, sterilization is an essential procedure for disinfecting surgical and medical instruments. For our engineeri

Jay Sharma 1 Jan 05, 2022
This implementation contains the application of GPlearn's symbolic transformer on a commodity futures sector of the financial market.

GPlearn_finiance_stock_futures_extension This implementation contains the application of GPlearn's symbolic transformer on a commodity futures sector

Chengwei <a href=[email protected]"> 189 Dec 25, 2022
some academic posters as references. May we have in-person poster session soon!

some academic posters as references. May we have in-person poster session soon!

Bolei Zhou 472 Jan 06, 2023
An alarm clock coded in Python 3 with Tkinter

Tkinter-Alarm-Clock An alarm clock coded in Python 3 with Tkinter. Run python3 Tkinter Alarm Clock.py in a terminal if you have Python 3. NOTE: This p

CodeMaster7000 1 Dec 25, 2021
HyperDict - Self linked dictionary in Python

Hyper Dictionary Advanced python dictionary(hash-table), which can link it-self

8 Feb 06, 2022
CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhancement

CBREN This is the Pytorch implementation for our IEEE TCSVT paper : CBREN: Convolutional Neural Networks for Constant Bit Rate Video Quality Enhanceme

Zhao Hengrun 3 Nov 04, 2022
Arxiv harvester - Poor man's simple harvester for arXiv resources

Poor man's simple harvester for arXiv resources This modest Python script takes

Patrice Lopez 5 Oct 18, 2022
Benchmarking the robustness of Spatial-Temporal Models

Benchmarking the robustness of Spatial-Temporal Models This repositery contains the code for the paper Benchmarking the Robustness of Spatial-Temporal

Yi Chenyu Ian 15 Dec 16, 2022
project page for VinVL

VinVL: Revisiting Visual Representations in Vision-Language Models Updates 02/28/2021: Project page built. Introduction This repository is the project

308 Jan 09, 2023
Multi-Modal Machine Learning toolkit based on PyTorch.

简体中文 | English TorchMM 简介 多模态学习工具包 TorchMM 旨在于提供模态联合学习和跨模态学习算法模型库,为处理图片文本等多模态数据提供高效的解决方案,助力多模态学习应用落地。 近期更新 2022.1.5 发布 TorchMM 初始版本 v1.0 特性 丰富的任务场景:工具

njustkmg 1 Jan 05, 2022
Price-Prediction-For-a-Dream-Home - A machine learning based linear regression trained model for house price prediction.

Price-Prediction-For-a-Dream-Home ROADMAP TO THIS LINEAR REGRESSION BASED HOUSE PRICE PREDICTION PREDICTION MODEL Import all the dependencies of the p

DIKSHA DESWAL 1 Dec 29, 2021
Simulating an AI playing 2048 using the Expectimax algorithm

2048-expectimax Simulating an AI playing 2048 using the Expectimax algorithm The base game engine uses code from here. The AI player is modeled as a m

Subha Ramesh 2 Jan 31, 2022
Code for the paper "Adversarial Generator-Encoder Networks"

This repository contains code for the paper "Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky. Pr

Dmitry Ulyanov 279 Jun 26, 2022
1st-in-MICCAI2020-CPM - Combined Radiology and Pathology Classification

Combined Radiology and Pathology Classification MICCAI 2020 Combined Radiology a

22 Dec 08, 2022
PolyGlot, a fuzzing framework for language processors

PolyGlot, a fuzzing framework for language processors Build We tested PolyGlot on Ubuntu 18.04. Get the source code: git clone https://github.com/s3te

Software Systems Security Team at Penn State University 79 Dec 27, 2022
A Broader Picture of Random-walk Based Graph Embedding

Random-walk Embedding Framework This repository is a reference implementation of the random-walk embedding framework as described in the paper: A Broa

Zexi Huang 23 Dec 13, 2022
Code for Multinomial Diffusion

Code for Multinomial Diffusion Abstract Generative flows and diffusion models have been predominantly trained on ordinal data, for example natural ima

104 Jan 04, 2023
Direct Multi-view Multi-person 3D Human Pose Estimation

Implementation of NeurIPS-2021 paper: Direct Multi-view Multi-person 3D Human Pose Estimation [paper] [video-YouTube, video-Bilibili] [slides] This is

Sea AI Lab 251 Dec 30, 2022